| """ |
| DiT (Diffusion Transformer) based flow matching action head for PRTS. |
| |
| Replaces the Qwen3VLTextModel-based fm_action_expert with a lightweight DiT |
| that uses explicit cross-attention to VLM hidden states, following the architecture |
| from GR00T / pi05. |
| |
| Architecture: |
| ActionEncoder(noisy_actions + dof_mask, timestep) |
| → action_features |
| → DiT(cross-attn to VLM hidden states, ada-norm timestep conditioning) |
| → ActionDecoder → predicted velocity |
| """ |
|
|
| import math |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.distributions import Beta |
| from typing import Optional |
|
|
| from transformers.cache_utils import Cache |
| from transformers.modeling_flash_attention_utils import _flash_attention_forward |
|
|
|
|
| |
| |
| |
| |
|
|
|
|
| class SinusoidalPositionalEncoding(nn.Module): |
| """Sinusoidal positional encoding for sequence positions or timesteps.""" |
|
|
| def __init__(self, embedding_dim: int): |
| super().__init__() |
| self.embedding_dim = embedding_dim |
|
|
| def forward(self, timesteps: torch.Tensor) -> torch.Tensor: |
| timesteps = timesteps.float() |
| squeeze = False |
| if timesteps.dim() == 1: |
| timesteps = timesteps.unsqueeze(1) |
| squeeze = True |
|
|
| half_dim = self.embedding_dim // 2 |
| exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device) * ( |
| math.log(10000.0) / half_dim |
| ) |
| freqs = timesteps.unsqueeze(-1) * exponent.exp() |
| enc = torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1) |
|
|
| if squeeze: |
| enc = enc.squeeze(1) |
| return enc |
|
|
|
|
| class TimestepEncoder(nn.Module): |
| """Projects scalar timesteps to embedding space via sinusoidal encoding + MLP.""" |
|
|
| def __init__(self, embedding_dim: int): |
| super().__init__() |
| self.sinusoidal = SinusoidalPositionalEncoding(256) |
| self.linear_1 = nn.Linear(256, embedding_dim) |
| self.act = nn.SiLU() |
| self.linear_2 = nn.Linear(embedding_dim, embedding_dim) |
|
|
| def forward(self, timesteps: torch.Tensor) -> torch.Tensor: |
| t_emb = self.sinusoidal(timesteps) |
| t_emb = self.linear_1(t_emb.to(dtype=self.linear_1.weight.dtype)) |
| t_emb = self.act(t_emb) |
| t_emb = self.linear_2(t_emb) |
| return t_emb |
|
|
|
|
| class AdaLayerNorm(nn.Module): |
| """Adaptive Layer Normalization conditioned on timestep embeddings. |
| |
| Applies scale-shift modulation: out = norm(x) * (1 + scale) + shift, |
| where (scale, shift) are linearly projected from the timestep embedding. |
| """ |
|
|
| def __init__(self, embedding_dim: int, eps: float = 1e-5): |
| super().__init__() |
| self.silu = nn.SiLU() |
| self.linear = nn.Linear(embedding_dim, embedding_dim * 2) |
| self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=False) |
|
|
| def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor: |
| temb = self.linear(self.silu(temb)) |
| scale, shift = temb.chunk(2, dim=-1) |
| x = self.norm(x) * (1 + scale[:, None]) + shift[:, None] |
| return x |
|
|
|
|
| class DiTAttention(nn.Module): |
| """Multi-head attention supporting both self-attention and cross-attention. |
| |
| Supports two backends selected via ``attn_implementation``: |
| |
| * ``"sdpa"`` (default) – uses :func:`F.scaled_dot_product_attention`, which |
| dispatches automatically to FlashAttention / memory-efficient attention |
| depending on the installed PyTorch build. The encoder padding mask is |
| expanded to ``(B, 1, 1, S)`` and passed as ``attn_mask``. |
| |
| * ``"flash_attention_2"`` – calls the ``flash_attn`` package directly for |
| lower memory usage and higher throughput. For cross-attention with an |
| encoder padding mask the k/v tensors are unpadded and |
| :func:`flash_attn_varlen_func` is used so that padding tokens are never |
| processed. For self-attention (no mask) the simpler |
| :func:`flash_attn_func` is used. |
| """ |
|
|
| def __init__( |
| self, |
| query_dim: int, |
| num_heads: int, |
| head_dim: int, |
| cross_attention_dim: Optional[int] = None, |
| dropout: float = 0.0, |
| bias: bool = True, |
| attn_implementation: str = "sdpa", |
| ): |
| super().__init__() |
| self.num_heads = num_heads |
| self.head_dim = head_dim |
| self.attn_implementation = attn_implementation |
| inner_dim = num_heads * head_dim |
|
|
| self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) |
| kv_dim = cross_attention_dim if cross_attention_dim is not None else query_dim |
| self.to_k = nn.Linear(kv_dim, inner_dim, bias=bias) |
| self.to_v = nn.Linear(kv_dim, inner_dim, bias=bias) |
| self.to_out = nn.Sequential( |
| nn.Linear(inner_dim, query_dim, bias=bias), |
| nn.Dropout(dropout), |
| ) |
|
|
| |
| |
| |
|
|
| def _flash_attn_forward( |
| self, |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| ) -> torch.Tensor: |
| """Run Flash Attention via HuggingFace's ``_flash_attention_forward``. |
| |
| Args: |
| q: ``(B, T_q, H, D)`` |
| k: ``(B, T_k, H, D)`` |
| v: ``(B, T_k, H, D)`` |
| attention_mask: ``(B, T_k)`` bool, True = valid token. |
| |
| Returns: |
| ``(B, T_q, H*D)`` |
| """ |
|
|
| B, T_q, H, D = q.shape |
| |
| out = _flash_attention_forward( |
| q, k, v, |
| attention_mask=attention_mask, |
| query_length=T_q, |
| is_causal=False, |
| dropout=0.0, |
| ) |
| return out.reshape(B, T_q, H * D) |
|
|
| |
| |
| |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| encoder_hidden_states: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| B, T, _ = hidden_states.shape |
|
|
| q = self.to_q(hidden_states) |
| kv_input = encoder_hidden_states if encoder_hidden_states is not None else hidden_states |
| k = self.to_k(kv_input) |
| v = self.to_v(kv_input) |
|
|
| if self.attn_implementation == "flash_attention_2": |
| |
| q = q.view(B, T, self.num_heads, self.head_dim) |
| k = k.view(B, -1, self.num_heads, self.head_dim) |
| v = v.view(B, -1, self.num_heads, self.head_dim) |
| attn_output = self._flash_attn_forward(q, k, v, attention_mask) |
| else: |
| |
| q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) |
| k = k.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2) |
| v = v.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
| |
| sdpa_mask = None |
| if attention_mask is not None: |
| if attention_mask.dim() == 2: |
| sdpa_mask = attention_mask[:, None, None, :] |
| else: |
| sdpa_mask = attention_mask |
|
|
| attn_output = F.scaled_dot_product_attention( |
| q, k, v, attn_mask=sdpa_mask, dropout_p=0.0 |
| ) |
| attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, -1) |
|
|
| return self.to_out(attn_output) |
|
|
|
|
| class FeedForward(nn.Module): |
| """Feed-forward network with GELU activation.""" |
|
|
| def __init__(self, dim: int, dropout: float = 0.0, mult: int = 4): |
| super().__init__() |
| inner_dim = dim * mult |
| self.net = nn.Sequential( |
| nn.Linear(dim, inner_dim), |
| nn.GELU(approximate="tanh"), |
| nn.Dropout(dropout), |
| nn.Linear(inner_dim, dim), |
| nn.Dropout(dropout), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.net(x) |
|
|
|
|
| class BasicTransformerBlock(nn.Module): |
| """Transformer block with self/cross-attention, optional AdaLayerNorm, and feed-forward. |
| |
| When cross_attention_dim is set, the attention block performs cross-attention |
| to encoder_hidden_states. Otherwise, it performs self-attention. |
| """ |
|
|
| def __init__( |
| self, |
| dim: int, |
| num_attention_heads: int, |
| attention_head_dim: int, |
| dropout: float = 0.0, |
| cross_attention_dim: Optional[int] = None, |
| norm_type: str = "ada_norm", |
| final_dropout: bool = False, |
| attn_implementation: str = "sdpa", |
| ): |
| super().__init__() |
| self.norm_type = norm_type |
|
|
| if norm_type == "ada_norm": |
| self.norm1 = AdaLayerNorm(dim) |
| else: |
| self.norm1 = nn.LayerNorm(dim) |
|
|
| self.attn1 = DiTAttention( |
| query_dim=dim, |
| num_heads=num_attention_heads, |
| head_dim=attention_head_dim, |
| cross_attention_dim=cross_attention_dim, |
| dropout=dropout, |
| attn_implementation=attn_implementation, |
| ) |
|
|
| self.norm3 = nn.LayerNorm(dim) |
| self.ff = FeedForward(dim, dropout=dropout) |
| self.final_dropout = nn.Dropout(dropout) if final_dropout else None |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| encoder_hidden_states: Optional[torch.Tensor] = None, |
| encoder_attention_mask: Optional[torch.Tensor] = None, |
| temb: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| if self.norm_type == "ada_norm": |
| norm_hidden_states = self.norm1(hidden_states, temb) |
| else: |
| norm_hidden_states = self.norm1(hidden_states) |
|
|
| attn_output = self.attn1( |
| norm_hidden_states, |
| encoder_hidden_states=encoder_hidden_states, |
| attention_mask=encoder_attention_mask, |
| ) |
|
|
| if self.final_dropout is not None: |
| attn_output = self.final_dropout(attn_output) |
|
|
| hidden_states = attn_output + hidden_states |
|
|
| norm_hidden_states = self.norm3(hidden_states) |
| ff_output = self.ff(norm_hidden_states) |
| hidden_states = ff_output + hidden_states |
|
|
| return hidden_states |
|
|
|
|
| class DiT(nn.Module): |
| """Diffusion Transformer with cross-attention to VLM context features. |
| |
| Interleaves cross-attention blocks (attending to encoder_hidden_states) |
| with self-attention blocks when interleave_self_attention=True. |
| Uses AdaLayerNorm for timestep conditioning throughout. |
| |
| Output block applies timestep-conditioned scale-shift before final projection. |
| """ |
|
|
| def __init__( |
| self, |
| num_attention_heads: int = 12, |
| attention_head_dim: int = 64, |
| output_dim: int = 768, |
| num_layers: int = 12, |
| dropout: float = 0.1, |
| norm_type: str = "ada_norm", |
| final_dropout: bool = True, |
| interleave_self_attention: bool = False, |
| cross_attention_dim: Optional[int] = None, |
| attn_implementation: str = "sdpa", |
| ): |
| super().__init__() |
| self.inner_dim = num_attention_heads * attention_head_dim |
| self.output_dim = output_dim |
| self.num_layers = num_layers |
| self.interleave_self_attention = interleave_self_attention |
|
|
| self.timestep_encoder = TimestepEncoder(self.inner_dim) |
|
|
| all_blocks = [] |
| for idx in range(num_layers): |
| use_self_attn = idx % 2 == 1 and interleave_self_attention |
| curr_cross_attention_dim = cross_attention_dim if not use_self_attn else None |
|
|
| all_blocks.append( |
| BasicTransformerBlock( |
| dim=self.inner_dim, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| dropout=dropout, |
| cross_attention_dim=curr_cross_attention_dim, |
| norm_type=norm_type, |
| final_dropout=final_dropout, |
| attn_implementation=attn_implementation, |
| ) |
| ) |
| self.transformer_blocks = nn.ModuleList(all_blocks) |
|
|
| self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) |
| self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim) |
| self.proj_out_2 = nn.Linear(self.inner_dim, output_dim) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| encoder_hidden_states: torch.Tensor, |
| timestep: torch.LongTensor, |
| encoder_attention_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| temb = self.timestep_encoder(timestep) |
|
|
| hidden_states = hidden_states.contiguous() |
| encoder_hidden_states = encoder_hidden_states.contiguous() |
|
|
| for idx, block in enumerate(self.transformer_blocks): |
| if idx % 2 == 1 and self.interleave_self_attention: |
| hidden_states = block( |
| hidden_states, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| temb=temb, |
| ) |
| else: |
| hidden_states = block( |
| hidden_states, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_attention_mask, |
| temb=temb, |
| ) |
|
|
| conditioning = temb |
| shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=-1) |
| hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] |
| return self.proj_out_2(hidden_states) |
|
|
|
|
| class AlternateVLDiT(DiT): |
| """DiT variant that separates visual and text tokens during cross-attention. |
| |
| Mirrors GR00T's AlternateVLDiT: even-indexed blocks do cross-attention, |
| alternating every ``attend_text_every_n_blocks`` between text tokens and |
| visual tokens. Odd-indexed blocks do self-attention (requires |
| ``interleave_self_attention=True``). |
| |
| When no visual tokens are present (``image_mask`` is None or all-False), |
| all valid tokens are treated as text. |
| """ |
|
|
| def __init__(self, *args, attend_text_every_n_blocks: int = 2, **kwargs): |
| super().__init__(*args, **kwargs) |
| assert self.interleave_self_attention, ( |
| "AlternateVLDiT requires interleave_self_attention=True" |
| ) |
| self.attend_text_every_n_blocks = attend_text_every_n_blocks |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| encoder_hidden_states: torch.Tensor, |
| timestep: torch.LongTensor, |
| encoder_attention_mask: Optional[torch.Tensor] = None, |
| image_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """ |
| Args: |
| encoder_attention_mask: (B, S) bool – True = valid VLM token. |
| image_mask: (B, S) bool – True = visual token position. |
| If None, all valid tokens are treated as text. |
| """ |
| temb = self.timestep_encoder(timestep) |
| hidden_states = hidden_states.contiguous() |
| encoder_hidden_states = encoder_hidden_states.contiguous() |
|
|
| B, S, _ = encoder_hidden_states.shape |
| backbone_mask = ( |
| encoder_attention_mask.bool() |
| if encoder_attention_mask is not None |
| else torch.ones(B, S, dtype=torch.bool, device=hidden_states.device) |
| ) |
|
|
| if image_mask is not None and image_mask.any(): |
| vis_mask = image_mask.bool() & backbone_mask |
| text_mask = (~image_mask.bool()) & backbone_mask |
| else: |
| |
| vis_mask = torch.zeros_like(backbone_mask) |
| text_mask = backbone_mask |
|
|
| for idx, block in enumerate(self.transformer_blocks): |
| if idx % 2 == 1: |
| |
| hidden_states = block( |
| hidden_states, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| temb=temb, |
| ) |
| else: |
| |
| if idx % (2 * self.attend_text_every_n_blocks) == 0: |
| curr_mask = text_mask |
| else: |
| curr_mask = vis_mask |
| hidden_states = block( |
| hidden_states, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=curr_mask, |
| temb=temb, |
| ) |
|
|
| conditioning = temb |
| shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=-1) |
| hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] |
| return self.proj_out_2(hidden_states) |
|
|
|
|
| class ActionEncoder(nn.Module): |
| """Encodes noisy actions (optionally concatenated with DOF mask) and timestep |
| into hidden features via MLP + sinusoidal time encoding. |
| |
| Architecture: Linear → concat(action_emb, time_emb) → SiLU + Linear → Linear |
| """ |
|
|
| def __init__(self, action_input_dim: int, hidden_size: int): |
| super().__init__() |
| self.hidden_size = hidden_size |
| self.layer1 = nn.Linear(action_input_dim, hidden_size) |
| self.layer2 = nn.Linear(2 * hidden_size, hidden_size) |
| self.layer3 = nn.Linear(hidden_size, hidden_size) |
| self.pos_encoding = SinusoidalPositionalEncoding(hidden_size) |
|
|
| def forward(self, actions: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| actions: (B, T, action_input_dim) noisy actions (+ DOF mask) |
| timesteps: (B,) discretized timesteps |
| """ |
| B, T, _ = actions.shape |
| timesteps_expanded = timesteps.unsqueeze(1).expand(-1, T) |
|
|
| a_emb = self.layer1(actions) |
| tau_emb = self.pos_encoding(timesteps_expanded).to(dtype=a_emb.dtype) |
|
|
| x = torch.cat([a_emb, tau_emb], dim=-1) |
| x = F.silu(self.layer2(x)) |
| x = self.layer3(x) |
| return x |
|
|
|
|
| class ActionDecoder(nn.Module): |
| """2-layer MLP that decodes DiT output to action-space velocity.""" |
|
|
| def __init__(self, input_dim: int, hidden_dim: int, output_dim: int): |
| super().__init__() |
| self.layer1 = nn.Linear(input_dim, hidden_dim) |
| self.layer2 = nn.Linear(hidden_dim, output_dim) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.layer2(F.relu(self.layer1(x))) |
|
|
|
|
| class FlowMatchingDiTHead(nn.Module): |
| """Flow matching action head using DiT (Diffusion Transformer). |
| |
| Replaces the fm_action_expert (Qwen3VLTextModel-based) with a DiT that uses |
| explicit cross-attention to VLM hidden states instead of KV cache continuation. |
| |
| Training: |
| 1. Sample noise and timestep from Beta distribution |
| 2. Compute noisy trajectory: x_t = (1-t)*noise + t*actions |
| 3. Compute velocity target: v = actions - noise |
| 4. Encode noisy actions + DOF mask + timestep → action features |
| 5. Prepend learned future query tokens |
| 6. Run DiT with cross-attention to VLM hidden states |
| 7. Decode to action-space velocity prediction |
| |
| Inference: |
| Euler integration from pure noise (t=0) to clean actions (t=1) |
| over num_inference_timesteps steps. |
| """ |
|
|
| def __init__( |
| self, |
| action_dim: int, |
| action_chunk_size: int, |
| cross_attention_dim: int, |
| num_inference_timesteps: int = 4, |
| config: Optional[dict] = None, |
| ): |
| super().__init__() |
| cfg = { |
| "num_layers": 16, |
| "num_attention_heads": 12, |
| "attention_head_dim": 64, |
| "output_dim": 1024, |
| "dropout": 0.2, |
| "interleave_self_attention": True, |
| "norm_type": "ada_norm", |
| "final_dropout": True, |
| "add_pos_embed": True, |
| "noise_beta_alpha": 1.5, |
| "noise_beta_beta": 1.0, |
| "noise_s": 0.999, |
| "num_timestep_buckets": 1000, |
| "attn_implementation": "sdpa", |
| "use_alternate_vl_dit": False, |
| "attend_text_every_n_blocks": 2, |
| } |
| if config is not None: |
| cfg.update(config) |
| |
| |
| |
| |
|
|
| self.action_dim = action_dim |
| self.action_chunk_size = action_chunk_size |
| self.num_inference_timesteps = num_inference_timesteps |
| self.num_timestep_buckets = cfg["num_timestep_buckets"] |
| self.noise_s = cfg["noise_s"] |
| self.use_alternate_vl_dit = cfg["use_alternate_vl_dit"] |
| self.add_pos_embed = cfg["add_pos_embed"] |
|
|
| num_attention_heads = cfg["num_attention_heads"] |
| attention_head_dim = cfg["attention_head_dim"] |
| output_dim = cfg["output_dim"] |
| inner_dim = num_attention_heads * attention_head_dim |
|
|
| dit_kwargs = dict( |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| output_dim=output_dim, |
| num_layers=cfg["num_layers"], |
| dropout=cfg["dropout"], |
| norm_type=cfg["norm_type"], |
| final_dropout=cfg["final_dropout"], |
| interleave_self_attention=cfg["interleave_self_attention"], |
| cross_attention_dim=cross_attention_dim, |
| attn_implementation=cfg["attn_implementation"], |
| ) |
| if self.use_alternate_vl_dit: |
| self.dit = AlternateVLDiT( |
| **dit_kwargs, |
| attend_text_every_n_blocks=cfg["attend_text_every_n_blocks"], |
| ) |
| else: |
| self.dit = DiT(**dit_kwargs) |
|
|
| |
| self.action_encoder = ActionEncoder(action_dim * 2, inner_dim) |
| self.action_decoder = ActionDecoder(output_dim, inner_dim, action_dim) |
|
|
| if self.add_pos_embed: |
| max_seq_len = max(action_chunk_size, 256) |
| self.position_embedding = nn.Embedding(max_seq_len, inner_dim) |
| nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02) |
|
|
| |
| self._beta_alpha = cfg["noise_beta_alpha"] |
| self._beta_beta = cfg["noise_beta_beta"] |
|
|
| def reset_parameters(self): |
| """Re-apply proper initialization. |
| |
| HuggingFace from_pretrained calls _init_weights on modules whose |
| parameters are absent from the checkpoint, overwriting any custom |
| init done in __init__. Call this after from_pretrained when loading |
| from a base VLM checkpoint that does not contain DiT weights. |
| """ |
| if self.add_pos_embed: |
| nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02) |
| for module in self.modules(): |
| if isinstance(module, nn.Linear): |
| nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) |
| if module.bias is not None: |
| fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) |
| bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 |
| nn.init.uniform_(module.bias, -bound, bound) |
| elif isinstance(module, nn.LayerNorm): |
| if module.elementwise_affine: |
| nn.init.ones_(module.weight) |
| nn.init.zeros_(module.bias) |
|
|
| def sample_time(self, batch_size: int, device, dtype) -> torch.Tensor: |
| beta_dist = Beta(self._beta_alpha, self._beta_beta) |
| sample = beta_dist.sample([batch_size]).to(device, dtype=dtype).clamp(max=self.noise_s) |
| return (self.noise_s - sample) / self.noise_s |
|
|
| def _encode_actions( |
| self, |
| noisy_actions: torch.Tensor, |
| t_discretized: torch.Tensor, |
| action_dof_mask: Optional[torch.Tensor], |
| device, |
| ) -> torch.Tensor: |
| """Encode noisy actions with DOF mask and timestep, add position embeddings.""" |
| if action_dof_mask is not None: |
| encoder_input = torch.cat( |
| [noisy_actions, action_dof_mask.to(noisy_actions.dtype)], dim=-1 |
| ) |
| else: |
| encoder_input = torch.cat( |
| [noisy_actions, torch.ones_like(noisy_actions)], dim=-1 |
| ) |
|
|
| action_features = self.action_encoder(encoder_input, t_discretized) |
|
|
| if self.add_pos_embed: |
| pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device) |
| pos_embs = self.position_embedding(pos_ids).unsqueeze(0) |
| action_features = action_features + pos_embs |
|
|
| return action_features |
|
|
| def _dit_forward( |
| self, |
| sa_embs: torch.Tensor, |
| vl_embs: torch.Tensor, |
| t_discretized: torch.LongTensor, |
| encoder_attention_mask: Optional[torch.Tensor], |
| image_mask: Optional[torch.Tensor], |
| ) -> torch.Tensor: |
| if self.use_alternate_vl_dit: |
| return self.dit( |
| hidden_states=sa_embs, |
| encoder_hidden_states=vl_embs, |
| timestep=t_discretized, |
| encoder_attention_mask=encoder_attention_mask, |
| image_mask=image_mask, |
| ) |
| return self.dit( |
| hidden_states=sa_embs, |
| encoder_hidden_states=vl_embs, |
| timestep=t_discretized, |
| encoder_attention_mask=encoder_attention_mask, |
| ) |
|
|
| def forward( |
| self, |
| vl_embs: torch.Tensor, |
| actions: torch.Tensor, |
| action_dof_mask: Optional[torch.Tensor] = None, |
| encoder_attention_mask: Optional[torch.Tensor] = None, |
| image_mask: Optional[torch.Tensor] = None, |
| ) -> tuple: |
| """Training forward pass. |
| |
| Args: |
| vl_embs: (B, S, D) VLM hidden states for cross-attention |
| actions: (B, T, action_dim) ground truth action trajectories |
| action_dof_mask: (B, T, action_dim) DOF validity mask |
| encoder_attention_mask: (B, S) bool – True = valid VLM token |
| image_mask: (B, S) bool – True = visual token (used by AlternateVLDiT) |
| |
| Returns: |
| (pred_v, velocity): predicted velocity and target velocity, both (B, T, action_dim) |
| """ |
| device = vl_embs.device |
| B = actions.shape[0] |
|
|
| noise = torch.randn(actions.shape, device=device, dtype=actions.dtype) |
| t = self.sample_time(B, device=device, dtype=actions.dtype) |
| t_expanded = t[:, None, None] |
|
|
| noisy_trajectory = (1 - t_expanded) * noise + t_expanded * actions |
| velocity = actions - noise |
|
|
| t_discretized = (t * self.num_timestep_buckets).long() |
|
|
| action_features = self._encode_actions(noisy_trajectory, t_discretized, action_dof_mask, device) |
|
|
| model_output = self._dit_forward( |
| action_features, vl_embs, t_discretized, encoder_attention_mask, image_mask |
| ) |
|
|
| pred = self.action_decoder(model_output) |
| pred_v = pred[:, :actions.shape[1]] |
|
|
| return pred_v, velocity |
|
|
| @torch.no_grad() |
| def predict_action( |
| self, |
| vl_embs: torch.Tensor, |
| action_dof_mask: Optional[torch.Tensor] = None, |
| encoder_attention_mask: Optional[torch.Tensor] = None, |
| image_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """Inference: denoise actions from noise using Euler integration. |
| |
| Args: |
| vl_embs: (B, S, D) VLM hidden states |
| action_dof_mask: optional (B, T, action_dim) or (1, T, action_dim) DOF mask |
| encoder_attention_mask: (B, S) bool – True = valid VLM token |
| image_mask: (B, S) bool – True = visual token (used by AlternateVLDiT) |
| |
| Returns: |
| (B, T, action_dim) denoised action trajectories |
| """ |
| B = vl_embs.shape[0] |
| device = vl_embs.device |
| dtype = vl_embs.dtype |
|
|
| actions = torch.randn( |
| (B, self.action_chunk_size, self.action_dim), |
| device=device, dtype=dtype, |
| ) |
|
|
| dt = 1.0 / self.num_inference_timesteps |
|
|
| for step in range(self.num_inference_timesteps): |
| t_cont = step / float(self.num_inference_timesteps) |
| t_discretized_val = int(t_cont * self.num_timestep_buckets) |
| timesteps_tensor = torch.full((B,), t_discretized_val, device=device, dtype=torch.long) |
|
|
| action_features = self._encode_actions(actions, timesteps_tensor, action_dof_mask, device) |
|
|
| model_output = self._dit_forward( |
| action_features, vl_embs, timesteps_tensor, encoder_attention_mask, image_mask |
| ) |
|
|
| pred = self.action_decoder(model_output) |
| pred_velocity = pred[:, :self.action_chunk_size] |
|
|
| actions = actions + dt * pred_velocity |
|
|
| return actions |
|
|
|
|
| |
| |
| |
| class AdaRMSNorm(nn.Module): |
| """Adaptive RMS normalization: (scale, shift, gate) from cond; zero-init.""" |
|
|
| def __init__(self, dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.eps = eps |
| self.modulation = nn.Linear(dim, dim * 3) |
| nn.init.zeros_(self.modulation.weight) |
| nn.init.zeros_(self.modulation.bias) |
|
|
| def forward(self, x: torch.Tensor, cond: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| var = x.float().pow(2).mean(-1, keepdim=True) |
| normed = (x * torch.rsqrt(var + self.eps)).to(x.dtype) |
| scale, shift, gate = self.modulation(cond).chunk(3, dim=-1) |
| normed = normed * (1 + scale[:, None]) + shift[:, None] |
| return normed, gate[:, None] |
|
|
|
|
| class SwiGLUFeedForward(nn.Module): |
| """SiLU(gate_proj(x)) * up_proj(x) → down_proj.""" |
|
|
| def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.0, bias: bool = True): |
| super().__init__() |
| self.gate_proj = nn.Linear(dim, hidden_dim, bias=bias) |
| self.up_proj = nn.Linear(dim, hidden_dim, bias=bias) |
| self.down_proj = nn.Linear(hidden_dim, dim, bias=bias) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.down_proj(self.dropout(F.silu(self.gate_proj(x)) * self.up_proj(x))) |
|
|
|
|
| class MoTAttention(nn.Module): |
| """Action Q attends to concatenated [VLM KV cache ; action KV]; GQA expand for SDPA.""" |
|
|
| def __init__( |
| self, |
| hidden_size: int, |
| num_attention_heads: int, |
| num_kv_heads: int, |
| head_dim: int, |
| dropout: float = 0.0, |
| bias: bool = True, |
| ): |
| super().__init__() |
| if num_attention_heads % num_kv_heads != 0: |
| raise ValueError( |
| f"num_attention_heads ({num_attention_heads}) must be divisible by " |
| f"num_kv_heads ({num_kv_heads})" |
| ) |
| self.num_attention_heads = num_attention_heads |
| self.num_kv_heads = num_kv_heads |
| self.head_dim = head_dim |
| q_dim = num_attention_heads * head_dim |
| kv_dim = num_kv_heads * head_dim |
| self.q_proj = nn.Linear(hidden_size, q_dim, bias=bias) |
| self.k_proj = nn.Linear(hidden_size, kv_dim, bias=bias) |
| self.v_proj = nn.Linear(hidden_size, kv_dim, bias=bias) |
| self.o_proj = nn.Linear(q_dim, hidden_size, bias=bias) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward( |
| self, |
| action_hidden: torch.Tensor, |
| vlm_cached_k: torch.Tensor, |
| vlm_cached_v: torch.Tensor, |
| vlm_attention_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| B, T_a, _ = action_hidden.shape |
|
|
| q = self.q_proj(action_hidden) |
| act_k = self.k_proj(action_hidden) |
| act_v = self.v_proj(action_hidden) |
|
|
| q = q.view(B, T_a, self.num_attention_heads, self.head_dim).transpose(1, 2) |
| act_k = act_k.view(B, T_a, self.num_kv_heads, self.head_dim).transpose(1, 2) |
| act_v = act_v.view(B, T_a, self.num_kv_heads, self.head_dim).transpose(1, 2) |
|
|
| k = torch.cat([vlm_cached_k, act_k], dim=2) |
| v = torch.cat([vlm_cached_v, act_v], dim=2) |
|
|
| repeat_factor = self.num_attention_heads // self.num_kv_heads |
| k = k.repeat_interleave(repeat_factor, dim=1) |
| v = v.repeat_interleave(repeat_factor, dim=1) |
|
|
| sdpa_mask = None |
| if vlm_attention_mask is not None: |
| action_mask = vlm_attention_mask.new_ones(B, T_a) |
| combined_mask = torch.cat([vlm_attention_mask, action_mask], dim=1) |
| sdpa_mask = combined_mask[:, None, None, :] |
|
|
| attn_out = F.scaled_dot_product_attention( |
| q, k, v, attn_mask=sdpa_mask, dropout_p=0.0, |
| ) |
| attn_out = attn_out.transpose(1, 2).contiguous().view(B, T_a, -1) |
| return self.dropout(self.o_proj(attn_out)) |
|
|
|
|
| class MoTBlock(nn.Module): |
| """AdaRMSNorm → attention → gated residual → AdaRMSNorm → SwiGLU FFN → gated residual.""" |
|
|
| def __init__( |
| self, |
| hidden_size: int, |
| num_attention_heads: int, |
| num_kv_heads: int, |
| head_dim: int, |
| intermediate_size: int, |
| dropout: float = 0.0, |
| ): |
| super().__init__() |
| self.pre_attn_norm = AdaRMSNorm(hidden_size) |
| self.attn = MoTAttention( |
| hidden_size=hidden_size, |
| num_attention_heads=num_attention_heads, |
| num_kv_heads=num_kv_heads, |
| head_dim=head_dim, |
| dropout=dropout, |
| ) |
| self.pre_ffn_norm = AdaRMSNorm(hidden_size) |
| self.ffn = SwiGLUFeedForward(hidden_size, intermediate_size, dropout=dropout) |
|
|
| def forward( |
| self, |
| action_hidden: torch.Tensor, |
| vlm_cached_k: torch.Tensor, |
| vlm_cached_v: torch.Tensor, |
| adarms_cond: torch.Tensor, |
| vlm_attention_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| normed, gate1 = self.pre_attn_norm(action_hidden, adarms_cond) |
| attn_out = self.attn(normed, vlm_cached_k, vlm_cached_v, vlm_attention_mask) |
| action_hidden = action_hidden + attn_out * gate1 |
|
|
| normed2, gate2 = self.pre_ffn_norm(action_hidden, adarms_cond) |
| action_hidden = action_hidden + self.ffn(normed2) * gate2 |
| return action_hidden |
|
|
|
|
| class MoTDiT(nn.Module): |
| """Stack of ActionBlocks; each block uses one VLM layer's KV pair.""" |
|
|
| def __init__( |
| self, |
| hidden_size: int, |
| num_attention_heads: int, |
| num_kv_heads: int, |
| head_dim: int, |
| intermediate_size: int, |
| num_layers: int, |
| dropout: float = 0.2, |
| ): |
| super().__init__() |
| self.num_layers = num_layers |
| self.blocks = nn.ModuleList([ |
| MoTBlock( |
| hidden_size=hidden_size, |
| num_attention_heads=num_attention_heads, |
| num_kv_heads=num_kv_heads, |
| head_dim=head_dim, |
| intermediate_size=intermediate_size, |
| dropout=dropout, |
| ) |
| for _ in range(num_layers) |
| ]) |
| self.final_norm = AdaRMSNorm(hidden_size) |
|
|
| def forward( |
| self, |
| action_hidden: torch.Tensor, |
| vlm_kv_cache: list, |
| adarms_cond: torch.Tensor, |
| vlm_attention_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| for idx, block in enumerate(self.blocks): |
| cached_k, cached_v = vlm_kv_cache[idx] |
| action_hidden = block( |
| action_hidden, cached_k, cached_v, adarms_cond, vlm_attention_mask, |
| ) |
| action_hidden, _ = self.final_norm(action_hidden, adarms_cond) |
| return action_hidden |
|
|
|
|
| def _kv_pairs_from_past_key_values(past_key_values: Cache) -> list[tuple[torch.Tensor, torch.Tensor]]: |
| """Per-layer (K, V) from a HuggingFace decoder KV cache (order matches transformer layers).""" |
| return [ |
| (past_key_values[i][0], past_key_values[i][1]) |
| for i in range(len(past_key_values)) |
| ] |
|
|
|
|
| class MoTFlowMatchingHead(nn.Module): |
| """Flow matching head: MoT-style action expert over VLM KV cache (concat + GQA).""" |
|
|
| def __init__( |
| self, |
| action_dim: int, |
| action_chunk_size: int, |
| vlm_config, |
| num_inference_timesteps: int = 10, |
| config: Optional[dict] = None, |
| ): |
| super().__init__() |
|
|
| _vlm_num_q_heads = 8 |
| _vlm_num_kv_heads = vlm_config.num_key_value_heads |
| _vlm_head_dim = getattr( |
| vlm_config, "head_dim", vlm_config.hidden_size // vlm_config.num_attention_heads |
| ) |
|
|
| cfg = { |
| "hidden_size": 1024, |
| |
| "intermediate_size": vlm_config.intermediate_size // 4, |
| "expert_num_layers": vlm_config.num_hidden_layers, |
| |
| "num_attention_heads": _vlm_num_q_heads, |
| "num_kv_heads": _vlm_num_kv_heads, |
| "head_dim": _vlm_head_dim, |
| |
| "dropout": 0.2, |
| "add_pos_embed": True, |
| "noise_beta_alpha": 1.5, |
| "noise_beta_beta": 1.0, |
| "noise_s": 0.999, |
| "num_timestep_buckets": 1000, |
| } |
| if config is not None: |
| config = cfg.copy() |
|
|
| num_attention_heads = cfg["num_attention_heads"] |
| num_kv_heads = cfg["num_kv_heads"] |
| head_dim = cfg["head_dim"] |
| hidden_size = cfg["hidden_size"] |
| intermediate_size = cfg["intermediate_size"] |
| num_layers = cfg["expert_num_layers"] |
|
|
| self.action_dim = action_dim |
| self.action_chunk_size = action_chunk_size |
| self.num_inference_timesteps = num_inference_timesteps |
| self.num_timestep_buckets = cfg["num_timestep_buckets"] |
| self.noise_s = cfg["noise_s"] |
| self.add_pos_embed = cfg["add_pos_embed"] |
|
|
| self.action_in_proj = nn.Linear(action_dim * 2, hidden_size) |
| self.action_out_proj = nn.Linear(hidden_size, action_dim) |
|
|
| self.time_sinusoidal = SinusoidalPositionalEncoding(hidden_size) |
| self.time_mlp_1 = nn.Linear(hidden_size, hidden_size) |
| self.time_mlp_2 = nn.Linear(hidden_size, hidden_size) |
|
|
| if self.add_pos_embed: |
| max_seq = max(action_chunk_size, 256) |
| self.position_embedding = nn.Embedding(max_seq, hidden_size) |
| nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02) |
|
|
| self.dit = MoTDiT( |
| hidden_size=hidden_size, |
| num_attention_heads=num_attention_heads, |
| num_kv_heads=num_kv_heads, |
| head_dim=head_dim, |
| intermediate_size=intermediate_size, |
| num_layers=num_layers, |
| dropout=cfg["dropout"], |
| ) |
|
|
| self._beta_alpha = cfg["noise_beta_alpha"] |
| self._beta_beta = cfg["noise_beta_beta"] |
|
|
| @property |
| def num_dit_layers(self) -> int: |
| """Number of expert blocks; must match ``len(past_key_values.key_cache)``.""" |
| return self.dit.num_layers |
|
|
| def _vlm_kv_list_from_past(self, past_key_values: Cache) -> list[tuple[torch.Tensor, torch.Tensor]]: |
| n = len(past_key_values) |
| if n != self.num_dit_layers: |
| raise ValueError( |
| f"MoT expert has {self.num_dit_layers} blocks but `past_key_values` has {n} " |
| "layers. Set `dit_action_head_config['expert_num_layers']` to match " |
| "`text_config.num_hidden_layers`." |
| ) |
| return _kv_pairs_from_past_key_values(past_key_values) |
|
|
| def reset_parameters(self): |
| """Re-apply proper initialization after from_pretrained.""" |
| if self.add_pos_embed: |
| nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02) |
| for module in self.modules(): |
| if isinstance(module, AdaRMSNorm): |
| nn.init.zeros_(module.modulation.weight) |
| nn.init.zeros_(module.modulation.bias) |
| elif isinstance(module, nn.Linear): |
| nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) |
| if module.bias is not None: |
| fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) |
| bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 |
| nn.init.uniform_(module.bias, -bound, bound) |
|
|
| def _compute_adarms_cond(self, t_discretized: torch.Tensor) -> torch.Tensor: |
| t_emb = self.time_sinusoidal(t_discretized.float()) |
| t_emb = t_emb.to(dtype=self.time_mlp_1.weight.dtype) |
| t_emb = F.silu(self.time_mlp_1(t_emb)) |
| t_emb = F.silu(self.time_mlp_2(t_emb)) |
| return t_emb |
|
|
| def sample_time(self, batch_size: int, device, dtype) -> torch.Tensor: |
| beta_dist = Beta(self._beta_alpha, self._beta_beta) |
| sample = beta_dist.sample([batch_size]).to(device, dtype=dtype).clamp(max=self.noise_s) |
| return (self.noise_s - sample) / self.noise_s |
|
|
| def _prepare_action_embeds( |
| self, |
| noisy_actions: torch.Tensor, |
| action_dof_mask: Optional[torch.Tensor], |
| ) -> torch.Tensor: |
| if action_dof_mask is not None: |
| x = torch.cat( |
| [noisy_actions, action_dof_mask.to(noisy_actions.dtype)], dim=-1, |
| ) |
| else: |
| x = torch.cat([noisy_actions, torch.ones_like(noisy_actions)], dim=-1) |
|
|
| tokens = self.action_in_proj(x) |
|
|
| if self.add_pos_embed: |
| pos_ids = torch.arange(tokens.shape[1], dtype=torch.long, device=noisy_actions.device) |
| tokens = tokens + self.position_embedding(pos_ids).unsqueeze(0) |
|
|
| return tokens |
|
|
| def forward( |
| self, |
| past_key_values: Cache, |
| actions: torch.Tensor, |
| action_dof_mask: Optional[torch.Tensor] = None, |
| encoder_attention_mask: Optional[torch.Tensor] = None, |
| ) -> tuple: |
| """Training: returns (pred_velocity, target_velocity). |
| |
| Args: |
| past_key_values: VLM decoder KV cache; layer count must equal ``num_dit_layers``. |
| """ |
| vlm_kv_cache = self._vlm_kv_list_from_past(past_key_values) |
| device = actions.device |
| B = actions.shape[0] |
|
|
| noise = torch.randn(actions.shape, device=device, dtype=actions.dtype) |
| t = self.sample_time(B, device=device, dtype=actions.dtype) |
| t_expanded = t[:, None, None] |
|
|
| noisy_trajectory = (1 - t_expanded) * noise + t_expanded * actions |
| velocity = actions - noise |
|
|
| t_discretized = (t * self.num_timestep_buckets).long() |
| adarms_cond = self._compute_adarms_cond(t_discretized) |
|
|
| action_tokens = self._prepare_action_embeds(noisy_trajectory, action_dof_mask) |
|
|
| output = self.dit( |
| action_tokens, vlm_kv_cache, adarms_cond, encoder_attention_mask, |
| ) |
|
|
| pred = self.action_out_proj(output) |
| pred_v = pred[:, :actions.shape[1]] |
| return pred_v, velocity |
| |
| def compute_velocity( |
| self, |
| past_key_values: Cache, |
| actions: torch.Tensor, |
| noise: torch.Tensor, |
| t: torch.Tensor, |
| action_dof_mask: Optional[torch.Tensor] = None, |
| encoder_attention_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """Compute velocity prediction for pre-sampled noise and timestep. |
| |
| Used by DiffusionNFT where noise and timestep must be shared between |
| the current policy (v_θ) and the reference policy (v_old). |
| |
| Args: |
| past_key_values: VLM decoder KV cache |
| actions: (B, T, action_dim) ground truth actions (x_0) |
| noise: (B, T, action_dim) pre-sampled noise (ε) |
| t: (B,) continuous timesteps in [0, 1) |
| action_dof_mask, encoder_attention_mask, |
| |
| Returns: |
| pred_v: (B, T, action_dim) predicted velocity |
| """ |
| vlm_kv_cache = self._vlm_kv_list_from_past(past_key_values) |
| device = actions.device |
| t_expanded = t[:, None, None] |
|
|
| noisy_trajectory = (1 - t_expanded) * noise + t_expanded * actions |
| t_discretized = (t * self.num_timestep_buckets).long() |
| adarms_cond = self._compute_adarms_cond(t_discretized) |
| action_tokens = self._prepare_action_embeds(noisy_trajectory, action_dof_mask) |
| output = self.dit( |
| action_tokens, vlm_kv_cache, adarms_cond, encoder_attention_mask, |
| ) |
| pred = self.action_out_proj(output) |
| return pred[:, :actions.shape[1]] |
|
|
|
|
| @torch.no_grad() |
| def predict_action( |
| self, |
| past_key_values: Cache, |
| action_dof_mask: Optional[torch.Tensor] = None, |
| encoder_attention_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """Inference: Euler integration, returns (B, chunk_size, action_dim).""" |
| k0 = past_key_values[0][0] |
| B = k0.shape[0] |
| device = k0.device |
| dtype = k0.dtype |
| vlm_kv_cache = self._vlm_kv_list_from_past(past_key_values) |
|
|
| actions = torch.randn( |
| (B, self.action_chunk_size, self.action_dim), |
| device=device, dtype=dtype, |
| ) |
| dt = 1.0 / self.num_inference_timesteps |
|
|
| for step in range(self.num_inference_timesteps): |
| t_cont = step / float(self.num_inference_timesteps) |
| t_disc_val = int(t_cont * self.num_timestep_buckets) |
| t_tensor = torch.full((B,), t_disc_val, device=device, dtype=torch.long) |
|
|
| adarms_cond = self._compute_adarms_cond(t_tensor) |
| action_tokens = self._prepare_action_embeds(actions, action_dof_mask) |
|
|
| output = self.dit( |
| action_tokens, vlm_kv_cache, adarms_cond, encoder_attention_mask, |
| ) |
| pred_velocity = self.action_out_proj(output)[:, :self.action_chunk_size] |
| actions = actions + dt * pred_velocity |
|
|
| return actions |