File size: 39,930 Bytes
4f07533
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
# Copyright 2025 TeleAI Rhodes Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Main VLA model architecture based on Qwen3-VL."""

from dataclasses import dataclass

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, MSELoss
from typing import Any, Dict, List, Optional, Tuple, Union

from transformers.modeling_outputs import ModelOutput
from transformers.cache_utils import Cache
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, is_torchdynamo_compiling

from .modeling_qwen3_vl import (
    Qwen3VLForConditionalGeneration,
    Qwen3VLTextModel,
    Qwen3VLVisionModel,
)
from .configuration_prts_qwen3_vl import PRTS_FlowMatchingConfig_Qwen3VL
from .dit_action_head import FlowMatchingDiTHead, MoTFlowMatchingHead

ACTION_DATASET_NAMES = []

# ----------------------------- Print Customization -----------------------------
from colorama import init, Fore, Style
from datetime import datetime

# Initialize colorama
init(autoreset=True)

class CustomPrinter:
    """Custom colored printer."""
    
    # Define message type configuration
    TYPE_CONFIG = {
        'normal': {
            'color': Fore.WHITE,
            'icon': '',
            'prefix': '',
            'style': Style.NORMAL
        },
        'important': {
            'color': Fore.CYAN,
            'icon': '💡',
            'prefix': 'IMPORTANT',
            'style': Style.BRIGHT
        }
    }

    @classmethod
    def print(cls, message, msg_type='normal', show_time=True, show_icon=True, end='\n'):
        """
        Custom print function.
        
        Args:
            message: The message content to print
            msg_type: Message type ('normal', 'info', 'success', 'warning', 'error', 'fail', 'debug', 'important')
            show_time: Whether to display a timestamp
            show_icon: Whether to display the icon
            end: Line terminator
        """
        # Get configuration for the message type
        config = cls.TYPE_CONFIG.get(msg_type, cls.TYPE_CONFIG['normal'])
        
        # Build prefix parts
        prefix_parts = []
        
        # Add timestamp
        if show_time:
            timestamp = datetime.now().strftime('%H:%M:%S')
            prefix_parts.append(f"[{timestamp}]")
        
        # Add icon and prefix text
        icon_text = f"{config['icon']} " if show_icon else ""
        prefix_parts.append(f"{icon_text}{config['prefix']}")
        
        if config['prefix'] == '':
            full_message = message
        else:
            # Combine prefix parts
            prefix = " ".join(prefix_parts)
            
            # Construct full message
            full_message = f"{prefix}: {message}"
        
        # Apply color and style and print
        formatted_message = f"{config['style']}{config['color']}{full_message}"
        print(formatted_message, end=end)
    
    @classmethod
    def normal(cls, message, **kwargs):
        """Convenience: normal-level print."""
        cls.print(message, 'normal', **kwargs)

    @classmethod
    def important(cls, message, **kwargs):
        """Convenience: important-level print."""
        cls.print(message, 'important', **kwargs)

def important(message, **kwargs):
    CustomPrinter.important(message, **kwargs)

# -------------------------------------------------------------


def create_sinusoidal_pos_embedding(
    time: torch.Tensor,
    dimension: int,
    min_period: float = 4e-3,
    max_period: float = 4.0,
    device="cpu",
) -> torch.Tensor:
    """
    Computes sine-cosine positional embedding vectors for scalar positions (diffusion timesteps).
    
    Args:
        time: Tensor of shape (batch_size,) containing timestep values
        dimension: Embedding dimension (must be even)
        min_period: Minimum period for sinusoidal encoding
        max_period: Maximum period for sinusoidal encoding
        device: Device to create tensors on
        
    Returns:
        Positional embeddings of shape (batch_size, dimension)
    """
    if dimension % 2 != 0:
        raise ValueError(f"dimension ({dimension}) must be divisible by 2")

    if time.ndim != 1:
        raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")

    fraction = torch.linspace(0.0, 1.0, dimension // 2, device=device)
    period = min_period * (max_period / min_period) ** fraction

    scaling_factor = 1.0 / period * 2 * math.pi
    sin_input = scaling_factor[None, :] * time[:, None]
    pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
    return pos_emb


class ContrastiveEncoder(nn.Module):
    """
    MLP projector for Contrastive Reinforcement Learning (CRL) embeddings.
    
    Projects hidden states to a shared latent space for contrastive learning,
    with L2 normalization for stable similarity computation.
    
    Architecture: N-layer MLP with LayerNorm and Swish activation, 
                  followed by a cold-initialized output projection.
                  [Linear -> LayerNorm -> Swish] x N -> Linear (cold init)

    Matches stable_contrastive_rl's Q network structure (default: 4 hidden layers).
    
    Args:
        input_dim: Dimension of input hidden states
        output_dim: Dimension of output embeddings (default: 256)
        hidden_dim: Dimension of hidden layers (default: 1024)
        num_layers: Number of hidden layers (default: 4)
        repr_norm: Whether to L2-normalize outputs (default: False)
        init_w: Small value for last layer weight initialization for cold init (default: 1e-12)
    """
    def __init__(
        self,
        input_dim: int,
        output_dim: int = 256,
        hidden_dim: int = 1024,
        num_layers: int = 4,
        repr_norm: bool = False,
        init_w: float = 1e-12,
    ):
        super().__init__()
        self.num_layers = num_layers
        self.repr_norm = repr_norm
        
        # Build hidden layers with LayerNorm
        self.hidden_layers = nn.ModuleList()
        self.layer_norms = nn.ModuleList()
        
        for i in range(num_layers):
            in_dim = input_dim if i == 0 else hidden_dim
            self.hidden_layers.append(nn.Linear(in_dim, hidden_dim))
            self.layer_norms.append(nn.LayerNorm(hidden_dim))
        
        # Output projection layer with cold initialization
        self.output_proj = nn.Linear(hidden_dim, output_dim)
        self.output_proj.weight.data.uniform_(-init_w, init_w)
        self.output_proj.bias.data.fill_(0)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Project input to L2-normalized embedding space.
        
        Args:
            x: Input tensor of shape (batch_size, input_dim)
            
        Returns:
            L2-normalized embeddings of shape (batch_size, output_dim)
        """
        # Pass through hidden layers
        for fc, norm in zip(self.hidden_layers, self.layer_norms):
            x = fc(x)
            x = norm(x)
            x = F.silu(x)
        
        # Output projection
        x = self.output_proj(x)
        
        # Optional L2 normalization
        if self.repr_norm:
            x = F.normalize(x, dim=-1)
        
        return x



@dataclass
class PRTS_Qwen3VL_ModelOutputWithPast(ModelOutput):
    """
    Output class for PRTS model based on Qwen3-VL.
    
    Args:
        loss: Combined total loss
        flow_loss: Flow matching loss for action prediction
        cross_entropy_loss: Standard language modeling loss
        crl_loss: Contrastive Reinforcement Learning loss for goal-action alignment
        logits: Language model logits
        past_key_values: Cached key-value states
        hidden_states: Hidden states from all layers (if output_hidden_states=True)
        attentions: Attention weights (if output_attentions=True)
        rope_deltas: RoPE position delta information
        channel_loss_dict: Per-dataset loss values for logging
        channel_loss_count_dict: Per-dataset token counts for loss normalization
    """
    loss: Optional[torch.FloatTensor] = None
    flow_loss: Optional[torch.FloatTensor] = None
    cross_entropy_loss: Optional[torch.FloatTensor] = None
    crl_loss: Optional[torch.FloatTensor] = None
    logits: Optional[torch.FloatTensor] = None
    past_key_values: Optional[List[torch.FloatTensor]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    rope_deltas: Optional[torch.LongTensor] = None

    crl_num_samples: Optional[torch.LongTensor] = None
    channel_loss_dict: Optional[dict] = None
    channel_loss_count_dict: Optional[dict] = None


class PRTS_Qwen3VL(Qwen3VLForConditionalGeneration):
    """
    Vision-Language-Action model based on Qwen3-VL.
    
    This model extends Qwen3-VL to support:
    1. Proprioceptive state embedding and prediction
    2. Sub-task description generation (language format)
    3. Action chunk prediction via flow matching (continuous actions)
    4. Optional discrete action tokenization (fast mode)
    
    The model uses a flow matching approach for continuous action prediction, with a DiT
    (Diffusion Transformer) action head that cross-attends to VLM hidden states.
    """
    config: PRTS_FlowMatchingConfig_Qwen3VL

    _tied_weights_keys = ["lm_head.weight"]
    _no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"]

    def __init__(
        self,
        config: PRTS_FlowMatchingConfig_Qwen3VL,
    ):
        """
        Initialize the PRTS Qwen3-VL model for action processing.

        Args:
            config: Model configuration
            use_fast_tokenizer (bool): Whether to use FAST tokenizer for discrete actions
            flow_matching_action_loss_weight (float): Weight for flow matching action loss
        """
        super().__init__(config)

        # The parent class initializes:
        # - self.visual: Qwen3VLVisionModel
        # - self.language_model: Qwen3VLTextModel
        # - self.lm_head: Language model head
        # - self.rope_deltas: Cached rope deltas
        # We keep these and add PRTS-specific components

        # PRTS-specific parameters
        self.action_dim = config.max_action_dim
        self.use_fast_tokenizer = config.use_fast_action_tokenizer
        self.flow_matching_action_loss_weight = config.flow_matching_action_loss_weight

        # Loss functions
        self.loss_fct = CrossEntropyLoss(reduction="none")
        self.loss_mse = MSELoss(reduction="none")

        # DiT-based flow matching action head: standard (+ AlternateVLDiT) or pi0.5 KV expert
        self.use_mot_action_expert = config.dit_action_head_config.get(
            "use_mot_action_expert", False
        )
        if config.flow_matching_action_loss_weight > 0.:
            if self.use_mot_action_expert:
                self.dit_action_head = MoTFlowMatchingHead(
                    action_dim=self.action_dim,
                    action_chunk_size=config.action_chunk_size,
                    vlm_config=config.text_config,
                    num_inference_timesteps=config.num_denoise_steps,
                    config=config.dit_action_head_config,
                )
            else:
                self.dit_action_head = FlowMatchingDiTHead(
                    action_dim=self.action_dim,
                    action_chunk_size=config.action_chunk_size,
                    cross_attention_dim=config.text_config.hidden_size,
                    num_inference_timesteps=config.num_denoise_steps,
                    config=config.dit_action_head_config,
                )

        # CRL (Contrastive Reinforcement Learning) components
        if config.crl_loss_weight > 0.:
            hidden_size = config.text_config.hidden_size
            # Current encoders (trainable)
            self.crl_action_encoder = ContrastiveEncoder(
                input_dim=hidden_size,
                output_dim=config.crl_embed_dim,
                init_w=config.crl_encoder_init_w,
                repr_norm=config.crl_repr_norm,
            )
            self.crl_goal_encoder = ContrastiveEncoder(
                input_dim=hidden_size,
                output_dim=config.crl_embed_dim,
                init_w=config.crl_encoder_init_w,
                repr_norm=config.crl_repr_norm,
            )
            # Learnable temperature (log-space for numerical stability, CLIP recipe).
            self.crl_logit_scale = nn.Parameter(
                torch.ones([], requires_grad=True) * math.log(1 / 0.2)
            )

        # Initialize weights
        self.post_init()

        # Print parameter counts
        visual_params = sum(p.numel() for p in self.visual.parameters())
        language_params = sum(p.numel() for p in self.language_model.parameters())
        model_params = visual_params + language_params
        important(f"Backbone VLM (visual + language_model) parameters: {model_params / 1e6:.2f}M")
        important(f"Flow Matching Loss coefficient: {self.flow_matching_action_loss_weight}")

        if config.flow_matching_action_loss_weight > 0.:
            dit_params = sum(p.numel() for p in self.dit_action_head.parameters())
            # Get the inner model type name for logging
            if hasattr(self.dit_action_head, 'dit'):
                dit_head_type = type(self.dit_action_head.dit).__name__
            else:
                dit_head_type = type(self.dit_action_head).__name__
            important(f"DiT Action Head ({dit_head_type}) parameters: {dit_params / 1e6:.2f}M")

        if config.crl_loss_weight > 0.:
            crl_params = sum(p.numel() for p in self.crl_action_encoder.parameters())
            crl_params += sum(p.numel() for p in self.crl_goal_encoder.parameters())
            important(f"CRL Encoders (action + goal) parameters: {crl_params / 1e6:.2f}M")
            important(f"CRL Loss coefficient: {config.crl_loss_weight}")
            important(f"CRL Encoder init_w: {config.crl_encoder_init_w}")
            important(f"CRL Repr Norm: {config.crl_repr_norm}")

        self.fast_action_token_start_idx = 200000
        self.use_multi_positive = True

    def get_input_embeddings(self):
        return self.language_model.get_input_embeddings()

    def set_input_embeddings(self, value):
        self.language_model.set_input_embeddings(value)
    
    def set_decoder(self, decoder):
        self.language_model = decoder

    def get_decoder(self):
        return self.language_model

    def get_output_embeddings(self):
        return self.lm_head
    
    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def to_float32_flow_matching_head(self):
        """Convert flow matching heads to float32 for numerical stability."""
        if hasattr(self, 'dit_action_head'):
            self.dit_action_head = self.dit_action_head.to(dtype=torch.float32)

    def set_fast_action_info(self, action_mapper, fast_action_token_start_idx):
        """Set information for fast (discrete) action tokenization."""
        self.action_mapper = action_mapper
        self.fast_action_token_start_idx = fast_action_token_start_idx
    
    def get_placeholder_mask_with_special_token(
        self,
        input_ids: torch.LongTensor,
        inputs_embeds: torch.FloatTensor,
        special_features: torch.FloatTensor,
        special_pad_token_id: int,
    ):
        """
        Get placeholder mask for a specific special token (e.g., state tokens).
        
        Similar to get_placeholder_mask but for custom special tokens beyond image/video.
        """
        if input_ids is None:
            special_mask = inputs_embeds == self.get_input_embeddings()(
                torch.tensor(special_pad_token_id, dtype=torch.long, device=inputs_embeds.device)
            )
            special_mask = special_mask.all(-1)
        else:
            special_mask = input_ids == special_pad_token_id

        n_special_tokens = special_mask.sum()
        special_mask = special_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
        if special_features is not None and inputs_embeds[special_mask].numel() != special_features.numel():
            raise ValueError(
                f"Features and tokens do not match: tokens: {n_special_tokens}, features {special_features.shape[0]}"
            )

        return special_mask

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        # use_cache: Optional[bool] = None,
        # output_attentions: Optional[bool] = None,
        # output_hidden_states: Optional[bool] = None,
        # return_dict: Optional[bool] = None,
        pixel_values: Optional[torch.Tensor] = None,
        pixel_values_videos: Optional[torch.FloatTensor] = None,
        image_grid_thw: Optional[torch.LongTensor] = None,
        video_grid_thw: Optional[torch.LongTensor] = None,
        # rope_deltas: Optional[torch.LongTensor] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        actions: Optional[torch.Tensor] = None,
        action_is_pad: torch.Tensor | None = None,
        action_dof_mask: Optional[torch.Tensor] = None,
        dataset_names: Optional[List[str]] = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> Union[tuple, PRTS_Qwen3VL_ModelOutputWithPast]:
        """
        Forward pass for PRTS_Qwen3VL model.
        
        This extends Qwen3VLForConditionalGeneration.forward with:
        - State embedding injection
        - Action chunk flow matching
        - DeepStack visual feature handling
        """
        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")


        # 1. Prepare input embeddings
        if inputs_embeds is None:
            inputs_embeds = self.get_input_embeddings()(input_ids)

        image_mask = None
        video_mask = None

        # 2. Process images with deepstack features
        deepstack_image_embeds = None
        if pixel_values is not None:
            image_embeds, deepstack_image_embeds = self.get_image_features(pixel_values, image_grid_thw, image_max_seqlen=kwargs['image_max_seqlen'])
            image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
            image_mask, _ = self.get_placeholder_mask(
                input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
            )
            inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)

        # 3. Process videos with deepstack features
        deepstack_video_embeds = None
        if pixel_values_videos is not None:
            video_embeds, deepstack_video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
            video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
            _, video_mask = self.get_placeholder_mask(
                input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
            )
            inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)

        # 4. Aggregate deepstack visual features
        visual_pos_masks = None
        deepstack_visual_embeds = None
        if image_mask is not None and video_mask is not None:
            # aggregate visual_pos_masks and deepstack_visual_embeds
            image_mask = image_mask[..., 0]
            video_mask = video_mask[..., 0]
            visual_pos_masks = image_mask | video_mask
            deepstack_visual_embeds = []
            image_mask_joint = image_mask[visual_pos_masks]
            video_mask_joint = video_mask[visual_pos_masks]
            for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds):
                embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device)
                embed_joint[image_mask_joint, :] = img_embed
                embed_joint[video_mask_joint, :] = vid_embed
                deepstack_visual_embeds.append(embed_joint)
        elif image_mask is not None:
            image_mask = image_mask[..., 0]
            visual_pos_masks = image_mask
            deepstack_visual_embeds = deepstack_image_embeds
        elif video_mask is not None:
            video_mask = video_mask[..., 0]
            visual_pos_masks = video_mask
            deepstack_visual_embeds = deepstack_video_embeds

        if attention_mask is not None:
            attention_mask = attention_mask.to(inputs_embeds.device)

        # 7. Calculate position IDs using Qwen3VL's rope index
        if position_ids is None:
            attention_mask_tensor = (
                attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
            )
            if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
                attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
                if attention_mask_tensor.dtype.is_floating_point:
                    attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
                    attention_mask_tensor = (1.0 - attention_mask_tensor).int()

            prefill_compiled_stage = is_torchdynamo_compiling() and (
                (input_ids is not None and input_ids.shape[1] != 1)
                or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
            )
            prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
                (cache_position is not None and cache_position[0] == 0)
                or (past_key_values is None or past_key_values.get_seq_length() == 0)
            )
            if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
                position_ids, rope_deltas = self.get_rope_index(
                    input_ids,
                    image_grid_thw,
                    video_grid_thw,
                    attention_mask=attention_mask_tensor,
                )
                self.rope_deltas = rope_deltas
            else:
                batch_size, seq_length, _ = inputs_embeds.shape
                delta = (
                    (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
                    if cache_position is not None
                    else 0
                )
                position_ids = torch.arange(seq_length, device=inputs_embeds.device)
                position_ids = position_ids.view(1, -1).expand(batch_size, -1)
                if cache_position is not None:  # otherwise `deltas` is an int `0`
                    delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
                position_ids = position_ids.add(delta)
                position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)

        _lm_extra_kwargs: dict = {}

        _use_cache = (
            self.use_mot_action_expert
            and self.flow_matching_action_loss_weight > 0.
            and actions is not None
        )

        vlm_outputs = self.language_model(
            input_ids=None,
            position_ids=position_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=_use_cache,
            cache_position=cache_position,
            visual_pos_masks=visual_pos_masks,
            deepstack_visual_embeds=deepstack_visual_embeds,
            output_hidden_states=False,
            **_lm_extra_kwargs,
            **kwargs,
        )

        vlm_hidden_states = vlm_outputs.last_hidden_state

        # 11. Run DiT action head if actions are present
        dit_pred_v = None
        dit_velocity = None
        if actions is not None and self.flow_matching_action_loss_weight > 0:
            # vlm_hidden_states shape: bs, seq_length, hidden_size
            actions_for_dit = actions.to(vlm_hidden_states.device, dtype=vlm_hidden_states.dtype)
            dof_mask_for_dit = action_dof_mask.to(vlm_hidden_states.device, dtype=vlm_hidden_states.dtype) if action_dof_mask is not None else None
            # Pass attention_mask so DiT cross-attention ignores padding tokens
            dit_encoder_attention_mask = attention_mask.bool() if attention_mask is not None else None

            if self.use_mot_action_expert and vlm_outputs.past_key_values is not None:
                dit_pred_v, dit_velocity = self.dit_action_head(
                    vlm_outputs.past_key_values,
                    actions_for_dit,
                    dof_mask_for_dit,
                    encoder_attention_mask=dit_encoder_attention_mask,
                )
            else:
                # Standard: pass single (last-layer) VLM hidden states
                dit_image_mask = visual_pos_masks.bool() if visual_pos_masks is not None else None
                dit_pred_v, dit_velocity = self.dit_action_head(
                    vlm_hidden_states, actions_for_dit, dof_mask_for_dit,
                    encoder_attention_mask=dit_encoder_attention_mask,
                    image_mask=dit_image_mask,
                )

        # 12. Compute logits
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        logits = self.lm_head(vlm_hidden_states[:, slice_indices, :])
        
        # 13. Compute losses
        loss = None
        cross_entropy_loss, flow_loss = None, None
        channel_loss_dict = None
        channel_loss_count_dict = None

        if labels is not None:
            loss = 0
            action_accuracy = 0
            unique_datasets_name = list(set(dataset_names)) if dataset_names is not None else []

            # Compute cross-entropy loss
            shift_logits = logits[..., :-1, :].float().contiguous()
            shift_labels = labels[..., 1:].contiguous()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)

            shift_labels = shift_labels.to(shift_logits.device)
            non_ignored_mask = shift_labels != -100
            _cross_entropy_loss = self.loss_fct(shift_logits, shift_labels)
            cross_entropy_loss = (
                _cross_entropy_loss[non_ignored_mask].mean()
                if non_ignored_mask.any()
                else (_cross_entropy_loss.sum() * 0.0)
            )

            # Add cross-entropy loss to total
            if not torch.isnan(cross_entropy_loss):
                loss += cross_entropy_loss
            else:
                with torch.no_grad():
                    cross_entropy_loss.detach()

            # Compute action token prediction accuracy (for logging)
            shift_logits_for_acc = logits[..., :-1, :].contiguous()
            action_preds = shift_logits_for_acc.argmax(dim=-1)
            shift_labels_for_acc = labels[..., 1:].contiguous()

            action_mask = (
                shift_labels_for_acc >= self.fast_action_token_start_idx
            )

            if self.use_fast_tokenizer and action_mask.any():
                correct_preds = (action_preds == shift_labels_for_acc) & action_mask
                action_accuracy = (
                    correct_preds.sum().float() / action_mask.sum().float()
                )

                if channel_loss_dict is None:
                    channel_loss_dict = {}
                    channel_loss_count_dict = {}
                
                channel_loss_dict["action_accuracy"] = action_accuracy.detach()
                channel_loss_count_dict["action_accuracy"] = torch.tensor(1, device=action_accuracy.device)

        # 14. Compute flow matching loss (DiT action head)
        if dit_pred_v is not None and self.flow_matching_action_loss_weight > 0:
            if channel_loss_dict is not None:
                channel_loss_dict.update(
                    {
                        f"flow_matching/{dataset_name}": torch.tensor(0.0, device=logits.device)
                        for dataset_name in ACTION_DATASET_NAMES
                    }
                )
                channel_loss_count_dict.update(
                    {
                        f"flow_matching/{dataset_name}": torch.tensor(0, device=logits.device)
                        for dataset_name in ACTION_DATASET_NAMES
                    }
                )
            else:
                channel_loss_dict = {
                    f"flow_matching/{dataset_name}": torch.tensor(0.0, device=logits.device)
                    for dataset_name in ACTION_DATASET_NAMES
                }
                channel_loss_count_dict = {
                    f"flow_matching/{dataset_name}": torch.tensor(0, device=logits.device)
                    for dataset_name in ACTION_DATASET_NAMES
                }

            # Compute flow matching loss: MSE between predicted and target velocity
            _fm_loss = self.loss_mse(dit_pred_v, dit_velocity)

            # Apply DOF mask (zero out invalid action dimensions)
            if action_dof_mask is not None:
                valid_action_dim = int(action_dof_mask[0, 0, :].sum(dim=-1).item())     # NOTE: only support 单种具身实体数据微调
                _fm_loss = _fm_loss[:, :, :valid_action_dim]

            # Apply action_is_pad mask: exclude padding timesteps from loss
            # action_is_pad: (B, T), True = pad timestep → should not contribute to loss
            if action_is_pad is not None:
                valid_timestep_mask = ~action_is_pad[:, :_fm_loss.shape[1]]  # align length
                _fm_loss = _fm_loss * valid_timestep_mask.unsqueeze(-1)
                flow_loss = _fm_loss.sum() / (valid_timestep_mask.sum() * _fm_loss.shape[-1])
            else:
                flow_loss = _fm_loss.mean()
            
            if not torch.isnan(flow_loss):
                loss = loss + self.flow_matching_action_loss_weight * flow_loss if loss is not None else self.flow_matching_action_loss_weight * flow_loss
            else:
                with torch.no_grad():
                    flow_loss.detach()

            # Per-dataset flow matching loss logging
            logging_fm_loss = _fm_loss.detach().mean(dim=(1, 2))  # Sum over chunk_size and action_dim

            action_dataset_names = dataset_names if dataset_names is not None else []
            unique_action_datasets = list(set(action_dataset_names))
            
            for dataset_name_i in unique_action_datasets:
                action_dataset_mask = torch.tensor(
                    [name == dataset_name_i for name in action_dataset_names],
                    device=logits.device,
                )
                if action_dataset_mask.any():
                    dataset_fm_loss = logging_fm_loss[action_dataset_mask].sum()
                    dataset_fm_count = action_dataset_mask.sum()
                    
                    prefixed_key = f"flow_matching/{dataset_name_i}"
                    channel_loss_dict[prefixed_key] += dataset_fm_loss
                    channel_loss_count_dict[prefixed_key] += dataset_fm_count
            
        elif self.flow_matching_action_loss_weight > 0: 
            # Dummy loss to keep all DiT parameters in computation graph
            dummy_params = [p.sum() * 0.0 for p in self.dit_action_head.parameters() if p.requires_grad]
            dummy_loss = sum(dummy_params) if len(dummy_params) > 0 else torch.tensor(0.0, device=logits.device)
            loss = (loss + dummy_loss) if loss is not None else dummy_loss

        return PRTS_Qwen3VL_ModelOutputWithPast(
            loss=loss,
            cross_entropy_loss=(
                cross_entropy_loss.detach() if cross_entropy_loss is not None else None
            ),
            flow_loss=(
                flow_loss.detach() if flow_loss is not None else None
            ),
            crl_loss=None,
            logits=logits,
            past_key_values=vlm_outputs.past_key_values,
            # hidden_states=vlm_outputs.hidden_states,
            # attentions=vlm_outputs.attentions,
            crl_num_samples=None,
            rope_deltas=self.rope_deltas,
            channel_loss_dict=channel_loss_dict,
            channel_loss_count_dict=channel_loss_count_dict,
        )
    

    def embed_prefix(
        self,
        input_ids: torch.LongTensor,
        inputs_embeds: torch.FloatTensor | None = None,
        pixel_values: torch.Tensor | None = None,
        pixel_values_videos: torch.FloatTensor | None = None,
        image_grid_thw: torch.LongTensor | None = None,
        video_grid_thw: torch.LongTensor | None = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
        """
        Embed prefix tokens including vision, DeepStack, and (optionally) state features.

        Returns:
            (inputs_embeds, visual_pos_masks, deepstack_visual_embeds)
        """
        if inputs_embeds is None:
            inputs_embeds = self.get_input_embeddings()(input_ids)

        image_mask = None
        video_mask = None
        deepstack_image_embeds = None
        deepstack_video_embeds = None

        if pixel_values is not None:
            image_embeds, deepstack_image_embeds = self.get_image_features(
                pixel_values, image_grid_thw,
                image_max_seqlen=kwargs.get('image_max_seqlen'),
            )
            image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
            image_mask, _ = self.get_placeholder_mask(
                input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
            )
            inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)

        if pixel_values_videos is not None:
            video_embeds, deepstack_video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
            video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
            _, video_mask = self.get_placeholder_mask(
                input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
            )
            inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)

        visual_pos_masks = None
        deepstack_visual_embeds = None
        if image_mask is not None and video_mask is not None:
            image_mask = image_mask[..., 0]
            video_mask = video_mask[..., 0]
            visual_pos_masks = image_mask | video_mask
            deepstack_visual_embeds = []
            image_mask_joint = image_mask[visual_pos_masks]
            video_mask_joint = video_mask[visual_pos_masks]
            for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds):
                embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device)
                embed_joint[image_mask_joint, :] = img_embed
                embed_joint[video_mask_joint, :] = vid_embed
                deepstack_visual_embeds.append(embed_joint)
        elif image_mask is not None:
            image_mask = image_mask[..., 0]
            visual_pos_masks = image_mask
            deepstack_visual_embeds = deepstack_image_embeds
        elif video_mask is not None:
            video_mask = video_mask[..., 0]
            visual_pos_masks = video_mask
            deepstack_visual_embeds = deepstack_video_embeds

        return inputs_embeds, visual_pos_masks, deepstack_visual_embeds

    @torch.no_grad()
    def sample_actions(
        self,
        input_ids: torch.LongTensor | None = None,
        position_ids: torch.LongTensor | None = None,
        attention_mask: torch.Tensor | None = None,
        past_key_values: list[torch.FloatTensor] | None = None,
        inputs_embeds: torch.FloatTensor | None = None,
        cache_position: torch.LongTensor | None = None,
        pixel_values: torch.Tensor | None = None,
        pixel_values_videos: torch.FloatTensor | None = None,
        image_grid_thw: torch.LongTensor | None = None,
        video_grid_thw: torch.LongTensor | None = None,
        action_dof_mask: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Any]:
        """
        Sample actions using DiT-based flow matching denoising.

        1. Computes position_ids via get_rope_index
        2. Embeds the prefix (with DeepStack visual features)
        3. Runs the language model to get hidden states
        4. Uses DiT action head to denoise actions via cross-attention to VLM features

        Returns:
            (x_t, outputs) — denoised action trajectories and language-model outputs
        """
        if position_ids is None:
            position_ids, _ = self.get_rope_index(
                input_ids,
                image_grid_thw=image_grid_thw,
                video_grid_thw=video_grid_thw,
                attention_mask=attention_mask,
            )

        visual_pos_masks = None
        deepstack_visual_embeds = None
        if inputs_embeds is None:
            inputs_embeds, visual_pos_masks, deepstack_visual_embeds = self.embed_prefix(
                input_ids,
                pixel_values=pixel_values,
                pixel_values_videos=pixel_values_videos,
                image_grid_thw=image_grid_thw,
                video_grid_thw=video_grid_thw,
                **kwargs,
            )

        _sample_use_cache = (
            self.use_mot_action_expert and self.flow_matching_action_loss_weight > 0
        )
        outputs = self.language_model(
            input_ids=None,
            position_ids=position_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=_sample_use_cache,
            cache_position=cache_position,
            visual_pos_masks=visual_pos_masks,
            deepstack_visual_embeds=deepstack_visual_embeds,
            output_hidden_states=False,
        )

        vlm_hidden_states = outputs.last_hidden_state
        dit_encoder_attention_mask = attention_mask.bool() if attention_mask is not None else None

        if self.use_mot_action_expert and outputs.past_key_values is not None:
            x_t = self.dit_action_head.predict_action(
                outputs.past_key_values,
                action_dof_mask,
                encoder_attention_mask=dit_encoder_attention_mask,
            )
        else:
            dit_image_mask = visual_pos_masks.bool() if visual_pos_masks is not None else None
            x_t = self.dit_action_head.predict_action(
                vlm_hidden_states, action_dof_mask,
                encoder_attention_mask=dit_encoder_attention_mask,
                image_mask=dit_image_mask,
            )

        return x_t, outputs


PRTS_Qwen3VL.register_for_auto_class()


__all__ = ["PRTS_Qwen3VL", "PRTS_Qwen3VL_ModelOutputWithPast"]