PRTS-4B / dit_action_head.py
breezeyoung's picture
Upload folder using huggingface_hub
4f07533 verified
"""
DiT (Diffusion Transformer) based flow matching action head for PRTS.
Replaces the Qwen3VLTextModel-based fm_action_expert with a lightweight DiT
that uses explicit cross-attention to VLM hidden states, following the architecture
from GR00T / pi05.
Architecture:
ActionEncoder(noisy_actions + dof_mask, timestep)
→ action_features
→ DiT(cross-attn to VLM hidden states, ada-norm timestep conditioning)
→ ActionDecoder → predicted velocity
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Beta
from typing import Optional
from transformers.cache_utils import Cache
from transformers.modeling_flash_attention_utils import _flash_attention_forward
# DIT_PRESETS = {
# "DiT-B": {"num_attention_heads": 12, "attention_head_dim": 64, "output_dim": 768},
# "DiT-L": {"num_attention_heads": 32, "attention_head_dim": 48, "output_dim": 1536},
# }
class SinusoidalPositionalEncoding(nn.Module):
"""Sinusoidal positional encoding for sequence positions or timesteps."""
def __init__(self, embedding_dim: int):
super().__init__()
self.embedding_dim = embedding_dim
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
timesteps = timesteps.float()
squeeze = False
if timesteps.dim() == 1:
timesteps = timesteps.unsqueeze(1)
squeeze = True
half_dim = self.embedding_dim // 2
exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device) * (
math.log(10000.0) / half_dim
)
freqs = timesteps.unsqueeze(-1) * exponent.exp()
enc = torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1)
if squeeze:
enc = enc.squeeze(1)
return enc
class TimestepEncoder(nn.Module):
"""Projects scalar timesteps to embedding space via sinusoidal encoding + MLP."""
def __init__(self, embedding_dim: int):
super().__init__()
self.sinusoidal = SinusoidalPositionalEncoding(256)
self.linear_1 = nn.Linear(256, embedding_dim)
self.act = nn.SiLU()
self.linear_2 = nn.Linear(embedding_dim, embedding_dim)
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
t_emb = self.sinusoidal(timesteps)
t_emb = self.linear_1(t_emb.to(dtype=self.linear_1.weight.dtype))
t_emb = self.act(t_emb)
t_emb = self.linear_2(t_emb)
return t_emb
class AdaLayerNorm(nn.Module):
"""Adaptive Layer Normalization conditioned on timestep embeddings.
Applies scale-shift modulation: out = norm(x) * (1 + scale) + shift,
where (scale, shift) are linearly projected from the timestep embedding.
"""
def __init__(self, embedding_dim: int, eps: float = 1e-5):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=False)
def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
temb = self.linear(self.silu(temb))
scale, shift = temb.chunk(2, dim=-1)
x = self.norm(x) * (1 + scale[:, None]) + shift[:, None]
return x
class DiTAttention(nn.Module):
"""Multi-head attention supporting both self-attention and cross-attention.
Supports two backends selected via ``attn_implementation``:
* ``"sdpa"`` (default) – uses :func:`F.scaled_dot_product_attention`, which
dispatches automatically to FlashAttention / memory-efficient attention
depending on the installed PyTorch build. The encoder padding mask is
expanded to ``(B, 1, 1, S)`` and passed as ``attn_mask``.
* ``"flash_attention_2"`` – calls the ``flash_attn`` package directly for
lower memory usage and higher throughput. For cross-attention with an
encoder padding mask the k/v tensors are unpadded and
:func:`flash_attn_varlen_func` is used so that padding tokens are never
processed. For self-attention (no mask) the simpler
:func:`flash_attn_func` is used.
"""
def __init__(
self,
query_dim: int,
num_heads: int,
head_dim: int,
cross_attention_dim: Optional[int] = None,
dropout: float = 0.0,
bias: bool = True,
attn_implementation: str = "sdpa",
):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.attn_implementation = attn_implementation
inner_dim = num_heads * head_dim
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
kv_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.to_k = nn.Linear(kv_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(kv_dim, inner_dim, bias=bias)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim, bias=bias),
nn.Dropout(dropout),
)
# ------------------------------------------------------------------
# Flash-Attention backend
# ------------------------------------------------------------------
def _flash_attn_forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attention_mask: Optional[torch.Tensor],
) -> torch.Tensor:
"""Run Flash Attention via HuggingFace's ``_flash_attention_forward``.
Args:
q: ``(B, T_q, H, D)``
k: ``(B, T_k, H, D)``
v: ``(B, T_k, H, D)``
attention_mask: ``(B, T_k)`` bool, True = valid token.
Returns:
``(B, T_q, H*D)``
"""
B, T_q, H, D = q.shape
# _flash_attention_forward returns (B, T_q, H, D); handles unpad/varlen internally.
out = _flash_attention_forward(
q, k, v,
attention_mask=attention_mask,
query_length=T_q,
is_causal=False,
dropout=0.0,
)
return out.reshape(B, T_q, H * D)
# ------------------------------------------------------------------
# Forward
# ------------------------------------------------------------------
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
B, T, _ = hidden_states.shape
q = self.to_q(hidden_states)
kv_input = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
k = self.to_k(kv_input)
v = self.to_v(kv_input)
if self.attn_implementation == "flash_attention_2":
# Flash Attention expects (B, S, H, D)
q = q.view(B, T, self.num_heads, self.head_dim)
k = k.view(B, -1, self.num_heads, self.head_dim)
v = v.view(B, -1, self.num_heads, self.head_dim)
attn_output = self._flash_attn_forward(q, k, v, attention_mask)
else:
# SDPA expects (B, H, S, D)
q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
# Expand (B, S) bool mask → (B, 1, 1, S) for broadcasting.
sdpa_mask = None
if attention_mask is not None:
if attention_mask.dim() == 2:
sdpa_mask = attention_mask[:, None, None, :]
else:
sdpa_mask = attention_mask
attn_output = F.scaled_dot_product_attention(
q, k, v, attn_mask=sdpa_mask, dropout_p=0.0
)
attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, -1)
return self.to_out(attn_output)
class FeedForward(nn.Module):
"""Feed-forward network with GELU activation."""
def __init__(self, dim: int, dropout: float = 0.0, mult: int = 4):
super().__init__()
inner_dim = dim * mult
self.net = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU(approximate="tanh"),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim),
nn.Dropout(dropout),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class BasicTransformerBlock(nn.Module):
"""Transformer block with self/cross-attention, optional AdaLayerNorm, and feed-forward.
When cross_attention_dim is set, the attention block performs cross-attention
to encoder_hidden_states. Otherwise, it performs self-attention.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout: float = 0.0,
cross_attention_dim: Optional[int] = None,
norm_type: str = "ada_norm",
final_dropout: bool = False,
attn_implementation: str = "sdpa",
):
super().__init__()
self.norm_type = norm_type
if norm_type == "ada_norm":
self.norm1 = AdaLayerNorm(dim)
else:
self.norm1 = nn.LayerNorm(dim)
self.attn1 = DiTAttention(
query_dim=dim,
num_heads=num_attention_heads,
head_dim=attention_head_dim,
cross_attention_dim=cross_attention_dim,
dropout=dropout,
attn_implementation=attn_implementation,
)
self.norm3 = nn.LayerNorm(dim)
self.ff = FeedForward(dim, dropout=dropout)
self.final_dropout = nn.Dropout(dropout) if final_dropout else None
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.norm_type == "ada_norm":
norm_hidden_states = self.norm1(hidden_states, temb)
else:
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
)
if self.final_dropout is not None:
attn_output = self.final_dropout(attn_output)
hidden_states = attn_output + hidden_states
norm_hidden_states = self.norm3(hidden_states)
ff_output = self.ff(norm_hidden_states)
hidden_states = ff_output + hidden_states
return hidden_states
class DiT(nn.Module):
"""Diffusion Transformer with cross-attention to VLM context features.
Interleaves cross-attention blocks (attending to encoder_hidden_states)
with self-attention blocks when interleave_self_attention=True.
Uses AdaLayerNorm for timestep conditioning throughout.
Output block applies timestep-conditioned scale-shift before final projection.
"""
def __init__(
self,
num_attention_heads: int = 12,
attention_head_dim: int = 64,
output_dim: int = 768,
num_layers: int = 12,
dropout: float = 0.1,
norm_type: str = "ada_norm",
final_dropout: bool = True,
interleave_self_attention: bool = False,
cross_attention_dim: Optional[int] = None,
attn_implementation: str = "sdpa",
):
super().__init__()
self.inner_dim = num_attention_heads * attention_head_dim
self.output_dim = output_dim
self.num_layers = num_layers
self.interleave_self_attention = interleave_self_attention
self.timestep_encoder = TimestepEncoder(self.inner_dim)
all_blocks = []
for idx in range(num_layers):
use_self_attn = idx % 2 == 1 and interleave_self_attention
curr_cross_attention_dim = cross_attention_dim if not use_self_attn else None
all_blocks.append(
BasicTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
cross_attention_dim=curr_cross_attention_dim,
norm_type=norm_type,
final_dropout=final_dropout,
attn_implementation=attn_implementation,
)
)
self.transformer_blocks = nn.ModuleList(all_blocks)
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
self.proj_out_2 = nn.Linear(self.inner_dim, output_dim)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
temb = self.timestep_encoder(timestep)
hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous()
for idx, block in enumerate(self.transformer_blocks):
if idx % 2 == 1 and self.interleave_self_attention:
hidden_states = block(
hidden_states,
encoder_hidden_states=None,
encoder_attention_mask=None,
temb=temb,
)
else:
hidden_states = block(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
temb=temb,
)
conditioning = temb
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=-1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
return self.proj_out_2(hidden_states)
class AlternateVLDiT(DiT):
"""DiT variant that separates visual and text tokens during cross-attention.
Mirrors GR00T's AlternateVLDiT: even-indexed blocks do cross-attention,
alternating every ``attend_text_every_n_blocks`` between text tokens and
visual tokens. Odd-indexed blocks do self-attention (requires
``interleave_self_attention=True``).
When no visual tokens are present (``image_mask`` is None or all-False),
all valid tokens are treated as text.
"""
def __init__(self, *args, attend_text_every_n_blocks: int = 2, **kwargs):
super().__init__(*args, **kwargs)
assert self.interleave_self_attention, (
"AlternateVLDiT requires interleave_self_attention=True"
)
self.attend_text_every_n_blocks = attend_text_every_n_blocks
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_attention_mask: Optional[torch.Tensor] = None,
image_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
encoder_attention_mask: (B, S) bool – True = valid VLM token.
image_mask: (B, S) bool – True = visual token position.
If None, all valid tokens are treated as text.
"""
temb = self.timestep_encoder(timestep)
hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous()
B, S, _ = encoder_hidden_states.shape
backbone_mask = (
encoder_attention_mask.bool()
if encoder_attention_mask is not None
else torch.ones(B, S, dtype=torch.bool, device=hidden_states.device)
)
if image_mask is not None and image_mask.any():
vis_mask = image_mask.bool() & backbone_mask # visual tokens
text_mask = (~image_mask.bool()) & backbone_mask # text tokens
else:
# No visual tokens – treat everything as text.
vis_mask = torch.zeros_like(backbone_mask)
text_mask = backbone_mask
for idx, block in enumerate(self.transformer_blocks):
if idx % 2 == 1:
# Self-attention block.
hidden_states = block(
hidden_states,
encoder_hidden_states=None,
encoder_attention_mask=None,
temb=temb,
)
else:
# Cross-attention block: alternate text / visual every N blocks.
if idx % (2 * self.attend_text_every_n_blocks) == 0:
curr_mask = text_mask
else:
curr_mask = vis_mask
hidden_states = block(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=curr_mask,
temb=temb,
)
conditioning = temb
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=-1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
return self.proj_out_2(hidden_states)
class ActionEncoder(nn.Module):
"""Encodes noisy actions (optionally concatenated with DOF mask) and timestep
into hidden features via MLP + sinusoidal time encoding.
Architecture: Linear → concat(action_emb, time_emb) → SiLU + Linear → Linear
"""
def __init__(self, action_input_dim: int, hidden_size: int):
super().__init__()
self.hidden_size = hidden_size
self.layer1 = nn.Linear(action_input_dim, hidden_size)
self.layer2 = nn.Linear(2 * hidden_size, hidden_size)
self.layer3 = nn.Linear(hidden_size, hidden_size)
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
def forward(self, actions: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
"""
Args:
actions: (B, T, action_input_dim) noisy actions (+ DOF mask)
timesteps: (B,) discretized timesteps
"""
B, T, _ = actions.shape
timesteps_expanded = timesteps.unsqueeze(1).expand(-1, T)
a_emb = self.layer1(actions)
tau_emb = self.pos_encoding(timesteps_expanded).to(dtype=a_emb.dtype)
x = torch.cat([a_emb, tau_emb], dim=-1)
x = F.silu(self.layer2(x))
x = self.layer3(x)
return x
class ActionDecoder(nn.Module):
"""2-layer MLP that decodes DiT output to action-space velocity."""
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
super().__init__()
self.layer1 = nn.Linear(input_dim, hidden_dim)
self.layer2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layer2(F.relu(self.layer1(x)))
class FlowMatchingDiTHead(nn.Module):
"""Flow matching action head using DiT (Diffusion Transformer).
Replaces the fm_action_expert (Qwen3VLTextModel-based) with a DiT that uses
explicit cross-attention to VLM hidden states instead of KV cache continuation.
Training:
1. Sample noise and timestep from Beta distribution
2. Compute noisy trajectory: x_t = (1-t)*noise + t*actions
3. Compute velocity target: v = actions - noise
4. Encode noisy actions + DOF mask + timestep → action features
5. Prepend learned future query tokens
6. Run DiT with cross-attention to VLM hidden states
7. Decode to action-space velocity prediction
Inference:
Euler integration from pure noise (t=0) to clean actions (t=1)
over num_inference_timesteps steps.
"""
def __init__(
self,
action_dim: int,
action_chunk_size: int,
cross_attention_dim: int,
num_inference_timesteps: int = 4,
config: Optional[dict] = None,
):
super().__init__()
cfg = {
"num_layers": 16,
"num_attention_heads": 12,
"attention_head_dim": 64,
"output_dim": 1024,
"dropout": 0.2,
"interleave_self_attention": True,
"norm_type": "ada_norm",
"final_dropout": True,
"add_pos_embed": True,
"noise_beta_alpha": 1.5,
"noise_beta_beta": 1.0,
"noise_s": 0.999,
"num_timestep_buckets": 1000,
"attn_implementation": "sdpa",
"use_alternate_vl_dit": False,
"attend_text_every_n_blocks": 2,
}
if config is not None:
cfg.update(config)
# dit_model_type = config.get("dit_model_type")
# if dit_model_type and dit_model_type in DIT_PRESETS:
# cfg.update(DIT_PRESETS[dit_model_type])
# cfg.pop("dit_model_type", None)
self.action_dim = action_dim
self.action_chunk_size = action_chunk_size
self.num_inference_timesteps = num_inference_timesteps
self.num_timestep_buckets = cfg["num_timestep_buckets"]
self.noise_s = cfg["noise_s"]
self.use_alternate_vl_dit = cfg["use_alternate_vl_dit"]
self.add_pos_embed = cfg["add_pos_embed"]
num_attention_heads = cfg["num_attention_heads"]
attention_head_dim = cfg["attention_head_dim"]
output_dim = cfg["output_dim"]
inner_dim = num_attention_heads * attention_head_dim
dit_kwargs = dict(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
output_dim=output_dim,
num_layers=cfg["num_layers"],
dropout=cfg["dropout"],
norm_type=cfg["norm_type"],
final_dropout=cfg["final_dropout"],
interleave_self_attention=cfg["interleave_self_attention"],
cross_attention_dim=cross_attention_dim,
attn_implementation=cfg["attn_implementation"],
)
if self.use_alternate_vl_dit:
self.dit = AlternateVLDiT(
**dit_kwargs,
attend_text_every_n_blocks=cfg["attend_text_every_n_blocks"],
)
else:
self.dit = DiT(**dit_kwargs)
# action_dim * 2: noisy action + DOF mask concatenated
self.action_encoder = ActionEncoder(action_dim * 2, inner_dim)
self.action_decoder = ActionDecoder(output_dim, inner_dim, action_dim)
if self.add_pos_embed:
max_seq_len = max(action_chunk_size, 256)
self.position_embedding = nn.Embedding(max_seq_len, inner_dim)
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
# self.beta_dist = Beta(cfg["noise_beta_alpha"], cfg["noise_beta_beta"])
self._beta_alpha = cfg["noise_beta_alpha"]
self._beta_beta = cfg["noise_beta_beta"]
def reset_parameters(self):
"""Re-apply proper initialization.
HuggingFace from_pretrained calls _init_weights on modules whose
parameters are absent from the checkpoint, overwriting any custom
init done in __init__. Call this after from_pretrained when loading
from a base VLM checkpoint that does not contain DiT weights.
"""
if self.add_pos_embed:
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
if module.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(module.bias, -bound, bound)
elif isinstance(module, nn.LayerNorm):
if module.elementwise_affine:
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
def sample_time(self, batch_size: int, device, dtype) -> torch.Tensor:
beta_dist = Beta(self._beta_alpha, self._beta_beta)
sample = beta_dist.sample([batch_size]).to(device, dtype=dtype).clamp(max=self.noise_s)
return (self.noise_s - sample) / self.noise_s
def _encode_actions(
self,
noisy_actions: torch.Tensor,
t_discretized: torch.Tensor,
action_dof_mask: Optional[torch.Tensor],
device,
) -> torch.Tensor:
"""Encode noisy actions with DOF mask and timestep, add position embeddings."""
if action_dof_mask is not None:
encoder_input = torch.cat(
[noisy_actions, action_dof_mask.to(noisy_actions.dtype)], dim=-1
)
else:
encoder_input = torch.cat(
[noisy_actions, torch.ones_like(noisy_actions)], dim=-1
)
action_features = self.action_encoder(encoder_input, t_discretized)
if self.add_pos_embed:
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
action_features = action_features + pos_embs
return action_features
def _dit_forward(
self,
sa_embs: torch.Tensor,
vl_embs: torch.Tensor,
t_discretized: torch.LongTensor,
encoder_attention_mask: Optional[torch.Tensor],
image_mask: Optional[torch.Tensor],
) -> torch.Tensor:
if self.use_alternate_vl_dit:
return self.dit(
hidden_states=sa_embs,
encoder_hidden_states=vl_embs,
timestep=t_discretized,
encoder_attention_mask=encoder_attention_mask,
image_mask=image_mask,
)
return self.dit(
hidden_states=sa_embs,
encoder_hidden_states=vl_embs,
timestep=t_discretized,
encoder_attention_mask=encoder_attention_mask,
)
def forward(
self,
vl_embs: torch.Tensor,
actions: torch.Tensor,
action_dof_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
image_mask: Optional[torch.Tensor] = None,
) -> tuple:
"""Training forward pass.
Args:
vl_embs: (B, S, D) VLM hidden states for cross-attention
actions: (B, T, action_dim) ground truth action trajectories
action_dof_mask: (B, T, action_dim) DOF validity mask
encoder_attention_mask: (B, S) bool – True = valid VLM token
image_mask: (B, S) bool – True = visual token (used by AlternateVLDiT)
Returns:
(pred_v, velocity): predicted velocity and target velocity, both (B, T, action_dim)
"""
device = vl_embs.device
B = actions.shape[0]
noise = torch.randn(actions.shape, device=device, dtype=actions.dtype)
t = self.sample_time(B, device=device, dtype=actions.dtype)
t_expanded = t[:, None, None]
noisy_trajectory = (1 - t_expanded) * noise + t_expanded * actions
velocity = actions - noise
t_discretized = (t * self.num_timestep_buckets).long()
action_features = self._encode_actions(noisy_trajectory, t_discretized, action_dof_mask, device)
model_output = self._dit_forward(
action_features, vl_embs, t_discretized, encoder_attention_mask, image_mask
)
pred = self.action_decoder(model_output)
pred_v = pred[:, :actions.shape[1]]
return pred_v, velocity
@torch.no_grad()
def predict_action(
self,
vl_embs: torch.Tensor,
action_dof_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
image_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Inference: denoise actions from noise using Euler integration.
Args:
vl_embs: (B, S, D) VLM hidden states
action_dof_mask: optional (B, T, action_dim) or (1, T, action_dim) DOF mask
encoder_attention_mask: (B, S) bool – True = valid VLM token
image_mask: (B, S) bool – True = visual token (used by AlternateVLDiT)
Returns:
(B, T, action_dim) denoised action trajectories
"""
B = vl_embs.shape[0]
device = vl_embs.device
dtype = vl_embs.dtype
actions = torch.randn(
(B, self.action_chunk_size, self.action_dim),
device=device, dtype=dtype,
)
dt = 1.0 / self.num_inference_timesteps
for step in range(self.num_inference_timesteps):
t_cont = step / float(self.num_inference_timesteps)
t_discretized_val = int(t_cont * self.num_timestep_buckets)
timesteps_tensor = torch.full((B,), t_discretized_val, device=device, dtype=torch.long)
action_features = self._encode_actions(actions, timesteps_tensor, action_dof_mask, device)
model_output = self._dit_forward(
action_features, vl_embs, timesteps_tensor, encoder_attention_mask, image_mask
)
pred = self.action_decoder(model_output)
pred_velocity = pred[:, :self.action_chunk_size]
actions = actions + dt * pred_velocity
return actions
# ============================================================================
# Pi0.5-style KV-cache action expert (VLM K/V concat + GQA + SwiGLU FFN)
# ============================================================================
class AdaRMSNorm(nn.Module):
"""Adaptive RMS normalization: (scale, shift, gate) from cond; zero-init."""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.modulation = nn.Linear(dim, dim * 3)
nn.init.zeros_(self.modulation.weight)
nn.init.zeros_(self.modulation.bias)
def forward(self, x: torch.Tensor, cond: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
var = x.float().pow(2).mean(-1, keepdim=True)
normed = (x * torch.rsqrt(var + self.eps)).to(x.dtype)
scale, shift, gate = self.modulation(cond).chunk(3, dim=-1)
normed = normed * (1 + scale[:, None]) + shift[:, None]
return normed, gate[:, None]
class SwiGLUFeedForward(nn.Module):
"""SiLU(gate_proj(x)) * up_proj(x) → down_proj."""
def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.0, bias: bool = True):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=bias)
self.up_proj = nn.Linear(dim, hidden_dim, bias=bias)
self.down_proj = nn.Linear(hidden_dim, dim, bias=bias)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.dropout(F.silu(self.gate_proj(x)) * self.up_proj(x)))
class MoTAttention(nn.Module):
"""Action Q attends to concatenated [VLM KV cache ; action KV]; GQA expand for SDPA."""
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
num_kv_heads: int,
head_dim: int,
dropout: float = 0.0,
bias: bool = True,
):
super().__init__()
if num_attention_heads % num_kv_heads != 0:
raise ValueError(
f"num_attention_heads ({num_attention_heads}) must be divisible by "
f"num_kv_heads ({num_kv_heads})"
)
self.num_attention_heads = num_attention_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
q_dim = num_attention_heads * head_dim
kv_dim = num_kv_heads * head_dim
self.q_proj = nn.Linear(hidden_size, q_dim, bias=bias)
self.k_proj = nn.Linear(hidden_size, kv_dim, bias=bias)
self.v_proj = nn.Linear(hidden_size, kv_dim, bias=bias)
self.o_proj = nn.Linear(q_dim, hidden_size, bias=bias)
self.dropout = nn.Dropout(dropout)
def forward(
self,
action_hidden: torch.Tensor,
vlm_cached_k: torch.Tensor,
vlm_cached_v: torch.Tensor,
vlm_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
B, T_a, _ = action_hidden.shape
q = self.q_proj(action_hidden)
act_k = self.k_proj(action_hidden)
act_v = self.v_proj(action_hidden)
q = q.view(B, T_a, self.num_attention_heads, self.head_dim).transpose(1, 2)
act_k = act_k.view(B, T_a, self.num_kv_heads, self.head_dim).transpose(1, 2)
act_v = act_v.view(B, T_a, self.num_kv_heads, self.head_dim).transpose(1, 2)
k = torch.cat([vlm_cached_k, act_k], dim=2)
v = torch.cat([vlm_cached_v, act_v], dim=2)
repeat_factor = self.num_attention_heads // self.num_kv_heads
k = k.repeat_interleave(repeat_factor, dim=1)
v = v.repeat_interleave(repeat_factor, dim=1)
sdpa_mask = None
if vlm_attention_mask is not None:
action_mask = vlm_attention_mask.new_ones(B, T_a)
combined_mask = torch.cat([vlm_attention_mask, action_mask], dim=1)
sdpa_mask = combined_mask[:, None, None, :]
attn_out = F.scaled_dot_product_attention(
q, k, v, attn_mask=sdpa_mask, dropout_p=0.0,
)
attn_out = attn_out.transpose(1, 2).contiguous().view(B, T_a, -1)
return self.dropout(self.o_proj(attn_out))
class MoTBlock(nn.Module):
"""AdaRMSNorm → attention → gated residual → AdaRMSNorm → SwiGLU FFN → gated residual."""
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
num_kv_heads: int,
head_dim: int,
intermediate_size: int,
dropout: float = 0.0,
):
super().__init__()
self.pre_attn_norm = AdaRMSNorm(hidden_size)
self.attn = MoTAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dropout=dropout,
)
self.pre_ffn_norm = AdaRMSNorm(hidden_size)
self.ffn = SwiGLUFeedForward(hidden_size, intermediate_size, dropout=dropout)
def forward(
self,
action_hidden: torch.Tensor,
vlm_cached_k: torch.Tensor,
vlm_cached_v: torch.Tensor,
adarms_cond: torch.Tensor,
vlm_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
normed, gate1 = self.pre_attn_norm(action_hidden, adarms_cond)
attn_out = self.attn(normed, vlm_cached_k, vlm_cached_v, vlm_attention_mask)
action_hidden = action_hidden + attn_out * gate1
normed2, gate2 = self.pre_ffn_norm(action_hidden, adarms_cond)
action_hidden = action_hidden + self.ffn(normed2) * gate2
return action_hidden
class MoTDiT(nn.Module):
"""Stack of ActionBlocks; each block uses one VLM layer's KV pair."""
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
num_kv_heads: int,
head_dim: int,
intermediate_size: int,
num_layers: int,
dropout: float = 0.2,
):
super().__init__()
self.num_layers = num_layers
self.blocks = nn.ModuleList([
MoTBlock(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
intermediate_size=intermediate_size,
dropout=dropout,
)
for _ in range(num_layers)
])
self.final_norm = AdaRMSNorm(hidden_size)
def forward(
self,
action_hidden: torch.Tensor,
vlm_kv_cache: list,
adarms_cond: torch.Tensor,
vlm_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
for idx, block in enumerate(self.blocks):
cached_k, cached_v = vlm_kv_cache[idx]
action_hidden = block(
action_hidden, cached_k, cached_v, adarms_cond, vlm_attention_mask,
)
action_hidden, _ = self.final_norm(action_hidden, adarms_cond)
return action_hidden
def _kv_pairs_from_past_key_values(past_key_values: Cache) -> list[tuple[torch.Tensor, torch.Tensor]]:
"""Per-layer (K, V) from a HuggingFace decoder KV cache (order matches transformer layers)."""
return [
(past_key_values[i][0], past_key_values[i][1])
for i in range(len(past_key_values))
]
class MoTFlowMatchingHead(nn.Module):
"""Flow matching head: MoT-style action expert over VLM KV cache (concat + GQA)."""
def __init__(
self,
action_dim: int,
action_chunk_size: int,
vlm_config,
num_inference_timesteps: int = 10,
config: Optional[dict] = None,
):
super().__init__()
_vlm_num_q_heads = 8 # vlm_config.num_attention_heads // 2 # optional: 8
_vlm_num_kv_heads = vlm_config.num_key_value_heads # 8
_vlm_head_dim = getattr(
vlm_config, "head_dim", vlm_config.hidden_size // vlm_config.num_attention_heads
) # 128
cfg = {
"hidden_size": 1024, # vlm_config.hidden_size // 2,
# "hidden_size": vlm_config.hidden_size // 2,
"intermediate_size": vlm_config.intermediate_size // 4,
"expert_num_layers": vlm_config.num_hidden_layers,
# Attention dims default to VLM values (required for KV cache compat)
"num_attention_heads": _vlm_num_q_heads,
"num_kv_heads": _vlm_num_kv_heads,
"head_dim": _vlm_head_dim,
# Noise schedule
"dropout": 0.2,
"add_pos_embed": True,
"noise_beta_alpha": 1.5,
"noise_beta_beta": 1.0,
"noise_s": 0.999,
"num_timestep_buckets": 1000,
}
if config is not None:
config = cfg.copy()
num_attention_heads = cfg["num_attention_heads"]
num_kv_heads = cfg["num_kv_heads"]
head_dim = cfg["head_dim"]
hidden_size = cfg["hidden_size"]
intermediate_size = cfg["intermediate_size"]
num_layers = cfg["expert_num_layers"]
self.action_dim = action_dim
self.action_chunk_size = action_chunk_size
self.num_inference_timesteps = num_inference_timesteps
self.num_timestep_buckets = cfg["num_timestep_buckets"]
self.noise_s = cfg["noise_s"]
self.add_pos_embed = cfg["add_pos_embed"]
self.action_in_proj = nn.Linear(action_dim * 2, hidden_size)
self.action_out_proj = nn.Linear(hidden_size, action_dim)
self.time_sinusoidal = SinusoidalPositionalEncoding(hidden_size)
self.time_mlp_1 = nn.Linear(hidden_size, hidden_size)
self.time_mlp_2 = nn.Linear(hidden_size, hidden_size)
if self.add_pos_embed:
max_seq = max(action_chunk_size, 256)
self.position_embedding = nn.Embedding(max_seq, hidden_size)
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
self.dit = MoTDiT(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
intermediate_size=intermediate_size,
num_layers=num_layers,
dropout=cfg["dropout"],
)
self._beta_alpha = cfg["noise_beta_alpha"]
self._beta_beta = cfg["noise_beta_beta"]
@property
def num_dit_layers(self) -> int:
"""Number of expert blocks; must match ``len(past_key_values.key_cache)``."""
return self.dit.num_layers
def _vlm_kv_list_from_past(self, past_key_values: Cache) -> list[tuple[torch.Tensor, torch.Tensor]]:
n = len(past_key_values)
if n != self.num_dit_layers:
raise ValueError(
f"MoT expert has {self.num_dit_layers} blocks but `past_key_values` has {n} "
"layers. Set `dit_action_head_config['expert_num_layers']` to match "
"`text_config.num_hidden_layers`."
)
return _kv_pairs_from_past_key_values(past_key_values)
def reset_parameters(self):
"""Re-apply proper initialization after from_pretrained."""
if self.add_pos_embed:
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
for module in self.modules():
if isinstance(module, AdaRMSNorm):
nn.init.zeros_(module.modulation.weight)
nn.init.zeros_(module.modulation.bias)
elif isinstance(module, nn.Linear):
nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
if module.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(module.bias, -bound, bound)
def _compute_adarms_cond(self, t_discretized: torch.Tensor) -> torch.Tensor:
t_emb = self.time_sinusoidal(t_discretized.float())
t_emb = t_emb.to(dtype=self.time_mlp_1.weight.dtype)
t_emb = F.silu(self.time_mlp_1(t_emb))
t_emb = F.silu(self.time_mlp_2(t_emb))
return t_emb
def sample_time(self, batch_size: int, device, dtype) -> torch.Tensor:
beta_dist = Beta(self._beta_alpha, self._beta_beta)
sample = beta_dist.sample([batch_size]).to(device, dtype=dtype).clamp(max=self.noise_s)
return (self.noise_s - sample) / self.noise_s
def _prepare_action_embeds(
self,
noisy_actions: torch.Tensor,
action_dof_mask: Optional[torch.Tensor],
) -> torch.Tensor:
if action_dof_mask is not None:
x = torch.cat(
[noisy_actions, action_dof_mask.to(noisy_actions.dtype)], dim=-1,
)
else:
x = torch.cat([noisy_actions, torch.ones_like(noisy_actions)], dim=-1)
tokens = self.action_in_proj(x)
if self.add_pos_embed:
pos_ids = torch.arange(tokens.shape[1], dtype=torch.long, device=noisy_actions.device)
tokens = tokens + self.position_embedding(pos_ids).unsqueeze(0)
return tokens
def forward(
self,
past_key_values: Cache,
actions: torch.Tensor,
action_dof_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
) -> tuple:
"""Training: returns (pred_velocity, target_velocity).
Args:
past_key_values: VLM decoder KV cache; layer count must equal ``num_dit_layers``.
"""
vlm_kv_cache = self._vlm_kv_list_from_past(past_key_values)
device = actions.device
B = actions.shape[0]
noise = torch.randn(actions.shape, device=device, dtype=actions.dtype)
t = self.sample_time(B, device=device, dtype=actions.dtype)
t_expanded = t[:, None, None]
noisy_trajectory = (1 - t_expanded) * noise + t_expanded * actions
velocity = actions - noise
t_discretized = (t * self.num_timestep_buckets).long()
adarms_cond = self._compute_adarms_cond(t_discretized)
action_tokens = self._prepare_action_embeds(noisy_trajectory, action_dof_mask)
output = self.dit(
action_tokens, vlm_kv_cache, adarms_cond, encoder_attention_mask,
)
pred = self.action_out_proj(output)
pred_v = pred[:, :actions.shape[1]]
return pred_v, velocity
def compute_velocity(
self,
past_key_values: Cache,
actions: torch.Tensor,
noise: torch.Tensor,
t: torch.Tensor,
action_dof_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Compute velocity prediction for pre-sampled noise and timestep.
Used by DiffusionNFT where noise and timestep must be shared between
the current policy (v_θ) and the reference policy (v_old).
Args:
past_key_values: VLM decoder KV cache
actions: (B, T, action_dim) ground truth actions (x_0)
noise: (B, T, action_dim) pre-sampled noise (ε)
t: (B,) continuous timesteps in [0, 1)
action_dof_mask, encoder_attention_mask,
Returns:
pred_v: (B, T, action_dim) predicted velocity
"""
vlm_kv_cache = self._vlm_kv_list_from_past(past_key_values)
device = actions.device
t_expanded = t[:, None, None]
noisy_trajectory = (1 - t_expanded) * noise + t_expanded * actions
t_discretized = (t * self.num_timestep_buckets).long()
adarms_cond = self._compute_adarms_cond(t_discretized)
action_tokens = self._prepare_action_embeds(noisy_trajectory, action_dof_mask)
output = self.dit(
action_tokens, vlm_kv_cache, adarms_cond, encoder_attention_mask,
)
pred = self.action_out_proj(output)
return pred[:, :actions.shape[1]]
@torch.no_grad()
def predict_action(
self,
past_key_values: Cache,
action_dof_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Inference: Euler integration, returns (B, chunk_size, action_dim)."""
k0 = past_key_values[0][0]
B = k0.shape[0]
device = k0.device
dtype = k0.dtype
vlm_kv_cache = self._vlm_kv_list_from_past(past_key_values)
actions = torch.randn(
(B, self.action_chunk_size, self.action_dim),
device=device, dtype=dtype,
)
dt = 1.0 / self.num_inference_timesteps
for step in range(self.num_inference_timesteps):
t_cont = step / float(self.num_inference_timesteps)
t_disc_val = int(t_cont * self.num_timestep_buckets)
t_tensor = torch.full((B,), t_disc_val, device=device, dtype=torch.long)
adarms_cond = self._compute_adarms_cond(t_tensor)
action_tokens = self._prepare_action_embeds(actions, action_dof_mask)
output = self.dit(
action_tokens, vlm_kv_cache, adarms_cond, encoder_attention_mask,
)
pred_velocity = self.action_out_proj(output)[:, :self.action_chunk_size]
actions = actions + dt * pred_velocity
return actions