| """Modeling code for MolmoAct2""" |
|
|
| import json |
| import math |
| import os |
| import re |
| from copy import deepcopy |
| from dataclasses import dataclass |
| from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union |
|
|
| import numpy as np |
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
| from torch.nn.attention import SDPBackend, sdpa_kernel |
| from transformers.activations import ACT2FN |
| from transformers.cache_utils import Cache, DynamicCache |
| from transformers.configuration_utils import PretrainedConfig |
| from transformers.generation import GenerationMixin |
| from transformers.masking_utils import create_causal_mask, create_masks_for_generate |
| from transformers.modeling_flash_attention_utils import ( |
| FlashAttentionKwargs, |
| _flash_attention_forward, |
| flash_attn_supports_top_left_mask, |
| ) |
| from transformers.modeling_layers import GradientCheckpointingLayer |
| from transformers.modeling_outputs import ( |
| BaseModelOutputWithPast, |
| ) |
| from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
| from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
| from transformers.models.auto import AutoModelForImageTextToText |
| from transformers.processing_utils import Unpack |
| from transformers.utils import ( |
| ModelOutput, |
| TransformersKwargs, |
| can_return_tuple, |
| logging, |
| ) |
|
|
| from .configuration_molmoact2 import ( |
| MolmoAct2ActionExpertConfig, |
| MolmoAct2AdapterConfig, |
| MolmoAct2Config, |
| MolmoAct2TextConfig, |
| MolmoAct2VitConfig, |
| ) |
| from .inference import ( |
| ActionCudaGraphManager, |
| DepthDecodeCudaGraphManager, |
| _ActionFlowInputs, |
| _cache_max_len_int, |
| _cache_seq_len_int, |
| _iter_cache_key_values, |
| ) |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| ACTION_START_TOKEN = "<action_start>" |
| ACTION_END_TOKEN = "<action_end>" |
| ACTION_OUTPUT_TOKEN = "<action_output>" |
| STATE_START_TOKEN = "<state_start>" |
| STATE_END_TOKEN = "<state_end>" |
| STATE_TOKEN_PREFIX = "<state_" |
| DEPTH_START_TOKEN = "<depth_start>" |
| DEPTH_END_TOKEN = "<depth_end>" |
| DEPTH_OUTPUT_TOKEN = "<depth_output>" |
| DEPTH_TOKEN_PREFIX = "<depth_" |
| SETUP_START_TOKEN = "<setup_start>" |
| SETUP_END_TOKEN = "<setup_end>" |
| CONTROL_START_TOKEN = "<control_start>" |
| CONTROL_END_TOKEN = "<control_end>" |
|
|
| _QUESTION_TRAILING_SENTENCE_PUNCTUATION = ".,!?;:,…" |
| _QUESTION_TRAILING_CLOSERS = "\"'”’)]}" |
| _QUESTION_SURROUNDING_DELIMITERS = "\"'`“”‘’[](){}" |
| _QUESTION_PREFIX_PATTERNS = tuple( |
| re.compile(pattern, flags=re.IGNORECASE) |
| for pattern in ( |
| r"^(?:task|instruction|language[_ ]instruction|goal)\s*[:\-]\s*", |
| r"^(?:the\s+task\s+is\s+to|your\s+task\s+is\s+to)\s+", |
| ) |
| ) |
|
|
| _DEPTH_REASONING_PATCH_SIZE = 32 |
| _DEPTH_REASONING_THRESHOLD = 0.996 |
|
|
|
|
| def _modulate( |
| x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor |
| ) -> torch.Tensor: |
| return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) |
|
|
|
|
| def _round_up_multiple(value: int, multiple_of: int) -> int: |
| if multiple_of <= 0: |
| return value |
| return int(math.ceil(value / multiple_of) * multiple_of) |
|
|
|
|
| def _init_linear(linear: nn.Linear, *, zero: bool = False, scale: float = 1.0) -> None: |
| if zero: |
| nn.init.zeros_(linear.weight) |
| else: |
| nn.init.xavier_uniform_(linear.weight) |
| if scale != 1.0: |
| with torch.no_grad(): |
| linear.weight.mul_(scale) |
| if linear.bias is not None: |
| nn.init.zeros_(linear.bias) |
|
|
|
|
| @dataclass |
| class ActionExpertContext: |
| kv_contexts: Sequence[Tuple[torch.Tensor, torch.Tensor]] |
| cross_mask: Optional[torch.Tensor] |
| self_mask: Optional[torch.Tensor] |
| valid_action: Optional[torch.Tensor] |
| rope_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None |
|
|
|
|
| @dataclass |
| class ActionExpertStepModulation: |
| conditioning: torch.Tensor |
| block_modulations: Sequence[Tuple[torch.Tensor, ...]] |
| final_modulation: Tuple[torch.Tensor, torch.Tensor] |
|
|
|
|
| class ActionExpertRMSNorm(nn.Module): |
| def __init__( |
| self, |
| size: int, |
| *, |
| eps: float = 1e-6, |
| elementwise_affine: bool = False, |
| device=None, |
| ) -> None: |
| super().__init__() |
| self.size = size |
| self.eps = eps |
| if elementwise_affine: |
| self.weight = nn.Parameter(torch.ones(size, device=device)) |
| else: |
| self.register_parameter("weight", None) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| with torch.autocast(enabled=False, device_type=x.device.type): |
| dtype = x.dtype |
| x_float = x.to(torch.float32) |
| variance = x_float.pow(2).mean(dim=-1, keepdim=True) |
| out = x_float * torch.rsqrt(variance + self.eps) |
| out = out.to(dtype) |
| if self.weight is not None: |
| out = out * self.weight |
| return out |
|
|
| def reset_parameters(self) -> None: |
| if self.weight is not None: |
| nn.init.ones_(self.weight) |
|
|
|
|
| class ActionExpertRotaryEmbedding(nn.Module): |
| def __init__(self, head_dim: int, base: float = 10000.0) -> None: |
| super().__init__() |
| if head_dim % 2 != 0: |
| raise ValueError("RoPE requires an even head_dim.") |
| self.head_dim = head_dim |
| self.base = base |
|
|
| def build_cache( |
| self, |
| *, |
| seq_len: int, |
| device: torch.device, |
| dtype: torch.dtype, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| half_dim = self.head_dim // 2 |
| inv_freq = 1.0 / ( |
| self.base |
| ** ( |
| torch.arange(0, half_dim, device=device, dtype=torch.float32) |
| / max(half_dim, 1) |
| ) |
| ) |
| positions = torch.arange(seq_len, device=device, dtype=torch.float32) |
| freqs = torch.outer(positions, inv_freq) |
| cos = freqs.cos().to(dtype=dtype).view(1, 1, seq_len, half_dim) |
| sin = freqs.sin().to(dtype=dtype).view(1, 1, seq_len, half_dim) |
| return cos, sin |
|
|
| def forward( |
| self, |
| q: torch.Tensor, |
| k: torch.Tensor, |
| *, |
| rope_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| if rope_cache is None: |
| rope_cache = self.build_cache( |
| seq_len=q.shape[-2], device=q.device, dtype=q.dtype |
| ) |
| cos, sin = rope_cache |
| half_dim = self.head_dim // 2 |
|
|
| def _apply(x: torch.Tensor) -> torch.Tensor: |
| x1, x2 = x[..., :half_dim], x[..., half_dim:] |
| return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1) |
|
|
| return _apply(q), _apply(k) |
|
|
|
|
| class ActionExpertSelfAttention(nn.Module): |
| def __init__( |
| self, |
| hidden_size: int, |
| num_heads: int, |
| *, |
| attn_dropout: float = 0.0, |
| proj_dropout: float = 0.0, |
| qk_norm: bool = True, |
| qk_norm_eps: float = 1e-6, |
| use_rope: bool = True, |
| ) -> None: |
| super().__init__() |
| if hidden_size % num_heads != 0: |
| raise ValueError("hidden_size must be divisible by num_heads") |
| self.hidden_size = hidden_size |
| self.num_heads = num_heads |
| self.head_dim = hidden_size // num_heads |
| self.attn_dropout = attn_dropout |
| self.q_norm = ( |
| ActionExpertRMSNorm(self.head_dim, eps=qk_norm_eps) if qk_norm else None |
| ) |
| self.k_norm = ( |
| ActionExpertRMSNorm(self.head_dim, eps=qk_norm_eps) if qk_norm else None |
| ) |
| self.rope = ActionExpertRotaryEmbedding(self.head_dim) if use_rope else None |
| self.qkv = nn.Linear(hidden_size, hidden_size * 3) |
| self.out_proj = nn.Linear(hidden_size, hidden_size) |
| self.out_drop = nn.Dropout(proj_dropout) |
|
|
| def _apply_qk_norm( |
| self, q: torch.Tensor, k: torch.Tensor |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| if self.q_norm is None or self.k_norm is None: |
| return q, k |
| return self.q_norm(q), self.k_norm(k) |
|
|
| def _attention( |
| self, |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| *, |
| attn_mask: Optional[torch.Tensor] = None, |
| is_causal: bool = False, |
| ) -> torch.Tensor: |
| dropout_p = self.attn_dropout if self.training else 0.0 |
| out = F.scaled_dot_product_attention( |
| q.transpose(1, 2), |
| k.transpose(1, 2), |
| v.transpose(1, 2), |
| attn_mask=attn_mask, |
| dropout_p=dropout_p, |
| is_causal=is_causal, |
| ) |
| return out.transpose(1, 2).contiguous() |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| *, |
| attn_mask: Optional[torch.Tensor] = None, |
| is_causal: bool = False, |
| rope_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| ) -> torch.Tensor: |
| bsz, seq_len, _ = x.shape |
| qkv = self.qkv(x).view(bsz, seq_len, 3, self.num_heads, self.head_dim) |
| q = qkv[:, :, 0].transpose(1, 2) |
| k = qkv[:, :, 1].transpose(1, 2) |
| v = qkv[:, :, 2].contiguous() |
| q, k = self._apply_qk_norm(q, k) |
| if self.rope is not None: |
| q, k = self.rope(q, k, rope_cache=rope_cache) |
| q = q.transpose(1, 2) |
| k = k.transpose(1, 2) |
| out = self._attention(q, k, v, attn_mask=attn_mask, is_causal=is_causal) |
| out = out.reshape(bsz, seq_len, self.hidden_size) |
| return self.out_drop(self.out_proj(out)) |
|
|
|
|
| class ActionExpertCrossAttention(nn.Module): |
| def __init__( |
| self, |
| hidden_size: int, |
| num_heads: int, |
| *, |
| attn_dropout: float = 0.0, |
| proj_dropout: float = 0.0, |
| qk_norm: bool = True, |
| qk_norm_eps: float = 1e-6, |
| ) -> None: |
| super().__init__() |
| if hidden_size % num_heads != 0: |
| raise ValueError("hidden_size must be divisible by num_heads") |
| self.hidden_size = hidden_size |
| self.num_heads = num_heads |
| self.head_dim = hidden_size // num_heads |
| self.attn_dropout = attn_dropout |
| self.q_norm = ( |
| ActionExpertRMSNorm(self.head_dim, eps=qk_norm_eps) if qk_norm else None |
| ) |
| self.k_norm = ( |
| ActionExpertRMSNorm(self.head_dim, eps=qk_norm_eps) if qk_norm else None |
| ) |
| self.q_proj = nn.Linear(hidden_size, hidden_size) |
| self.kv_proj = nn.Linear(hidden_size, hidden_size * 2) |
| self.out_proj = nn.Linear(hidden_size, hidden_size) |
| self.out_drop = nn.Dropout(proj_dropout) |
|
|
| def _apply_qk_norm( |
| self, q: torch.Tensor, k: torch.Tensor |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| if self.q_norm is None or self.k_norm is None: |
| return q, k |
| return self.q_norm(q), self.k_norm(k) |
|
|
| def _as_heads(self, x: torch.Tensor) -> torch.Tensor: |
| if x.dim() == 4: |
| if x.shape[2] == self.num_heads: |
| return x |
| if x.shape[1] == self.num_heads: |
| return x.transpose(1, 2).contiguous() |
| raise ValueError(f"Unexpected cross-attention KV shape {tuple(x.shape)}") |
| if x.dim() != 3: |
| raise ValueError(f"Expected 3D/4D cross-attention KV, got {tuple(x.shape)}") |
| bsz, seq_len, _ = x.shape |
| return x.view(bsz, seq_len, self.num_heads, self.head_dim) |
|
|
| def _attention( |
| self, |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| *, |
| attn_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| dropout_p = self.attn_dropout if self.training else 0.0 |
| out = F.scaled_dot_product_attention( |
| q.transpose(1, 2), |
| k.transpose(1, 2), |
| v.transpose(1, 2), |
| attn_mask=attn_mask, |
| dropout_p=dropout_p, |
| is_causal=False, |
| ) |
| return out.transpose(1, 2).contiguous() |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| *, |
| kv: Optional[torch.Tensor] = None, |
| kv_k: Optional[torch.Tensor] = None, |
| kv_v: Optional[torch.Tensor] = None, |
| attn_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| if (kv_k is None) != (kv_v is None): |
| raise ValueError("kv_k and kv_v must both be provided or both be None.") |
| if kv is not None and kv_k is not None: |
| raise ValueError("Provide either kv or kv_k/kv_v, not both.") |
| bsz, tgt_len, _ = x.shape |
| q = self.q_proj(x).view(bsz, tgt_len, self.num_heads, self.head_dim) |
| if kv_k is not None and kv_v is not None: |
| k = self._as_heads(kv_k) |
| v = self._as_heads(kv_v) |
| k_pre_normed = True |
| else: |
| if kv is None: |
| raise ValueError("cross-attention requires kv or kv_k/kv_v.") |
| src_len = kv.shape[1] |
| kv_proj = self.kv_proj(kv).view( |
| bsz, src_len, 2, self.num_heads, self.head_dim |
| ) |
| k = kv_proj[:, :, 0] |
| v = kv_proj[:, :, 1] |
| k_pre_normed = False |
| q = q.transpose(1, 2) |
| k = k.transpose(1, 2) |
| if k_pre_normed: |
| if self.q_norm is not None: |
| q = self.q_norm(q) |
| else: |
| q, k = self._apply_qk_norm(q, k) |
| q = q.transpose(1, 2) |
| k = k.transpose(1, 2) |
| out = self._attention(q, k, v, attn_mask=attn_mask) |
| out = out.reshape(bsz, tgt_len, self.hidden_size) |
| return self.out_drop(self.out_proj(out)) |
|
|
|
|
| class ActionExpertMLP(nn.Module): |
| def __init__( |
| self, |
| hidden_size: int, |
| *, |
| mlp_ratio: float, |
| multiple_of: int, |
| dropout: float = 0.0, |
| ) -> None: |
| super().__init__() |
| inner_dim = _round_up_multiple(int(hidden_size * mlp_ratio), multiple_of) |
| self.up_proj = nn.Linear(hidden_size, inner_dim) |
| self.gate_proj = nn.Linear(hidden_size, inner_dim) |
| self.down_proj = nn.Linear(inner_dim, hidden_size) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = F.silu(self.gate_proj(x)) * self.up_proj(x) |
| x = self.dropout(x) |
| x = self.down_proj(x) |
| return self.dropout(x) |
|
|
|
|
| class ActionExpertModulation(nn.Module): |
| def __init__(self, hidden_size: int, num_chunks: int) -> None: |
| super().__init__() |
| self.act = nn.SiLU() |
| self.linear = nn.Linear(hidden_size, num_chunks * hidden_size) |
|
|
| def forward(self, conditioning: torch.Tensor) -> torch.Tensor: |
| return self.linear(self.act(conditioning)) |
|
|
|
|
| class ActionExpertBlock(nn.Module): |
| def __init__( |
| self, |
| hidden_size: int, |
| num_heads: int, |
| *, |
| mlp_ratio: float, |
| ffn_multiple_of: int, |
| attn_dropout: float = 0.0, |
| dropout: float = 0.0, |
| qk_norm: bool = True, |
| qk_norm_eps: float = 1e-6, |
| rope: bool = True, |
| ) -> None: |
| super().__init__() |
| self.self_norm = ActionExpertRMSNorm(hidden_size, eps=1e-6) |
| self.cross_norm = ActionExpertRMSNorm(hidden_size, eps=1e-6) |
| self.ff_norm = ActionExpertRMSNorm(hidden_size, eps=1e-6) |
| self.self_attn = ActionExpertSelfAttention( |
| hidden_size, |
| num_heads, |
| attn_dropout=attn_dropout, |
| proj_dropout=dropout, |
| qk_norm=qk_norm, |
| qk_norm_eps=qk_norm_eps, |
| use_rope=rope, |
| ) |
| self.cross_attn = ActionExpertCrossAttention( |
| hidden_size, |
| num_heads, |
| attn_dropout=attn_dropout, |
| proj_dropout=dropout, |
| qk_norm=qk_norm, |
| qk_norm_eps=qk_norm_eps, |
| ) |
| self.mlp = ActionExpertMLP( |
| hidden_size, |
| mlp_ratio=mlp_ratio, |
| multiple_of=ffn_multiple_of, |
| dropout=dropout, |
| ) |
| self.modulation = ActionExpertModulation(hidden_size, 9) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| conditioning: torch.Tensor, |
| *, |
| cross_kv: Tuple[torch.Tensor, torch.Tensor], |
| self_attn_mask: Optional[torch.Tensor] = None, |
| attn_mask: Optional[torch.Tensor] = None, |
| is_causal: bool = False, |
| modulation: Optional[Tuple[torch.Tensor, ...]] = None, |
| rope_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| ) -> torch.Tensor: |
| if modulation is None: |
| modulation = self.modulation(conditioning).chunk(9, dim=1) |
| ( |
| shift_msa, |
| scale_msa, |
| gate_msa, |
| shift_mca, |
| scale_mca, |
| gate_mca, |
| shift_mlp, |
| scale_mlp, |
| gate_mlp, |
| ) = modulation |
| x = x + gate_msa.unsqueeze(1) * self.self_attn( |
| _modulate(self.self_norm(x), shift_msa, scale_msa), |
| attn_mask=self_attn_mask, |
| is_causal=is_causal, |
| rope_cache=rope_cache, |
| ) |
| x = x + gate_mca.unsqueeze(1) * self.cross_attn( |
| _modulate(self.cross_norm(x), shift_mca, scale_mca), |
| kv_k=cross_kv[0], |
| kv_v=cross_kv[1], |
| attn_mask=attn_mask, |
| ) |
| x = x + gate_mlp.unsqueeze(1) * self.mlp( |
| _modulate(self.ff_norm(x), shift_mlp, scale_mlp) |
| ) |
| return x |
|
|
|
|
| class ActionExpertFinalLayer(nn.Module): |
| def __init__(self, hidden_size: int, output_dim: int) -> None: |
| super().__init__() |
| self.norm = ActionExpertRMSNorm(hidden_size, eps=1e-6) |
| self.modulation = ActionExpertModulation(hidden_size, 2) |
| self.linear = nn.Linear(hidden_size, output_dim) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| conditioning: torch.Tensor, |
| *, |
| modulation: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| ) -> torch.Tensor: |
| if modulation is None: |
| modulation = self.modulation(conditioning).chunk(2, dim=1) |
| shift, scale = modulation |
| return self.linear(_modulate(self.norm(x), shift, scale)) |
|
|
|
|
| class SinusoidalTimeEmbedding(nn.Module): |
| def __init__(self, dim: int): |
| super().__init__() |
| self.dim = dim |
|
|
| def forward(self, timesteps: torch.Tensor) -> torch.Tensor: |
| if timesteps.dim() > 1: |
| timesteps = timesteps.view(timesteps.shape[0], -1)[:, 0] |
| half_dim = self.dim // 2 |
| freq = torch.exp( |
| torch.arange(half_dim, device=timesteps.device, dtype=timesteps.dtype) |
| * (-math.log(10000.0) / max(half_dim - 1, 1)) |
| ) |
| args = timesteps[:, None] * freq[None, :] |
| emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1) |
| if self.dim % 2 == 1: |
| emb = F.pad(emb, (0, 1)) |
| return emb |
|
|
|
|
| class ActionExpert(nn.Module): |
| """Modern MolmoAct2 action expert, embedded for HF remote-code inference.""" |
|
|
| def __init__( |
| self, |
| config: MolmoAct2ActionExpertConfig, |
| *, |
| llm_dim: int, |
| llm_kv_dim: int, |
| llm_num_layers: int, |
| device=None, |
| ): |
| super().__init__() |
| if config.implementation != "new": |
| raise ValueError("Only action_expert.implementation='new' is supported.") |
| if config.num_layers != llm_num_layers: |
| raise ValueError( |
| "MolmoAct2 HF action expert supports only per-layer conditioning with one " |
| f"action block per LLM layer (action={config.num_layers}, llm={llm_num_layers})." |
| ) |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self.llm_dim = llm_dim |
| self.llm_kv_dim = llm_kv_dim |
| self.action_head_dim = config.hidden_size // config.num_heads |
|
|
| self.time_embed = nn.Sequential( |
| SinusoidalTimeEmbedding(config.timestep_embed_dim), |
| nn.Linear(config.timestep_embed_dim, config.hidden_size, device=device), |
| nn.SiLU(), |
| nn.Linear(config.hidden_size, config.hidden_size, device=device), |
| ) |
| self.action_embed = nn.Linear( |
| config.max_action_dim, config.hidden_size, device=device |
| ) |
| self.state_encoder = nn.Linear( |
| config.hidden_size, config.hidden_size, device=device |
| ) |
| self.state_norm = ActionExpertRMSNorm( |
| config.hidden_size, eps=1e-6, device=device |
| ) |
| self.context_k_proj = nn.Linear( |
| self.llm_kv_dim, config.hidden_size, bias=False, device=device |
| ) |
| self.context_v_proj = nn.Linear( |
| self.llm_kv_dim, config.hidden_size, bias=False, device=device |
| ) |
| self.context_norm = ( |
| ActionExpertRMSNorm(config.hidden_size, eps=1e-6) |
| if config.context_layer_norm |
| else nn.Identity() |
| ) |
| self._modulation_cache_key: Optional[Tuple[Any, ...]] = None |
| self._modulation_cache_value: Optional[Sequence[ActionExpertStepModulation]] = ( |
| None |
| ) |
| self.blocks = nn.ModuleList( |
| [ |
| ActionExpertBlock( |
| config.hidden_size, |
| config.num_heads, |
| mlp_ratio=config.mlp_ratio, |
| ffn_multiple_of=config.ffn_multiple_of, |
| attn_dropout=config.attn_dropout, |
| dropout=config.dropout, |
| qk_norm=config.qk_norm, |
| qk_norm_eps=config.qk_norm_eps, |
| rope=config.rope, |
| ) |
| for _ in range(config.num_layers) |
| ] |
| ) |
| for block in self.blocks: |
| block.cross_attn.kv_proj.weight.requires_grad = False |
| if block.cross_attn.kv_proj.bias is not None: |
| block.cross_attn.kv_proj.bias.requires_grad = False |
| self.final_layer = ActionExpertFinalLayer( |
| config.hidden_size, config.max_action_dim |
| ) |
| self.reset_parameters() |
|
|
| def reset_parameters(self) -> None: |
| for module in self.time_embed.modules(): |
| if isinstance(module, nn.Linear): |
| _init_linear(module) |
| _init_linear(self.action_embed) |
| _init_linear(self.state_encoder) |
| self.state_norm.reset_parameters() |
| _init_linear(self.context_k_proj) |
| _init_linear(self.context_v_proj) |
| if isinstance(self.context_norm, ActionExpertRMSNorm): |
| self.context_norm.reset_parameters() |
| residual_scale = (2 * max(self.config.num_layers, 1)) ** -0.5 |
| for block in self.blocks: |
| _init_linear(block.self_attn.qkv) |
| _init_linear(block.self_attn.out_proj, scale=residual_scale) |
| _init_linear(block.cross_attn.q_proj) |
| _init_linear(block.cross_attn.kv_proj) |
| _init_linear(block.cross_attn.out_proj, scale=residual_scale) |
| _init_linear(block.mlp.up_proj) |
| _init_linear(block.mlp.gate_proj) |
| _init_linear(block.mlp.down_proj, scale=residual_scale) |
| _init_linear(block.modulation.linear, zero=True) |
| block.self_norm.reset_parameters() |
| block.cross_norm.reset_parameters() |
| block.ff_norm.reset_parameters() |
| if block.self_attn.q_norm is not None: |
| block.self_attn.q_norm.reset_parameters() |
| if block.self_attn.k_norm is not None: |
| block.self_attn.k_norm.reset_parameters() |
| if block.cross_attn.q_norm is not None: |
| block.cross_attn.q_norm.reset_parameters() |
| if block.cross_attn.k_norm is not None: |
| block.cross_attn.k_norm.reset_parameters() |
| self.final_layer.norm.reset_parameters() |
| _init_linear(self.final_layer.modulation.linear, zero=True) |
| _init_linear(self.final_layer.linear, zero=True) |
|
|
| def _reshape_hidden_to_heads(self, x: torch.Tensor) -> torch.Tensor: |
| return x.view( |
| x.shape[0], x.shape[1], self.config.num_heads, self.action_head_dim |
| ) |
|
|
| def _encode_states(self, states: Optional[torch.Tensor]) -> Optional[torch.Tensor]: |
| if states is None: |
| return None |
| if states.dim() == 2: |
| states = states.unsqueeze(1) |
| if states.shape[-1] != self.hidden_size: |
| feat_dim = states.shape[-1] |
| if feat_dim < self.hidden_size: |
| states = F.pad(states, (0, self.hidden_size - feat_dim)) |
| else: |
| states = states[..., : self.hidden_size] |
| return self.state_norm(self.state_encoder(states)) |
|
|
| def _project_kv_tensor(self, x: torch.Tensor, proj: nn.Linear) -> torch.Tensor: |
| flat = self.context_norm(proj(x)) |
| return self._reshape_hidden_to_heads(flat) |
|
|
| def _prepare_kv_context( |
| self, |
| encoder_kv_states: Sequence[Tuple[torch.Tensor, torch.Tensor]], |
| encoded_states: Optional[torch.Tensor], |
| ) -> Sequence[Tuple[torch.Tensor, torch.Tensor]]: |
| if len(encoder_kv_states) != len(self.blocks): |
| raise ValueError( |
| f"Expected {len(self.blocks)} KV layers for per-layer conditioning, " |
| f"got {len(encoder_kv_states)}." |
| ) |
| kv_contexts = [] |
| state_heads = ( |
| self._reshape_hidden_to_heads(encoded_states) |
| if encoded_states is not None |
| else None |
| ) |
| for block, (k_in, v_in) in zip(self.blocks, encoder_kv_states): |
| k_ctx = self._project_kv_tensor(k_in, self.context_k_proj) |
| v_ctx = self._project_kv_tensor(v_in, self.context_v_proj) |
| if state_heads is not None: |
| k_ctx = torch.cat([k_ctx, state_heads], dim=1) |
| v_ctx = torch.cat([v_ctx, state_heads], dim=1) |
| k_norm = block.cross_attn.k_norm |
| if k_norm is not None: |
| k_ctx = k_norm(k_ctx.transpose(1, 2)).transpose(1, 2) |
| kv_contexts.append((k_ctx, v_ctx)) |
| return kv_contexts |
|
|
| @staticmethod |
| def _build_cross_attention_mask( |
| encoder_attention_mask: Optional[torch.Tensor], |
| encoded_states: Optional[torch.Tensor], |
| batch_size: int, |
| dtype: torch.dtype, |
| ) -> Optional[torch.Tensor]: |
| state_seq_len = 0 if encoded_states is None else encoded_states.shape[1] |
| if encoder_attention_mask is None: |
| return None |
| mask = encoder_attention_mask[:, None, None, :].to(dtype=dtype) |
| if state_seq_len > 0: |
| ones = torch.ones( |
| batch_size, |
| 1, |
| 1, |
| state_seq_len, |
| device=mask.device, |
| dtype=mask.dtype, |
| ) |
| mask = torch.cat([mask, ones], dim=-1) |
| return (1.0 - mask) * torch.finfo(dtype).min |
|
|
| def _build_self_attention_mask( |
| self, |
| action_attention_mask: Optional[torch.Tensor], |
| seq_len: int, |
| device: torch.device, |
| dtype: torch.dtype, |
| ) -> Optional[torch.Tensor]: |
| mask = None |
| if action_attention_mask is not None: |
| valid = action_attention_mask.to(device=device, dtype=torch.bool) |
| key_mask = (~valid)[:, None, None, :].to(dtype=dtype) |
| mask = key_mask * torch.finfo(dtype).min |
| if self.config.causal_attn: |
| causal = torch.ones(seq_len, seq_len, device=device, dtype=torch.bool).triu( |
| diagonal=1 |
| ) |
| causal = ( |
| causal.unsqueeze(0).unsqueeze(0).to(dtype=dtype) |
| * torch.finfo(dtype).min |
| ) |
| mask = causal if mask is None else mask + causal |
| return mask |
|
|
| def prepare_context( |
| self, |
| *, |
| encoder_kv_states: Sequence[Tuple[torch.Tensor, torch.Tensor]], |
| encoder_attention_mask: Optional[torch.Tensor] = None, |
| action_attention_mask: Optional[torch.Tensor] = None, |
| state_embeddings: Optional[torch.Tensor] = None, |
| batch_size: int, |
| seq_len: int, |
| device: torch.device, |
| dtype: torch.dtype, |
| ) -> ActionExpertContext: |
| encoded_states = self._encode_states(state_embeddings) |
| valid_action = None |
| if action_attention_mask is not None: |
| valid_action = action_attention_mask.to( |
| device=device, dtype=dtype |
| ).unsqueeze(-1) |
| rope_cache = None |
| if len(self.blocks) > 0 and self.blocks[0].self_attn.rope is not None: |
| rope_cache = self.blocks[0].self_attn.rope.build_cache( |
| seq_len=seq_len, |
| device=device, |
| dtype=dtype, |
| ) |
| kv_contexts = self._prepare_kv_context(encoder_kv_states, encoded_states) |
| cross_mask = self._build_cross_attention_mask( |
| encoder_attention_mask, |
| encoded_states, |
| batch_size, |
| dtype, |
| ) |
| self_mask = self._build_self_attention_mask( |
| action_attention_mask, seq_len, device, dtype |
| ) |
| return ActionExpertContext( |
| kv_contexts=kv_contexts, |
| cross_mask=cross_mask, |
| self_mask=self_mask, |
| valid_action=valid_action, |
| rope_cache=rope_cache, |
| ) |
|
|
| def prepare_modulation_cache( |
| self, |
| timesteps: Sequence[torch.Tensor], |
| ) -> Sequence[ActionExpertStepModulation]: |
| cache = [] |
| for idx, step_t in enumerate(timesteps): |
| conditioning = self.time_embed(step_t) |
| block_modulations = [] |
| for block in self.blocks: |
| block_modulations.append( |
| tuple(block.modulation(conditioning).chunk(9, dim=1)) |
| ) |
| final_modulation = tuple( |
| self.final_layer.modulation(conditioning).chunk(2, dim=1) |
| ) |
| cache.append( |
| ActionExpertStepModulation( |
| conditioning=conditioning, |
| block_modulations=block_modulations, |
| final_modulation=final_modulation, |
| ) |
| ) |
| return cache |
|
|
| def get_or_prepare_modulation_cache( |
| self, |
| timesteps: Sequence[torch.Tensor], |
| *, |
| cache_key: Optional[Tuple[Any, ...]] = None, |
| ) -> Sequence[ActionExpertStepModulation]: |
| if self.training or cache_key is None: |
| return self.prepare_modulation_cache(timesteps) |
| if ( |
| self._modulation_cache_key == cache_key |
| and self._modulation_cache_value is not None |
| ): |
| return self._modulation_cache_value |
| cached = self.prepare_modulation_cache(timesteps) |
| self._modulation_cache_key = cache_key |
| self._modulation_cache_value = cached |
| return cached |
|
|
| def forward_with_context( |
| self, |
| actions: torch.Tensor, |
| timesteps: torch.Tensor, |
| *, |
| context: ActionExpertContext, |
| modulation: Optional[ActionExpertStepModulation] = None, |
| ) -> torch.Tensor: |
| bsz, seq_len, _ = actions.shape |
| if seq_len > self.config.max_horizon: |
| raise ValueError( |
| f"Action sequence length {seq_len} exceeds configured max_horizon={self.config.max_horizon}" |
| ) |
| if modulation is None: |
| conditioning = self.time_embed(timesteps) |
| block_modulations: Sequence[Optional[Tuple[torch.Tensor, ...]]] = [ |
| None |
| ] * len(self.blocks) |
| final_modulation = None |
| else: |
| conditioning = modulation.conditioning |
| block_modulations = modulation.block_modulations |
| final_modulation = modulation.final_modulation |
| x = self.action_embed(actions) |
| if context.valid_action is not None: |
| x = x * context.valid_action |
| for idx, (block, kv_context, block_modulation) in enumerate( |
| zip(self.blocks, context.kv_contexts, block_modulations) |
| ): |
| x = block( |
| x, |
| conditioning, |
| cross_kv=kv_context, |
| self_attn_mask=context.self_mask, |
| attn_mask=context.cross_mask, |
| is_causal=self.config.causal_attn, |
| modulation=block_modulation, |
| rope_cache=context.rope_cache, |
| ) |
| if context.valid_action is not None: |
| x = x * context.valid_action |
| out = self.final_layer(x, conditioning, modulation=final_modulation) |
| if context.valid_action is not None: |
| out = out * context.valid_action |
| return out |
|
|
| def forward( |
| self, |
| actions: torch.Tensor, |
| timesteps: torch.Tensor, |
| *, |
| encoder_kv_states: Sequence[Tuple[torch.Tensor, torch.Tensor]], |
| encoder_attention_mask: Optional[torch.Tensor] = None, |
| action_attention_mask: Optional[torch.Tensor] = None, |
| state_embeddings: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| bsz, seq_len, _ = actions.shape |
| context = self.prepare_context( |
| encoder_kv_states=encoder_kv_states, |
| encoder_attention_mask=encoder_attention_mask, |
| action_attention_mask=action_attention_mask, |
| state_embeddings=state_embeddings, |
| batch_size=bsz, |
| seq_len=seq_len, |
| device=actions.device, |
| dtype=actions.dtype, |
| ) |
| return self.forward_with_context(actions, timesteps, context=context) |
|
|
|
|
| def _to_numpy(value: Any) -> np.ndarray: |
| if isinstance(value, np.ndarray): |
| return value |
| if torch.is_tensor(value): |
| return value.detach().cpu().numpy() |
| return np.asarray(value) |
|
|
|
|
| def _to_array(value: Any) -> Optional[np.ndarray]: |
| if value is None: |
| return None |
| if torch.is_tensor(value): |
| return value.detach().cpu().numpy().astype(np.float32, copy=False) |
| return np.asarray(value, dtype=np.float32) |
|
|
|
|
| def _to_mask(value: Any, fallback_like: Optional[np.ndarray]) -> Optional[np.ndarray]: |
| if value is None: |
| return None |
| mask = np.asarray(value, dtype=np.bool_) |
| if fallback_like is not None and mask.shape != fallback_like.shape: |
| mask = np.broadcast_to(mask, fallback_like.shape) |
| return mask |
|
|
|
|
| def _feature_dim_from_stats(stats: Optional[Mapping[str, Any]]) -> Optional[int]: |
| if not isinstance(stats, Mapping): |
| return None |
| for key in ( |
| "mean", |
| "std", |
| "min", |
| "max", |
| "q01", |
| "q99", |
| "q10", |
| "q90", |
| "mask", |
| "names", |
| ): |
| value = stats.get(key) |
| if value is None: |
| continue |
| arr = np.asarray(value) |
| if arr.shape: |
| return int(arr.shape[-1]) |
| if isinstance(value, Sequence) and not isinstance(value, (str, bytes)): |
| return int(len(value)) |
| return None |
|
|
|
|
| class _FeatureNormalizer: |
| def __init__( |
| self, |
| *, |
| mode: str, |
| mean: Optional[np.ndarray] = None, |
| std: Optional[np.ndarray] = None, |
| min_val: Optional[np.ndarray] = None, |
| max_val: Optional[np.ndarray] = None, |
| q_low: Optional[np.ndarray] = None, |
| q_high: Optional[np.ndarray] = None, |
| mask: Optional[np.ndarray] = None, |
| zero_mask: Optional[np.ndarray] = None, |
| ): |
| self.mode = mode |
| self.mean = mean |
| self.std = std |
| self.min_val = min_val |
| self.max_val = max_val |
| self.q_low = q_low |
| self.q_high = q_high |
| self.mask = mask |
| self.zero_mask = zero_mask |
|
|
| @classmethod |
| def from_stats( |
| cls, stats: Optional[Mapping[str, Any]], mode: str |
| ) -> Optional["_FeatureNormalizer"]: |
| if stats is None: |
| return None |
| raw_mask = stats.get("mask") if isinstance(stats, Mapping) else None |
| if mode == "none": |
| fallback = None |
| for key in ( |
| "mean", |
| "std", |
| "min", |
| "max", |
| "q01", |
| "q99", |
| "q10", |
| "q90", |
| "mask", |
| ): |
| fallback = _to_array(stats.get(key)) |
| if fallback is not None: |
| break |
| return cls(mode=mode, mask=_to_mask(raw_mask, fallback)) |
| if mode == "mean_std": |
| mean = _to_array(stats.get("mean")) |
| std = _to_array(stats.get("std")) |
| if mean is None or std is None: |
| raise ValueError("norm_mode='mean_std' requires mean and std stats.") |
| return cls(mode=mode, mean=mean, std=std, mask=_to_mask(raw_mask, mean)) |
| if mode == "min_max": |
| min_val = _to_array(stats.get("min")) |
| max_val = _to_array(stats.get("max")) |
| if min_val is None or max_val is None: |
| raise ValueError("norm_mode='min_max' requires min and max stats.") |
| return cls( |
| mode=mode, |
| min_val=min_val, |
| max_val=max_val, |
| mask=_to_mask(raw_mask, min_val), |
| zero_mask=(min_val == max_val), |
| ) |
| if mode in {"q01_q99", "q10_q90"}: |
| low_key, high_key = ("q01", "q99") if mode == "q01_q99" else ("q10", "q90") |
| q_low = _to_array(stats.get(low_key)) |
| q_high = _to_array(stats.get(high_key)) |
| if q_low is None or q_high is None: |
| raise ValueError( |
| f"norm_mode={mode!r} requires {low_key} and {high_key} stats." |
| ) |
| min_val = _to_array(stats.get("min")) |
| max_val = _to_array(stats.get("max")) |
| fallback = min_val if min_val is not None else q_low |
| zero_mask = ( |
| None if min_val is None or max_val is None else (min_val == max_val) |
| ) |
| return cls( |
| mode=mode, |
| min_val=min_val, |
| max_val=max_val, |
| q_low=q_low, |
| q_high=q_high, |
| mask=_to_mask(raw_mask, fallback), |
| zero_mask=zero_mask, |
| ) |
| raise ValueError(f"Unsupported robot normalization mode {mode!r}.") |
|
|
| def normalize(self, x: Any) -> Any: |
| arr = _to_array(x) |
| if arr is None: |
| return None |
| eps = 1e-6 |
| if self.mode == "none": |
| normed = arr |
| elif self.mode == "mean_std": |
| normed = (arr - self.mean) / np.maximum(self.std, eps) |
| elif self.mode == "min_max": |
| normed = ( |
| 2.0 |
| * (arr - self.min_val) |
| / np.maximum(self.max_val - self.min_val, eps) |
| - 1.0 |
| ) |
| elif self.mode in {"q01_q99", "q10_q90"}: |
| normed = ( |
| 2.0 * (arr - self.q_low) / np.maximum(self.q_high - self.q_low, eps) |
| - 1.0 |
| ) |
| else: |
| normed = arr |
| if self.mode in {"min_max", "q01_q99", "q10_q90"}: |
| normed = np.clip(normed, -1.0, 1.0) |
| if self.mask is not None: |
| normed = np.where(self.mask, normed, arr) |
| if self.zero_mask is not None: |
| normed = np.where(self.zero_mask, 0.0, normed) |
| if torch.is_tensor(x): |
| return torch.as_tensor(normed, device=x.device, dtype=x.dtype) |
| return normed |
|
|
| def unnormalize(self, x: Any) -> Any: |
| arr = _to_array(x) |
| if arr is None: |
| return None |
| if self.mode in {"min_max", "q01_q99", "q10_q90"}: |
| arr = np.clip(arr, -1.0, 1.0) |
| if self.mode == "none": |
| out = arr |
| elif self.mode == "mean_std": |
| out = arr * self.std + self.mean |
| elif self.mode == "min_max": |
| out = (arr + 1.0) * (self.max_val - self.min_val) / 2.0 + self.min_val |
| elif self.mode in {"q01_q99", "q10_q90"}: |
| out = (arr + 1.0) * (self.q_high - self.q_low) / 2.0 + self.q_low |
| else: |
| out = arr |
| if self.mask is not None: |
| out = np.where(self.mask, out, arr) |
| if torch.is_tensor(x): |
| return torch.as_tensor(out, device=x.device, dtype=x.dtype) |
| return out |
|
|
|
|
| class _RobotStats: |
| def __init__(self, payload: Mapping[str, Any]): |
| self.norm_mode = str(payload.get("norm_mode", "min_max")) |
| self.metadata_by_tag: Dict[str, Dict[str, Any]] = { |
| str(tag): dict(metadata or {}) |
| for tag, metadata in dict(payload.get("metadata_by_tag") or {}).items() |
| } |
| self.action_normalizers = {} |
| self.state_normalizers = {} |
| for tag, metadata in self.metadata_by_tag.items(): |
| if metadata.get("action_stats") is not None: |
| self.action_normalizers[tag] = _FeatureNormalizer.from_stats( |
| metadata.get("action_stats"), |
| self.norm_mode, |
| ) |
| if metadata.get("state_stats") is not None: |
| self.state_normalizers[tag] = _FeatureNormalizer.from_stats( |
| metadata.get("state_stats"), |
| self.norm_mode, |
| ) |
|
|
| def validate_tag(self, norm_tag: Optional[str]) -> str: |
| tag = str(norm_tag or "").strip() |
| if not tag: |
| raise ValueError("MolmoAct2 `predict_action` requires `norm_tag`.") |
| if tag not in self.metadata_by_tag: |
| allowed = ", ".join(sorted(self.metadata_by_tag)) |
| raise ValueError( |
| f"Unknown MolmoAct2 normalization tag {tag!r}. Allowed tags: {allowed}." |
| ) |
| return tag |
|
|
| def get_metadata(self, norm_tag: Optional[str]) -> Dict[str, Any]: |
| if norm_tag is None: |
| return {} |
| return dict(self.metadata_by_tag.get(str(norm_tag), {}) or {}) |
|
|
| def normalize_state(self, state: Any, norm_tag: str) -> Any: |
| normalizer = self.state_normalizers.get(str(norm_tag)) |
| return state if normalizer is None else normalizer.normalize(state) |
|
|
| def unnormalize_action(self, action: Any, norm_tag: str) -> Any: |
| normalizer = self.action_normalizers.get(str(norm_tag)) |
| return action if normalizer is None else normalizer.unnormalize(action) |
|
|
| def get_action_dim(self, norm_tag: str) -> Optional[int]: |
| metadata = self.get_metadata(norm_tag) |
| stats = metadata.get("action_stats") |
| dim = _feature_dim_from_stats(stats) |
| return dim |
|
|
| def get_state_dim(self, norm_tag: str) -> Optional[int]: |
| metadata = self.get_metadata(norm_tag) |
| return _feature_dim_from_stats(metadata.get("state_stats")) |
|
|
| def get_action_horizon(self, norm_tag: str) -> Optional[int]: |
| return self._get_positive_int(norm_tag, "action_horizon") |
|
|
| def get_n_action_steps(self, norm_tag: str) -> Optional[int]: |
| return self._get_positive_int(norm_tag, "n_action_steps") |
|
|
| def _get_positive_int(self, norm_tag: str, key: str) -> Optional[int]: |
| value = self.get_metadata(norm_tag).get(key) |
| if value is None: |
| return None |
| value = int(value) |
| if value < 1: |
| raise ValueError( |
| f"Robot metadata for norm_tag={norm_tag!r} must define {key} >= 1." |
| ) |
| return value |
|
|
|
|
| def _normalize_image_for_cache(image: Any) -> np.ndarray: |
| arr = np.asarray(image) |
| if arr.ndim == 2: |
| arr = np.stack([arr] * 3, axis=-1) |
| if arr.ndim == 3 and arr.shape[0] in {1, 3, 4} and arr.shape[-1] not in {1, 3, 4}: |
| arr = np.moveaxis(arr, 0, -1) |
| if arr.ndim == 3 and arr.shape[-1] == 1: |
| arr = np.repeat(arr, 3, axis=-1) |
| if arr.dtype in (np.float32, np.float64): |
| if arr.size > 0 and float(arr.max()) <= 1.0: |
| arr = arr * 255.0 |
| arr = np.clip(arr, 0, 255).astype(np.uint8) |
| elif arr.dtype != np.uint8: |
| arr = np.clip(arr, 0, 255).astype(np.uint8) |
| return arr |
|
|
|
|
| def _extract_first_image(images: Any) -> Optional[np.ndarray]: |
| if images is None: |
| return None |
| if isinstance(images, (list, tuple)): |
| if not images: |
| return None |
| return _normalize_image_for_cache(images[0]) |
| arr = _to_numpy(images) |
| if arr.ndim == 4: |
| return _normalize_image_for_cache(arr[0]) |
| return _normalize_image_for_cache(arr) |
|
|
|
|
| def _resize_depth_reasoning_image(image: np.ndarray, target_size: int) -> np.ndarray: |
| from PIL import Image |
|
|
| if image.shape[0] == target_size and image.shape[1] == target_size: |
| return image |
| pil_image = Image.fromarray(np.asarray(image, dtype=np.uint8)) |
| return np.asarray(pil_image.resize((target_size, target_size), Image.BILINEAR)) |
|
|
|
|
| def _compute_depth_update_mask( |
| current_image: np.ndarray, |
| previous_image: np.ndarray, |
| *, |
| num_depth_codes: int, |
| ) -> np.ndarray: |
| grid_side = int(math.isqrt(int(num_depth_codes))) |
| if grid_side * grid_side != int(num_depth_codes): |
| raise ValueError( |
| f"enable_adaptive_depth=True requires a square depth grid, got num_depth_codes={int(num_depth_codes)}." |
| ) |
| target_size = grid_side * _DEPTH_REASONING_PATCH_SIZE |
| current_resized = _resize_depth_reasoning_image(current_image, target_size).astype( |
| np.float32 |
| ) |
| previous_resized = _resize_depth_reasoning_image( |
| previous_image, target_size |
| ).astype(np.float32) |
| current_patches = ( |
| current_resized.reshape( |
| grid_side, |
| _DEPTH_REASONING_PATCH_SIZE, |
| grid_side, |
| _DEPTH_REASONING_PATCH_SIZE, |
| 3, |
| ) |
| .transpose(0, 2, 1, 3, 4) |
| .reshape(grid_side, grid_side, -1) |
| ) |
| previous_patches = ( |
| previous_resized.reshape( |
| grid_side, |
| _DEPTH_REASONING_PATCH_SIZE, |
| grid_side, |
| _DEPTH_REASONING_PATCH_SIZE, |
| 3, |
| ) |
| .transpose(0, 2, 1, 3, 4) |
| .reshape(grid_side, grid_side, -1) |
| ) |
| dot = np.sum(current_patches * previous_patches, axis=-1) |
| norm_current = np.linalg.norm(current_patches, axis=-1) |
| norm_previous = np.linalg.norm(previous_patches, axis=-1) |
| denom = norm_current * norm_previous |
| similarity = np.where(denom < 1e-8, 1.0, dot / (denom + 1e-12)) |
| return np.asarray(similarity < _DEPTH_REASONING_THRESHOLD, dtype=np.bool_).reshape( |
| -1 |
| ) |
|
|
|
|
| def _build_depth_update_spans( |
| update_mask: Sequence[bool], |
| ) -> List[Tuple[int, int, bool]]: |
| flat_mask = np.asarray(update_mask, dtype=np.bool_).reshape(-1) |
| if flat_mask.size == 0: |
| return [] |
| spans: List[Tuple[int, int, bool]] = [] |
| start = 0 |
| current_value = bool(flat_mask[0]) |
| for idx in range(1, int(flat_mask.shape[0])): |
| next_value = bool(flat_mask[idx]) |
| if next_value == current_value: |
| continue |
| spans.append((start, idx, current_value)) |
| start = idx |
| current_value = next_value |
| spans.append((start, int(flat_mask.shape[0]), current_value)) |
| return spans |
|
|
|
|
| def _wrap_setup_text(setup_type: str, add_setup_tokens: bool = False) -> str: |
| setup_type = str(setup_type or "") |
| if setup_type.startswith(SETUP_START_TOKEN) and setup_type.endswith( |
| SETUP_END_TOKEN |
| ): |
| return setup_type |
| if not setup_type or not add_setup_tokens: |
| return setup_type |
| return f"{SETUP_START_TOKEN}{setup_type}{SETUP_END_TOKEN}" |
|
|
|
|
| def _wrap_control_text(control_mode: str, add_control_tokens: bool = False) -> str: |
| control_mode = str(control_mode or "") |
| if control_mode.startswith(CONTROL_START_TOKEN) and control_mode.endswith( |
| CONTROL_END_TOKEN |
| ): |
| return control_mode |
| if not control_mode or not add_control_tokens: |
| return control_mode |
| return f"{CONTROL_START_TOKEN}{control_mode}{CONTROL_END_TOKEN}" |
|
|
|
|
| def _discretize_normalized_state( |
| state: np.ndarray, num_state_tokens: int |
| ) -> np.ndarray: |
| arr = np.asarray(state, dtype=np.float32) |
| arr = np.nan_to_num(arr, nan=0.0, posinf=1.0, neginf=-1.0) |
| arr = np.clip(arr, -1.0, 1.0) |
| scaled = (arr + 1.0) / 2.0 * float(num_state_tokens - 1) |
| return np.clip(np.rint(scaled).astype(np.int64), 0, int(num_state_tokens) - 1) |
|
|
|
|
| def _build_discrete_state_string( |
| state: Optional[np.ndarray], num_state_tokens: int |
| ) -> str: |
| if state is None: |
| return "" |
| token_ids = _discretize_normalized_state(state, num_state_tokens).reshape(-1) |
| return f"{STATE_START_TOKEN}{''.join(f'{STATE_TOKEN_PREFIX}{int(token_id)}>' for token_id in token_ids)}{STATE_END_TOKEN}" |
|
|
|
|
| def _normalize_question_text(text: str) -> str: |
| normalized = re.sub(r"\s+", " ", text).strip() |
| if not normalized: |
| return "" |
| previous = None |
| while normalized and normalized != previous: |
| previous = normalized |
| normalized = normalized.strip().strip(_QUESTION_SURROUNDING_DELIMITERS).strip() |
| for pattern in _QUESTION_PREFIX_PATTERNS: |
| normalized = pattern.sub("", normalized, count=1).strip() |
| normalized = normalized.rstrip(_QUESTION_TRAILING_SENTENCE_PUNCTUATION).rstrip() |
| normalized = normalized.rstrip(_QUESTION_TRAILING_CLOSERS).rstrip() |
| normalized = normalized.rstrip(_QUESTION_TRAILING_SENTENCE_PUNCTUATION).rstrip() |
| sentence_chunks = [ |
| chunk.strip() for chunk in re.split(r"[.!?]+", normalized) if chunk.strip() |
| ] |
| if len(sentence_chunks) > 1: |
| normalized = "; ".join(sentence_chunks) |
| normalized = normalized.lower() |
| return normalized |
|
|
|
|
| def _build_robot_text( |
| *, |
| task: str, |
| style: str, |
| discrete_state_string: str, |
| setup_type: str, |
| control_mode: str, |
| add_setup_tokens: bool, |
| add_control_tokens: bool, |
| num_images: int, |
| ) -> str: |
| setup_text = _wrap_setup_text(setup_type, add_setup_tokens=add_setup_tokens) |
| control_text = _wrap_control_text( |
| control_mode, add_control_tokens=add_control_tokens |
| ) |
| state_clause = ( |
| f" The current state of the robot is {discrete_state_string}." |
| if discrete_state_string |
| else "" |
| ) |
| if style == "robot_depth_action": |
| prompt = ( |
| f"The task is to {task}. The setup is {setup_text}.{state_clause} " |
| f"The expected control mode is {control_text}. Given these, first predict the depth map of the main image " |
| "and then predict the action the robot should take to complete the task?" |
| ) |
| trigger = f"{DEPTH_OUTPUT_TOKEN}{ACTION_OUTPUT_TOKEN}" |
| else: |
| prompt = ( |
| f"The task is to {task}. The setup is {setup_text}.{state_clause} " |
| f"The expected control mode is {control_text}. Given these, what action should the robot take to complete the task?" |
| ) |
| trigger = ACTION_OUTPUT_TOKEN |
| if num_images <= 0: |
| image_prefix = "" |
| elif num_images == 1: |
| image_prefix = "<|image|>" |
| else: |
| image_prefix = "".join(f"Image {idx + 1}<|image|>" for idx in range(num_images)) |
| return f"{image_prefix}<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n{trigger}" |
|
|
|
|
| def _flatten_generated_token_ids(token_ids: torch.Tensor) -> List[int]: |
| if token_ids.ndim == 3: |
| return [int(x) for x in token_ids[0, 0].detach().cpu().tolist()] |
| if token_ids.ndim == 2: |
| return [int(x) for x in token_ids[0].detach().cpu().tolist()] |
| if token_ids.ndim == 1: |
| return [int(x) for x in token_ids.detach().cpu().tolist()] |
| raise ValueError( |
| f"Unexpected generated token tensor shape {tuple(token_ids.shape)}" |
| ) |
|
|
|
|
| def _extract_discrete_token_bins( |
| generated_ids: List[int], |
| start_token_id: int, |
| end_token_id: int, |
| token_id_to_bin: Dict[int, int], |
| ) -> List[int]: |
| start_idx = None |
| end_idx = None |
| for idx, token_id in enumerate(generated_ids): |
| if token_id == start_token_id: |
| start_idx = idx |
| break |
| if start_idx is not None: |
| for idx in range(start_idx + 1, len(generated_ids)): |
| if generated_ids[idx] == end_token_id: |
| end_idx = idx |
| break |
| span_start = 0 if start_idx is None else start_idx + 1 |
| span_end = len(generated_ids) if end_idx is None else end_idx |
| return [ |
| int(token_id_to_bin[token_id]) |
| for token_id in generated_ids[span_start:span_end] |
| if token_id in token_id_to_bin |
| ] |
|
|
|
|
| @dataclass |
| class MolmoAct2ActionOutput(ModelOutput): |
| actions: Optional[torch.FloatTensor] = None |
| generated_token_ids: Optional[torch.LongTensor] = None |
| depth_bins: Optional[torch.LongTensor] = None |
| depth_cache: Optional[Dict[str, Any]] = None |
|
|
|
|
| @dataclass |
| class _DepthPrefix: |
| token_ids: torch.Tensor |
| depth_bins: torch.Tensor |
| full_input_ids: torch.Tensor |
| attention_mask: Optional[torch.Tensor] |
| encoder_kv_states: Sequence[Tuple[torch.Tensor, torch.Tensor]] |
| next_output: Any |
| past_key_values: Optional[Cache] |
|
|
|
|
| @dataclass |
| class MolmoAct2CausalLMOutputWithPast(ModelOutput): |
| """ |
| Base class for MolmoAct2 causal language model (or autoregressive) outputs. |
| |
| Args: |
| loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
| Language modeling loss (for next-token prediction). |
| logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): |
| Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
| past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
| It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). |
| |
| Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see |
| `past_key_values` input) to speed up sequential decoding. |
| image_hidden_states (`torch.FloatTensor`, *optional*): |
| A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. |
| image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. |
| """ |
|
|
| loss: Optional[torch.FloatTensor] = None |
| logits: Optional[torch.FloatTensor] = None |
| past_key_values: Optional[Cache] = None |
| hidden_states: Optional[tuple[torch.FloatTensor]] = None |
| attentions: Optional[tuple[torch.FloatTensor]] = None |
| image_hidden_states: Optional[torch.FloatTensor] = None |
|
|
|
|
| @dataclass |
| class MolmoAct2ModelOutputWithPast(BaseModelOutputWithPast): |
| """ |
| Base class for MolmoAct2 outputs, with hidden states and attentions. |
| |
| Args: |
| image_hidden_states (`torch.FloatTensor`, *optional*): |
| A `torch.FloatTensor` of size `(batch_num_patches, hidden_size)`. |
| image_hidden_states of the model produced by the vision backbone |
| """ |
|
|
| last_hidden_state: Optional[torch.FloatTensor] = None |
| past_key_values: Optional[Cache] = None |
| hidden_states: Optional[tuple[torch.FloatTensor]] = None |
| attentions: Optional[tuple[torch.FloatTensor]] = None |
| image_hidden_states: Optional[torch.FloatTensor] = None |
|
|
|
|
| class ViTMLP(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| hidden_dim: int, |
| hidden_act: str, |
| device: Union[str, torch.device] = None, |
| ): |
| super().__init__() |
| self.w1 = nn.Linear(dim, hidden_dim, bias=True, device=device) |
| self.act = ACT2FN[hidden_act] |
| self.w2 = nn.Linear(hidden_dim, dim, bias=True, device=device) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.w2(self.act(self.w1(x))) |
|
|
|
|
| class ViTMultiHeadDotProductAttention(nn.Module): |
| def __init__( |
| self, |
| hidden_size: int, |
| num_heads: int, |
| num_key_value_heads: int, |
| head_dim: int, |
| use_bias: bool = True, |
| input_dim: Optional[int] = None, |
| float32_attention: bool = True, |
| attention_dropout: float = 0.0, |
| residual_dropout: float = 0.0, |
| device: Union[str, torch.device] = None, |
| attn_implementation: str = "eager", |
| ): |
| super().__init__() |
|
|
| self.hidden_size = hidden_size |
| self.num_heads = num_heads |
| self.head_dim = head_dim |
| self.num_key_value_heads = num_key_value_heads |
| self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
| self.attn_implementation = attn_implementation |
| self.is_causal = False |
|
|
| input_dim = input_dim or hidden_size |
|
|
| self.wq = nn.Linear( |
| input_dim, |
| self.num_heads * self.head_dim, |
| bias=use_bias, |
| device=device, |
| ) |
| self.wk = nn.Linear( |
| input_dim, |
| self.num_key_value_heads * self.head_dim, |
| bias=use_bias, |
| device=device, |
| ) |
| self.wv = nn.Linear( |
| input_dim, |
| self.num_key_value_heads * self.head_dim, |
| bias=use_bias, |
| device=device, |
| ) |
| self.wo = nn.Linear( |
| self.num_heads * self.head_dim, |
| self.hidden_size, |
| ) |
| self.float32_attention = float32_attention |
| self.attention_dropout = attention_dropout |
| self.residual_dropout = nn.Dropout(residual_dropout) |
| self.sdpa_backend_list = [ |
| SDPBackend.FLASH_ATTENTION, |
| SDPBackend.CUDNN_ATTENTION, |
| SDPBackend.EFFICIENT_ATTENTION, |
| SDPBackend.MATH, |
| ] |
|
|
| def _split_heads(self, hidden_states, num_heads) -> torch.Tensor: |
| return hidden_states.reshape( |
| hidden_states.shape[:2] + (num_heads, self.head_dim) |
| ) |
|
|
| def _merge_heads(self, hidden_states) -> torch.Tensor: |
| return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,)) |
|
|
| def forward( |
| self, |
| inputs_q: torch.Tensor, |
| inputs_kv: Optional[torch.Tensor] = None, |
| attn_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
|
|
| if inputs_kv is not None: |
| inputs_k = inputs_kv |
| inputs_v = inputs_kv |
| else: |
| inputs_k = inputs_q |
| inputs_v = inputs_q |
|
|
| xq, xk, xv = self.wq(inputs_q), self.wk(inputs_k), self.wv(inputs_v) |
|
|
| xq = self._split_heads(xq, self.num_heads) |
| xk = self._split_heads(xk, self.num_key_value_heads) |
| xv = self._split_heads(xv, self.num_key_value_heads) |
|
|
| if self.num_heads != self.num_key_value_heads: |
| xk = xk.repeat_interleave( |
| self.num_key_value_groups, dim=2, output_size=self.num_heads |
| ) |
| xv = xv.repeat_interleave( |
| self.num_key_value_groups, dim=2, output_size=self.num_heads |
| ) |
|
|
| og_dtype = xq.dtype |
|
|
| if self.float32_attention: |
| xq = xq.to(torch.float) |
| xk = xk.to(torch.float) |
|
|
| dropout_p = 0.0 if not self.training else self.attention_dropout |
|
|
| if self.attn_implementation == "eager": |
| attn_weights = torch.einsum( |
| "...qhd,...khd->...hqk", xq / math.sqrt(xq.size(-1)), xk |
| ) |
| attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to( |
| xq.dtype |
| ) |
| attn_weights = F.dropout(attn_weights, p=dropout_p, training=self.training) |
| attn_output = torch.einsum( |
| "...hqk,...khd->...qhd", attn_weights.to(xv.dtype), xv |
| ) |
|
|
| elif self.attn_implementation == "sdpa": |
| if self.float32_attention: |
| xv = xv.to(torch.float32) |
|
|
| query = xq.transpose(1, 2).contiguous() |
| key = xk.transpose(1, 2).contiguous() |
| value = xv.transpose(1, 2).contiguous() |
| if inputs_kv is not None: |
| with sdpa_kernel(self.sdpa_backend_list): |
| attn_output = F.scaled_dot_product_attention( |
| query, |
| key, |
| value, |
| attn_mask=attn_mask, |
| is_causal=False, |
| dropout_p=dropout_p, |
| ).transpose(1, 2) |
| else: |
| attn_output = F.scaled_dot_product_attention( |
| query, |
| key, |
| value, |
| attn_mask=attn_mask, |
| is_causal=False, |
| dropout_p=dropout_p, |
| ).transpose(1, 2) |
|
|
| elif self.attn_implementation == "flash_attention_2": |
| if xq.dtype == torch.float32: |
| if torch.is_autocast_enabled(): |
| target_dtype = torch.get_autocast_gpu_dtype() |
| else: |
| target_dtype = self.wq.weight.dtype |
| attn_output = _flash_attention_forward( |
| xq, |
| xk, |
| xv, |
| attention_mask=attn_mask, |
| query_length=inputs_q.shape[1], |
| is_causal=False, |
| dropout=dropout_p, |
| softmax_scale=xq.shape[-1] ** -0.5, |
| use_top_left_mask=flash_attn_supports_top_left_mask(), |
| target_dtype=target_dtype, |
| implementation=self.attn_implementation, |
| ) |
| else: |
| raise ValueError( |
| f"Attention implementation {self.attn_implementation} not supported" |
| ) |
|
|
| attn_output = attn_output.to(og_dtype) |
| attn_output = self._merge_heads(attn_output) |
| attn_output = self.wo(attn_output) |
| attn_output = self.residual_dropout(attn_output) |
|
|
| return attn_output |
|
|
|
|
| class MolmoAct2VisionBlock(nn.Module): |
| def __init__( |
| self, config: MolmoAct2VitConfig, device: Union[str, torch.device] = None |
| ): |
| super().__init__() |
| self.attention = ViTMultiHeadDotProductAttention( |
| hidden_size=config.hidden_size, |
| num_heads=config.num_attention_heads, |
| num_key_value_heads=config.num_key_value_heads, |
| head_dim=config.head_dim, |
| float32_attention=config.float32_attention, |
| attention_dropout=config.attention_dropout, |
| residual_dropout=config.residual_dropout, |
| device=device, |
| attn_implementation=config._attn_implementation, |
| ) |
| self.feed_forward = ViTMLP( |
| config.hidden_size, |
| config.intermediate_size, |
| config.hidden_act, |
| device=device, |
| ) |
| self.attention_norm = nn.LayerNorm( |
| config.hidden_size, eps=config.layer_norm_eps, device=device |
| ) |
| self.ffn_norm = nn.LayerNorm( |
| config.hidden_size, eps=config.layer_norm_eps, device=device |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = x + self.attention(self.attention_norm(x)) |
| x = x + self.feed_forward(self.ffn_norm(x)) |
| return x |
|
|
|
|
| class MolmoAct2VisionBlockCollection(nn.Module): |
| def __init__( |
| self, config: MolmoAct2VitConfig, device: Union[str, torch.device] = None |
| ): |
| super().__init__() |
| self.conifg = config |
| self.resblocks = nn.ModuleList( |
| [ |
| MolmoAct2VisionBlock(config, device) |
| for _ in range(config.num_hidden_layers) |
| ] |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> list[torch.Tensor]: |
| hidden_states = [] |
| for r in self.resblocks: |
| x = r(x) |
| hidden_states.append(x) |
| return hidden_states |
|
|
|
|
| class MolmoAct2VisionTransformer(nn.Module): |
| def __init__( |
| self, config: MolmoAct2VitConfig, device: Union[str, torch.device] = None |
| ): |
| super().__init__() |
| self.config = config |
|
|
| |
| self.scale = config.hidden_size**-0.5 |
| self.num_prefix_tokens: int = 0 |
| self.positional_embedding = nn.Parameter( |
| torch.zeros(config.image_num_pos, config.hidden_size, device=device), |
| ) |
|
|
| image_patch_size = config.image_patch_size |
| self.patch_embedding = nn.Linear( |
| image_patch_size * image_patch_size * 3, |
| config.hidden_size, |
| bias=True, |
| device=device, |
| ) |
|
|
| self.transformer = MolmoAct2VisionBlockCollection(config, device) |
|
|
| def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor: |
| pos_emb = self.positional_embedding |
|
|
| pos_emb = pos_emb.reshape( |
| ( |
| int(math.sqrt(pos_emb.shape[0])), |
| int(math.sqrt(pos_emb.shape[0])), |
| pos_emb.shape[1], |
| ) |
| ) |
|
|
| (patch_num_0, patch_num_1) = patch_num |
|
|
| if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1: |
| |
| |
| pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2) |
| pos_emb = F.interpolate( |
| pos_emb, |
| size=(patch_num_0, patch_num_1), |
| mode="bicubic", |
| align_corners=False, |
| antialias=True, |
| ) |
| pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0) |
|
|
| pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1]) |
| x = x + pos_emb[None, :, :].to(x.dtype) |
| return x |
|
|
| def forward(self, x: torch.Tensor, patch_num: int = None) -> list[torch.Tensor]: |
| """ |
| : param x: (batch_size, num_patch, n_pixels) |
| """ |
| if patch_num is None: |
| patch_num = self.config.image_num_patch |
|
|
| B, N, D = x.shape |
|
|
| x = self.patch_embedding(x) |
|
|
| |
| x = self.add_pos_emb(x, patch_num) |
|
|
| hidden_states = self.transformer(x) |
| return hidden_states |
|
|
|
|
| class ImageProjectorMLP(nn.Module): |
| def __init__( |
| self, |
| input_dim: int, |
| hidden_dim: int, |
| output_dim: int, |
| hidden_act: str, |
| device: Union[str, torch.device] = None, |
| ): |
| super().__init__() |
| self.w1 = nn.Linear(input_dim, hidden_dim, bias=False, device=device) |
| self.w2 = nn.Linear(hidden_dim, output_dim, bias=False, device=device) |
| self.w3 = nn.Linear(input_dim, hidden_dim, bias=False, device=device) |
| self.act = ACT2FN[hidden_act] |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.w2(self.act(self.w1(x)) * self.w3(x)) |
|
|
|
|
| class MolmoAct2VisionBackbone(nn.Module): |
| def __init__( |
| self, vit_config: MolmoAct2VitConfig, adapter_config: MolmoAct2AdapterConfig |
| ): |
| super().__init__() |
| self.vit_config = vit_config |
| self.adapter_config = adapter_config |
|
|
| self.vit_layers = [] |
| for layer in adapter_config.vit_layers: |
| if layer >= 0: |
| self.vit_layers.append(layer) |
| else: |
| self.vit_layers.append(layer + vit_config.num_hidden_layers) |
|
|
| last_layer_needed = max(self.vit_layers) + 1 |
| if last_layer_needed < vit_config.num_hidden_layers: |
| new_vit_config = deepcopy(vit_config) |
| new_vit_config.num_hidden_layers = last_layer_needed |
| self.image_vit = MolmoAct2VisionTransformer(new_vit_config) |
| else: |
| self.image_vit = MolmoAct2VisionTransformer(vit_config) |
|
|
| self.num_prefix_tokens: int = self.image_vit.num_prefix_tokens |
|
|
| pool_dim = vit_config.hidden_size * len(adapter_config.vit_layers) |
| self.image_pooling_2d = ViTMultiHeadDotProductAttention( |
| hidden_size=adapter_config.hidden_size, |
| num_heads=adapter_config.num_attention_heads, |
| num_key_value_heads=adapter_config.num_key_value_heads, |
| head_dim=adapter_config.head_dim, |
| input_dim=pool_dim, |
| float32_attention=adapter_config.float32_attention, |
| attention_dropout=adapter_config.attention_dropout, |
| residual_dropout=adapter_config.residual_dropout, |
| attn_implementation=adapter_config._attn_implementation, |
| ) |
| self.image_projector = ImageProjectorMLP( |
| adapter_config.hidden_size, |
| adapter_config.intermediate_size, |
| adapter_config.text_hidden_size, |
| adapter_config.hidden_act, |
| ) |
| self.image_feature_dropout = nn.Dropout(adapter_config.image_feature_dropout) |
|
|
| def encode_image(self, images: torch.Tensor) -> torch.Tensor: |
| """ |
| : param images: (batch_size, num_crops, num_patch, n_pixels) |
| """ |
| B, T, N, D = images.shape |
| images = images.view(B * T, N, D) |
| image_features = self.image_vit(images) |
|
|
| features = [] |
| for layer in self.vit_layers: |
| features.append(image_features[layer]) |
| image_features = torch.cat(features, dim=-1) |
|
|
| if self.num_prefix_tokens > 0: |
| image_features = image_features[:, 1:] |
| image_features = image_features.view(B, T, N, -1) |
| return image_features |
|
|
| @property |
| def dtype(self) -> torch.dtype: |
| return self.image_vit.patch_embedding.weight.dtype |
|
|
| @property |
| def device(self) -> torch.device: |
| return self.image_vit.patch_embedding.weight.device |
|
|
| def forward( |
| self, |
| images: torch.Tensor, |
| pooled_patches_idx: torch.Tensor, |
| ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
|
| |
| batch_size, num_image = images.shape[:2] |
| images = images.to(device=self.device) |
| if images.dtype == torch.uint8: |
| images = images.to(dtype=torch.float32) / 255.0 |
| images = images * 2.0 - 1.0 |
| elif torch.is_floating_point(images): |
| |
| |
| images = torch.round(((images.to(dtype=torch.float32) + 1.0) * 0.5) * 255.0) |
| images = torch.clamp(images, 0.0, 255.0) / 255.0 |
| images = images * 2.0 - 1.0 |
| images = images.to(dtype=self.dtype) |
| image_features = self.encode_image(images) |
|
|
| image_features = self.image_feature_dropout(image_features) |
| dim = image_features.shape[-1] |
| valid = pooled_patches_idx >= 0 |
| valid_token = torch.any(valid, -1) |
|
|
| |
| batch_idx = torch.arange( |
| pooled_patches_idx.shape[0], |
| dtype=torch.long, |
| device=pooled_patches_idx.device, |
| ) |
| batch_idx = torch.tile( |
| batch_idx.view(batch_size, 1, 1), |
| [1, pooled_patches_idx.shape[1], pooled_patches_idx.shape[2]], |
| ) |
|
|
| |
| to_pool = image_features.reshape(batch_size, -1, dim)[ |
| batch_idx, torch.clip(pooled_patches_idx, 0) |
| ] |
| to_pool = to_pool * valid.to(self.dtype)[:, :, :, None] |
| to_pool = to_pool.reshape([-1, pooled_patches_idx.shape[-1], dim]) |
| if self.adapter_config.pooling_attention_mask: |
| attn_mask = valid.reshape([-1, 1, 1, valid.shape[-1]]) |
| denom = valid.view(-1, to_pool.shape[-2]).float().sum(-1) |
| denom = torch.where(denom == 0, 1, denom) |
| query = to_pool.sum(-2, keepdim=True) / denom[:, None, None].to( |
| to_pool.dtype |
| ) |
| else: |
| attn_mask = None |
| query = to_pool.mean(-2, keepdim=True) |
| pooled_features = self.image_pooling_2d(query, to_pool, attn_mask=attn_mask) |
| pooled_features = pooled_features.reshape( |
| [batch_size, -1, pooled_features.shape[-1]] |
| ) |
|
|
| |
| pooled_features = self.image_projector(pooled_features) |
| return pooled_features.view(-1, pooled_features.shape[-1])[ |
| valid_token.flatten() |
| ] |
|
|
|
|
| |
|
|
|
|
| def rotate_half(x): |
| """Rotates half the hidden dims of the input.""" |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| Args: |
| q (`torch.Tensor`): The query tensor. |
| k (`torch.Tensor`): The key tensor. |
| cos (`torch.Tensor`): The cosine part of the rotary embedding. |
| sin (`torch.Tensor`): The sine part of the rotary embedding. |
| position_ids (`torch.Tensor`, *optional*): |
| Deprecated and unused. |
| unsqueeze_dim (`int`, *optional*, defaults to 1): |
| The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
| sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
| that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
| k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
| cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
| the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
| Returns: |
| `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
| """ |
| cos = cos.unsqueeze(unsqueeze_dim) |
| sin = sin.unsqueeze(unsqueeze_dim) |
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
| return q_embed, k_embed |
|
|
|
|
| class MolmoAct2RotaryEmbedding(nn.Module): |
| inv_freq: torch.Tensor |
|
|
| def __init__( |
| self, |
| config: MolmoAct2TextConfig, |
| device: Union[str, torch.device] = None, |
| rope_type: Optional[str] = None, |
| ): |
| super().__init__() |
| if rope_type is not None: |
| self.rope_type = rope_type |
| elif hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): |
| |
| self.rope_type = config.rope_scaling.get( |
| "rope_type", config.rope_scaling.get("type") |
| ) |
| else: |
| self.rope_type = "default" |
| self.max_seq_len_cached = config.max_position_embeddings |
| self.original_max_seq_len = config.max_position_embeddings |
|
|
| self.config = config |
| if self.rope_type == "default": |
| self.rope_init_fn = self._default_rope_init |
| else: |
| self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
|
|
| inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) |
| self.register_buffer("inv_freq", inv_freq, persistent=True) |
| self.original_inv_freq = self.inv_freq |
| self.register_buffer("_pos_sin_cache", torch.empty(0), persistent=False) |
| self.register_buffer("_pos_cos_cache", torch.empty(0), persistent=False) |
|
|
| @staticmethod |
| def _default_rope_init( |
| config: MolmoAct2TextConfig, device: Union[str, torch.device] = None, **_ |
| ) -> tuple[torch.Tensor, float]: |
| inv_freq = 1.0 / ( |
| config.rope_theta |
| ** ( |
| torch.arange(0, config.head_dim, 2, dtype=torch.float32, device=device) |
| / config.head_dim |
| ) |
| ) |
| return inv_freq, 1.0 |
|
|
| def _target_cache_seq_len( |
| self, x: torch.Tensor, position_ids: Optional[torch.Tensor] |
| ) -> int: |
| if self.config.max_position_embeddings: |
| return int(self.config.max_position_embeddings) |
| if position_ids is not None: |
| return int(position_ids.max().item()) + 1 |
| return int(x.shape[-2]) |
|
|
| def _rope_cache_ready(self, device: torch.device, seq_len: int) -> bool: |
| return ( |
| self._pos_sin_cache.numel() > 0 |
| and self._pos_sin_cache.device == device |
| and self._pos_cos_cache.device == device |
| and self._pos_sin_cache.shape[-2] >= seq_len |
| and self._pos_cos_cache.shape[-2] >= seq_len |
| ) |
|
|
| def _refresh_inv_freq_if_needed(self, device: torch.device) -> None: |
| device = torch.device(device) |
| expected = int(self.config.head_dim) // 2 |
| needs_refresh = ( |
| self.inv_freq is None |
| or self._pos_sin_cache.numel() == 0 |
| or self.inv_freq.device.type == "meta" |
| or self.inv_freq.device != device |
| or self.inv_freq.numel() != expected |
| ) |
| if not needs_refresh: |
| inv_freq_cpu = self.inv_freq.detach() |
| needs_refresh = ( |
| not bool(torch.isfinite(inv_freq_cpu).all().item()) |
| or bool((inv_freq_cpu <= 0).any().item()) |
| or not bool( |
| torch.isclose(inv_freq_cpu[0].cpu(), torch.tensor(1.0)).item() |
| ) |
| ) |
| if needs_refresh: |
| inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) |
| self.register_buffer("inv_freq", inv_freq, persistent=True) |
| self.original_inv_freq = self.inv_freq |
| self._pos_sin_cache = torch.empty(0, device=device) |
| self._pos_cos_cache = torch.empty(0, device=device) |
|
|
| def _build_rope_cache(self, device: torch.device, seq_len: int) -> None: |
| device_type = device.type if device.type != "mps" else "cpu" |
| with torch.autocast(device_type=device_type, enabled=False): |
| seq = torch.arange(seq_len, device=device, dtype=torch.float) |
| freqs = torch.einsum( |
| "i,j->ij", seq, self.inv_freq.to(device=device, dtype=torch.float) |
| ) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| self._pos_sin_cache = emb.sin()[None, None, :, :] * self.attention_scaling |
| self._pos_cos_cache = emb.cos()[None, None, :, :] * self.attention_scaling |
|
|
| @torch.no_grad() |
| def prepare_rope_cache( |
| self, |
| *, |
| device: Union[str, torch.device], |
| max_seq_len: Optional[int] = None, |
| ) -> None: |
| if self.rope_type != "default": |
| return |
| device = torch.device(device) |
| seq_len = int(max_seq_len or self.config.max_position_embeddings or 0) |
| if seq_len <= 0: |
| raise ValueError( |
| "RoPE cache preparation requires a positive max sequence length." |
| ) |
| if self._rope_cache_ready(device, seq_len): |
| return |
| self._refresh_inv_freq_if_needed(device) |
| self._build_rope_cache(device, seq_len) |
|
|
| def _select_rope_cache( |
| self, |
| x: torch.Tensor, |
| position_ids: Optional[torch.Tensor], |
| seq_len: int, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| pos_sin = self._pos_sin_cache[:, :, :seq_len, :] |
| pos_cos = self._pos_cos_cache[:, :, :seq_len, :] |
| if position_ids is None: |
| sin = pos_sin[0, 0, : x.shape[-2], :] |
| cos = pos_cos[0, 0, : x.shape[-2], :] |
| else: |
| sin = pos_sin[0, 0][position_ids].view( |
| position_ids.shape + (pos_sin.shape[-1],) |
| ) |
| cos = pos_cos[0, 0][position_ids].view( |
| position_ids.shape + (pos_cos.shape[-1],) |
| ) |
| return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
| @torch.no_grad() |
| @dynamic_rope_update |
| def forward( |
| self, x, position_ids: torch.Tensor |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| seq_len = self._target_cache_seq_len(x, position_ids) |
| if not self._rope_cache_ready(x.device, seq_len): |
| self._refresh_inv_freq_if_needed(x.device) |
| self._build_rope_cache(x.device, seq_len) |
| return self._select_rope_cache(x, position_ids, seq_len) |
|
|
|
|
| class MolmoAct2RMSNorm(nn.Module): |
| def __init__( |
| self, |
| size: int, |
| eps: float = 1e-6, |
| device: Union[str, torch.device] = None, |
| ): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(size, device=device)) |
| self.eps = eps |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| with torch.autocast(enabled=False, device_type=x.device.type): |
| og_dtype = x.dtype |
| x = x.to(torch.float32) |
| variance = x.pow(2).mean(-1, keepdim=True) |
| x = x * torch.rsqrt(variance + self.eps) |
| x = x.to(og_dtype) |
|
|
| return self.weight * x |
|
|
| def extra_repr(self): |
| return f"{tuple(self.weight.shape)}, eps={self.eps}" |
|
|
|
|
| |
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| """ |
| This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
| num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
| """ |
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
| if n_rep == 1: |
| return hidden_states |
| hidden_states = hidden_states[:, :, None, :, :].expand( |
| batch, num_key_value_heads, n_rep, slen, head_dim |
| ) |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
| def eager_attention_forward( |
| module: nn.Module, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| scaling: float, |
| dropout: float = 0.0, |
| **kwargs, |
| ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: |
| key_states = repeat_kv(key, module.num_key_value_groups) |
| value_states = repeat_kv(value, module.num_key_value_groups) |
|
|
| attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling |
| if attention_mask is not None: |
| causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
| attn_weights = attn_weights + causal_mask |
|
|
| attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( |
| query.dtype |
| ) |
| attn_weights = nn.functional.dropout( |
| attn_weights, p=dropout, training=module.training |
| ) |
| attn_output = torch.matmul(attn_weights, value_states) |
| attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
| return attn_output, attn_weights |
|
|
|
|
| class MolmoAct2Attention(nn.Module): |
| """Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
| def __init__(self, config: MolmoAct2TextConfig, layer_idx: int) -> None: |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
| self.num_heads = config.num_attention_heads |
| self.num_key_value_heads = config.num_key_value_heads |
| self.num_key_value_groups = ( |
| config.num_attention_heads // config.num_key_value_heads |
| ) |
| self.head_dim = config.head_dim |
| self.scaling = self.head_dim**-0.5 |
| self.is_causal = True |
|
|
| self.fused_dims = ( |
| config.num_attention_heads * config.head_dim, |
| config.head_dim * config.num_key_value_heads, |
| config.head_dim * config.num_key_value_heads, |
| ) |
| self.att_proj = nn.Linear( |
| config.hidden_size, |
| sum(self.fused_dims), |
| bias=config.qkv_bias, |
| ) |
|
|
| |
| self.k_norm: Optional[MolmoAct2RMSNorm] = None |
| self.q_norm: Optional[MolmoAct2RMSNorm] = None |
| self.qk_norm_type: Optional[str] = None |
| if config.use_qk_norm: |
| k_norm_size = ( |
| config.head_dim |
| if config.qk_norm_type == "qwen3" |
| else config.num_key_value_heads * config.head_dim |
| ) |
| self.k_norm = MolmoAct2RMSNorm(k_norm_size, eps=config.layer_norm_eps) |
| q_norm_size = ( |
| config.head_dim |
| if config.qk_norm_type == "qwen3" |
| else config.num_attention_heads * config.head_dim |
| ) |
| self.q_norm = MolmoAct2RMSNorm(q_norm_size, eps=config.layer_norm_eps) |
| self.qk_norm_type = config.qk_norm_type |
|
|
| self.attention_dropout = config.attention_dropout |
|
|
| self.attn_out = nn.Linear( |
| config.head_dim * config.num_attention_heads, |
| config.hidden_size, |
| bias=False, |
| ) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| attention_mask: Optional[torch.Tensor], |
| past_key_values: Optional[Cache] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs: Unpack[FlashAttentionKwargs], |
| ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: |
| input_shape = hidden_states.shape[:-1] |
| hidden_shape = (*input_shape, -1, self.head_dim) |
|
|
| qkv = self.att_proj(hidden_states) |
| query_states, key_states, value_states = qkv.split(self.fused_dims, dim=-1) |
| value_states = value_states.view(hidden_shape) |
|
|
| |
| if ( |
| self.q_norm is not None |
| and self.k_norm is not None |
| and self.qk_norm_type != "qwen3" |
| ): |
| query_states = self.q_norm(query_states) |
| key_states = self.k_norm(key_states) |
|
|
| query_states = query_states.view(hidden_shape) |
| key_states = key_states.view(hidden_shape) |
| if ( |
| self.q_norm is not None |
| and self.k_norm is not None |
| and self.qk_norm_type == "qwen3" |
| ): |
| query_states = self.q_norm(query_states) |
| key_states = self.k_norm(key_states) |
| query_states = query_states.transpose(1, 2) |
| key_states = key_states.transpose(1, 2) |
| value_states = value_states.transpose(1, 2) |
|
|
| cos, sin = position_embeddings |
| query_states, key_states = apply_rotary_pos_emb( |
| query_states, key_states, cos, sin |
| ) |
|
|
| if past_key_values is not None: |
| |
| cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
| key_states, value_states = past_key_values.update( |
| key_states, value_states, self.layer_idx, cache_kwargs |
| ) |
|
|
| dropout_p = 0.0 if not self.training else self.attention_dropout |
| if self.config._attn_implementation == "sdpa" and ( |
| attention_mask is None or torch.is_tensor(attention_mask) |
| ): |
| key_states = repeat_kv(key_states, self.num_key_value_groups) |
| value_states = repeat_kv(value_states, self.num_key_value_groups) |
| attn_output = F.scaled_dot_product_attention( |
| query_states, |
| key_states, |
| value_states, |
| attn_mask=attention_mask, |
| dropout_p=dropout_p, |
| is_causal=attention_mask is None, |
| ) |
| attn_output = attn_output.transpose(1, 2).contiguous() |
| attn_weights = None |
| else: |
| attention_interface: Callable = eager_attention_forward |
| if self.config._attn_implementation != "eager": |
| attention_interface = ALL_ATTENTION_FUNCTIONS[ |
| self.config._attn_implementation |
| ] |
|
|
| attn_output, attn_weights = attention_interface( |
| self, |
| query_states, |
| key_states, |
| value_states, |
| attention_mask, |
| dropout=dropout_p, |
| scaling=self.scaling, |
| **kwargs, |
| ) |
|
|
| attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
| attn_output = self.attn_out(attn_output) |
| return attn_output, attn_weights |
|
|
|
|
| class LanguageModelMLP(nn.Module): |
| def __init__( |
| self, |
| input_dim: int, |
| intermediate_size: int, |
| hidden_act: str, |
| device: Union[str, torch.device] = None, |
| ): |
| super().__init__() |
| self.ff_proj = nn.Linear( |
| input_dim, intermediate_size * 2, bias=False, device=device |
| ) |
| self.ff_out = nn.Linear(intermediate_size, input_dim, bias=False, device=device) |
| self.act = ACT2FN[hidden_act] |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.ff_proj(x) |
| x, gate = x.chunk(2, dim=-1) |
| x = self.act(gate) * x |
| x = self.ff_out(x) |
| return x |
|
|
|
|
| class MolmoAct2DecoderLayer(GradientCheckpointingLayer): |
| def __init__( |
| self, |
| config: MolmoAct2TextConfig, |
| layer_idx: Optional[int] = None, |
| device: Union[str, torch.device] = None, |
| ): |
| super().__init__() |
| self.config = config |
|
|
| self.self_attn = MolmoAct2Attention(config, layer_idx) |
| self.attn_norm = MolmoAct2RMSNorm( |
| config.hidden_size, eps=config.layer_norm_eps, device=device |
| ) |
| self.dropout = nn.Dropout(config.residual_dropout) |
| self.mlp = LanguageModelMLP( |
| config.hidden_size, |
| config.intermediate_size, |
| config.hidden_act, |
| device=device, |
| ) |
| self.ff_norm = MolmoAct2RMSNorm( |
| config.hidden_size, eps=config.layer_norm_eps, device=device |
| ) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| output_attentions: Optional[bool] = False, |
| use_cache: Optional[bool] = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> tuple[ |
| torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]] |
| ]: |
|
|
| residual = hidden_states |
| hidden_states = self.attn_norm(hidden_states) |
|
|
| |
| hidden_states, self_attn_weights = self.self_attn( |
| hidden_states=hidden_states, |
| position_embeddings=position_embeddings, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
|
|
| hidden_states = residual + self.dropout(hidden_states) |
|
|
| |
| residual = hidden_states |
| hidden_states = self.ff_norm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
|
|
| hidden_states = residual + self.dropout(hidden_states) |
|
|
| outputs = (hidden_states,) |
|
|
| if output_attentions: |
| outputs += (self_attn_weights,) |
|
|
| return outputs |
|
|
|
|
| class MolmoAct2PostNormDecoderLayer(MolmoAct2DecoderLayer): |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| output_attentions: Optional[bool] = False, |
| use_cache: Optional[bool] = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs, |
| ) -> tuple[ |
| torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]] |
| ]: |
|
|
| residual = hidden_states |
|
|
| |
| hidden_states, self_attn_weights = self.self_attn( |
| hidden_states=hidden_states, |
| position_embeddings=position_embeddings, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| ) |
| hidden_states = self.attn_norm(hidden_states) |
|
|
| hidden_states = residual + self.dropout(hidden_states) |
|
|
| |
| residual = hidden_states |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = self.ff_norm(hidden_states) |
|
|
| hidden_states = residual + self.dropout(hidden_states) |
|
|
| outputs = (hidden_states,) |
|
|
| if output_attentions: |
| outputs += (self_attn_weights,) |
|
|
| return outputs |
|
|
|
|
| class MolmoAct2Embedding(nn.Module): |
| def __init__( |
| self, |
| num_embeddings: int, |
| num_new_embeddings: int, |
| features: int, |
| device: Union[str, torch.device] = None, |
| ): |
| super().__init__() |
| self.embedding = nn.Parameter( |
| torch.zeros(num_embeddings, features, device=device), |
| ) |
| self.new_embedding = nn.Parameter( |
| torch.zeros(num_new_embeddings, features, device=device), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return F.embedding(x, torch.cat([self.embedding, self.new_embedding], dim=0)) |
|
|
|
|
| class MolmoAct2PreTrainedModel(PreTrainedModel): |
| config: MolmoAct2Config |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = [ |
| "MolmoAct2DecoderLayer", |
| "MolmoAct2PostNormDecoderLayer", |
| "MolmoAct2VisionBlock", |
| "ViTMultiHeadDotProductAttention", |
| ] |
| _skip_keys_device_placement = "past_key_values" |
| _supports_flash_attn = True |
| _supports_sdpa = True |
|
|
| _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": MolmoAct2DecoderLayer, |
| "attentions": MolmoAct2Attention, |
| } |
|
|
| def _init_weights(self, module): |
| std = self.config.initializer_range |
| if isinstance(module, (nn.Linear,)): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, MolmoAct2Embedding): |
| module.embedding.data.normal_(mean=0.0, std=std) |
| module.new_embedding.data.normal_(mean=0.0, std=std) |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
| elif isinstance(module, MolmoAct2RMSNorm): |
| module.weight.data.fill_(1.0) |
| elif isinstance(module, nn.LayerNorm): |
| module.weight.data.fill_(1.0) |
| if module.bias is not None: |
| module.bias.data.zero_() |
|
|
|
|
| class MolmoAct2TextModel(MolmoAct2PreTrainedModel): |
| config: MolmoAct2TextConfig |
| _no_split_modules = ["MolmoAct2DecoderLayer", "MolmoAct2PostNormDecoderLayer"] |
|
|
| def __init__(self, config: MolmoAct2TextConfig): |
| super().__init__(config) |
| if config.additional_vocab_size is not None: |
| self.wte = MolmoAct2Embedding( |
| config.vocab_size, |
| config.additional_vocab_size, |
| config.hidden_size, |
| ) |
| else: |
| self.wte = nn.Embedding(config.vocab_size, config.hidden_size) |
| self.emb_drop = nn.Dropout(config.embedding_dropout) |
| decoder_layer = ( |
| MolmoAct2PostNormDecoderLayer |
| if config.norm_after |
| else MolmoAct2DecoderLayer |
| ) |
| self.blocks = nn.ModuleList( |
| [ |
| decoder_layer(config, layer_idx) |
| for layer_idx in range(config.num_hidden_layers) |
| ] |
| ) |
| self.ln_f = MolmoAct2RMSNorm(config.hidden_size, eps=config.layer_norm_eps) |
| if config.rope_scaling_layers is not None: |
| self.rotary_embs = nn.ModuleDict( |
| { |
| "default": MolmoAct2RotaryEmbedding(config, rope_type="default"), |
| "scaling": MolmoAct2RotaryEmbedding(config), |
| } |
| ) |
| else: |
| self.rotary_emb = MolmoAct2RotaryEmbedding(config) |
| self.gradient_checkpointing = False |
|
|
| |
| self.post_init() |
|
|
| @torch.no_grad() |
| def prepare_rope_cache( |
| self, |
| *, |
| device: Union[str, torch.device], |
| max_seq_len: Optional[int] = None, |
| ) -> None: |
| if self.config.rope_scaling_layers is not None: |
| for rotary_emb in self.rotary_embs.values(): |
| rotary_emb.prepare_rope_cache(device=device, max_seq_len=max_seq_len) |
| return |
| self.rotary_emb.prepare_rope_cache(device=device, max_seq_len=max_seq_len) |
|
|
| def get_input_embeddings(self) -> torch.nn.Module: |
| return self.wte |
|
|
| def set_input_embeddings(self, value: torch.nn.Module) -> None: |
| self.wte = value |
|
|
| @can_return_tuple |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> BaseModelOutputWithPast: |
| output_attentions = ( |
| output_attentions |
| if output_attentions is not None |
| else self.config.output_attentions |
| ) |
| output_hidden_states = ( |
| output_hidden_states |
| if output_hidden_states is not None |
| else self.config.output_hidden_states |
| ) |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError( |
| "You must specify exactly one of input_ids or inputs_embeds" |
| ) |
|
|
| if self.gradient_checkpointing and self.training and use_cache: |
| logger.warning_once( |
| "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." |
| ) |
| use_cache = False |
|
|
| if inputs_embeds is None: |
| input_ids = input_ids * (input_ids != -1).to(input_ids.dtype) |
| inputs_embeds = self.wte(input_ids) |
|
|
| |
| if use_cache and past_key_values is None and not torch.jit.is_tracing(): |
| past_key_values = DynamicCache(config=self.config) |
|
|
| if cache_position is None: |
| past_seen_tokens = ( |
| past_key_values.get_seq_length() if past_key_values is not None else 0 |
| ) |
| cache_position = torch.arange( |
| past_seen_tokens, |
| past_seen_tokens + inputs_embeds.shape[1], |
| device=inputs_embeds.device, |
| ) |
|
|
| if position_ids is None: |
| position_ids = cache_position.unsqueeze(0) |
|
|
| |
| if torch.is_tensor(attention_mask) and attention_mask.ndim == 4: |
| causal_mask_mapping = attention_mask |
| elif not isinstance(causal_mask_mapping := attention_mask, dict): |
| |
| mask_kwargs = { |
| "config": self.config, |
| "input_embeds": inputs_embeds, |
| "attention_mask": attention_mask, |
| "cache_position": cache_position, |
| "past_key_values": past_key_values, |
| "position_ids": position_ids, |
| } |
|
|
| |
| causal_mask_mapping = create_causal_mask(**mask_kwargs) |
|
|
| hidden_states = inputs_embeds |
|
|
| |
| if self.config.rope_scaling_layers is not None: |
| position_embeddings_mapping = { |
| "default": self.rotary_embs["default"](hidden_states, position_ids), |
| "scaling": self.rotary_embs["scaling"](hidden_states, position_ids), |
| } |
| else: |
| position_embeddings = self.rotary_emb(hidden_states, position_ids) |
|
|
| |
| all_hidden_states = () if output_hidden_states else None |
| all_self_attns = () if output_attentions else None |
|
|
| for layer_idx, decoder_block in enumerate( |
| self.blocks[: self.config.num_hidden_layers] |
| ): |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| if self.config.rope_scaling_layers is not None: |
| position_embeddings_i = ( |
| position_embeddings_mapping["scaling"] |
| if layer_idx in self.config.rope_scaling_layers |
| else position_embeddings_mapping["default"] |
| ) |
| else: |
| position_embeddings_i = position_embeddings |
|
|
| layer_outputs = decoder_block( |
| hidden_states, |
| attention_mask=causal_mask_mapping, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings_i, |
| **kwargs, |
| ) |
|
|
| hidden_states = layer_outputs[0] |
|
|
| if output_attentions: |
| all_self_attns += (layer_outputs[1],) |
|
|
| hidden_states = self.ln_f(hidden_states) |
|
|
| |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=past_key_values, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attns, |
| ) |
|
|
|
|
| |
| def token_type_ids_mask_function( |
| token_type_ids: Optional[torch.Tensor] = None, |
| ) -> Optional[Callable]: |
| """ |
| This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths, |
| not start and end indices. |
| """ |
| |
| if token_type_ids is None: |
| return None |
|
|
| def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: |
| |
| |
| |
| safe_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0) |
| token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_idx] |
| token_type_ids_at_kv_idx = torch.where( |
| kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0 |
| ) |
|
|
| is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & ( |
| token_type_ids_at_kv_idx == 1 |
| ) |
|
|
| |
| return is_image_block & is_image_block |
|
|
| return inner_mask |
|
|
|
|
| class MolmoAct2Model(MolmoAct2PreTrainedModel): |
| base_model_prefix = "" |
| _checkpoint_conversion_mapping = {} |
| |
| accepts_loss_kwargs = False |
| config: MolmoAct2Config |
|
|
| def __init__(self, config: MolmoAct2Config): |
| super().__init__(config) |
| self.transformer: MolmoAct2TextModel = MolmoAct2TextModel(config.text_config) |
| self.vision_backbone: Optional[MolmoAct2VisionBackbone] = None |
| if config.vit_config is not None and config.adapter_config is not None: |
| self.vision_backbone = MolmoAct2VisionBackbone( |
| config.vit_config, config.adapter_config |
| ) |
| llm_kv_dim = ( |
| config.text_config.num_key_value_heads * config.text_config.head_dim |
| ) |
| if config.add_action_expert: |
| self.action_expert = ActionExpert( |
| config.action_expert_config, |
| llm_dim=config.hidden_size, |
| llm_kv_dim=llm_kv_dim, |
| llm_num_layers=config.num_hidden_layers, |
| ) |
| else: |
| self.action_expert = None |
| if config.add_action_expert and config.action_expert_depth_gate: |
| if config.action_expert_depth_gate_per_layer: |
| self.action_expert_depth_gate = nn.ModuleList( |
| nn.Linear(llm_kv_dim, 1) |
| for _ in range(config.action_expert_config.num_layers) |
| ) |
| else: |
| self.action_expert_depth_gate = nn.Linear(llm_kv_dim, 1) |
| self.reset_action_expert_depth_gate_parameters() |
| else: |
| self.action_expert_depth_gate = None |
| self._depth_gate_token_ids = self._resolve_depth_gate_token_ids() |
| self.action_cuda_graph_manager: Optional[ActionCudaGraphManager] = None |
|
|
| |
| self.post_init() |
|
|
| def get_input_embeddings(self) -> torch.nn.Module: |
| return self.transformer.wte |
|
|
| def set_input_embeddings(self, value: torch.nn.Module) -> None: |
| self.transformer.wte = value |
|
|
| def set_decoder(self, decoder): |
| self.transformer = decoder |
|
|
| def get_decoder(self): |
| return self.transformer |
|
|
| @property |
| def device(self) -> torch.device: |
| return self.transformer.ln_f.weight.device |
|
|
| def reset_action_expert_depth_gate_parameters(self) -> None: |
| if self.action_expert_depth_gate is None: |
| return |
| gates = ( |
| self.action_expert_depth_gate |
| if isinstance(self.action_expert_depth_gate, nn.ModuleList) |
| else [self.action_expert_depth_gate] |
| ) |
| for gate in gates: |
| nn.init.zeros_(gate.weight) |
| nn.init.constant_( |
| gate.bias, float(self.config.action_expert_depth_gate_init_bias) |
| ) |
|
|
| def _resolve_depth_gate_token_ids(self) -> Tuple[int, ...]: |
| if not self.config.action_expert_depth_gate: |
| return () |
| token_ids = [] |
| for token_id in ( |
| self.config.depth_output_token_id, |
| self.config.depth_start_token_id, |
| self.config.depth_end_token_id, |
| ): |
| if token_id is not None: |
| token_ids.append(int(token_id)) |
| if ( |
| self.config.depth_token_start_id is not None |
| and int(self.config.num_depth_tokens or 0) > 0 |
| ): |
| start = int(self.config.depth_token_start_id) |
| token_ids.extend(range(start, start + int(self.config.num_depth_tokens))) |
| return tuple(dict.fromkeys(token_ids)) |
|
|
| def _require_action_expert(self) -> ActionExpert: |
| if self.action_expert is None: |
| raise RuntimeError( |
| "This MolmoAct2 checkpoint does not include an action expert." |
| ) |
| return self.action_expert |
|
|
| def _cache_to_sequence(self, cache: torch.Tensor) -> torch.Tensor: |
| if cache.dim() != 4: |
| raise ValueError( |
| f"Expected KV cache tensor with 4 dims, got shape {tuple(cache.shape)}" |
| ) |
| head_candidates = { |
| self.config.text_config.num_key_value_heads, |
| self.config.text_config.num_attention_heads, |
| } |
| if cache.shape[1] in head_candidates: |
| bsz, n_heads, seq_len, head_dim = cache.shape |
| return cache.permute(0, 2, 1, 3).reshape(bsz, seq_len, n_heads * head_dim) |
| if cache.shape[2] in head_candidates: |
| bsz, seq_len, n_heads, head_dim = cache.shape |
| return cache.reshape(bsz, seq_len, n_heads * head_dim) |
| if cache.shape[1] <= cache.shape[2]: |
| bsz, n_heads, seq_len, head_dim = cache.shape |
| return cache.permute(0, 2, 1, 3).reshape(bsz, seq_len, n_heads * head_dim) |
| bsz, seq_len, n_heads, head_dim = cache.shape |
| return cache.reshape(bsz, seq_len, n_heads * head_dim) |
|
|
| def _extract_kv_states( |
| self, past_key_values: Cache |
| ) -> Sequence[Tuple[torch.Tensor, torch.Tensor]]: |
| if past_key_values is None: |
| raise RuntimeError( |
| "Action generation requires past_key_values from the VLM forward pass." |
| ) |
| seq_len = _cache_seq_len_int(past_key_values) |
| kv_states = [] |
| for key, value in _iter_cache_key_values(past_key_values): |
| if key is None or value is None: |
| continue |
| if key.shape[-2] > seq_len: |
| key = key[..., :seq_len, :] |
| value = value[..., :seq_len, :] |
| kv_states.append( |
| (self._cache_to_sequence(key), self._cache_to_sequence(value)) |
| ) |
| if len(kv_states) != self.config.action_expert_config.num_layers: |
| raise RuntimeError( |
| f"Expected {self.config.action_expert_config.num_layers} KV layers, got {len(kv_states)}." |
| ) |
| return kv_states |
|
|
| @staticmethod |
| def _mask_discrete_output_span( |
| row_ids: torch.Tensor, |
| row_mask: torch.Tensor, |
| start_id: Optional[int], |
| end_id: Optional[int], |
| ) -> None: |
| if start_id is None or end_id is None: |
| return |
| start_positions = ( |
| (row_ids == start_id).nonzero(as_tuple=False).flatten().tolist() |
| ) |
| if not start_positions: |
| return |
| end_positions = (row_ids == end_id).nonzero(as_tuple=False).flatten().tolist() |
| end_ptr = 0 |
| for start_pos in start_positions: |
| while end_ptr < len(end_positions) and end_positions[end_ptr] < start_pos: |
| end_ptr += 1 |
| if end_ptr >= len(end_positions): |
| row_mask[start_pos:] = False |
| break |
| end_pos = end_positions[end_ptr] |
| row_mask[start_pos : end_pos + 1] = False |
| end_ptr += 1 |
|
|
| def _get_encoder_attention_mask( |
| self, |
| input_ids: Optional[torch.Tensor], |
| attention_mask: Optional[torch.Tensor], |
| ) -> Optional[torch.Tensor]: |
| if attention_mask is not None: |
| mask = attention_mask.to(dtype=torch.bool).clone() |
| elif input_ids is not None: |
| mask = input_ids != -1 |
| else: |
| return None |
| if self.config.action_format != "both" or input_ids is None: |
| return mask |
| eos_id = getattr(self.config, "eos_token_id", None) |
| if eos_id is not None: |
| mask &= input_ids != int(eos_id) |
| for batch_idx in range(input_ids.shape[0]): |
| self._mask_discrete_output_span( |
| input_ids[batch_idx], |
| mask[batch_idx], |
| self.config.action_start_token_id, |
| self.config.action_end_token_id, |
| ) |
| return mask |
|
|
| def _get_depth_token_mask( |
| self, |
| input_ids: Optional[torch.Tensor], |
| encoder_attention_mask: Optional[torch.Tensor], |
| ) -> Optional[torch.Tensor]: |
| if ( |
| not self.config.action_expert_depth_gate |
| or input_ids is None |
| or not self._depth_gate_token_ids |
| ): |
| return None |
| depth_token_ids = torch.as_tensor( |
| self._depth_gate_token_ids, |
| device=input_ids.device, |
| dtype=input_ids.dtype, |
| ) |
| depth_mask = (input_ids.unsqueeze(-1) == depth_token_ids).any(dim=-1) |
| if encoder_attention_mask is not None: |
| depth_mask = depth_mask & encoder_attention_mask.to( |
| device=input_ids.device, dtype=torch.bool |
| ) |
| return depth_mask |
|
|
| @staticmethod |
| def _depth_gate_from_source( |
| gate_head: nn.Linear, |
| *, |
| source: torch.Tensor, |
| depth_mask: torch.Tensor, |
| encoder_attention_mask: Optional[torch.Tensor], |
| ) -> torch.Tensor: |
| if source.ndim == 4: |
| source = source.reshape(source.shape[0], source.shape[1], -1) |
| if source.ndim != 3: |
| raise ValueError( |
| f"Depth gate expected a 3D sequence tensor, got {tuple(source.shape)}." |
| ) |
| if encoder_attention_mask is not None: |
| valid_mask = encoder_attention_mask.to( |
| device=source.device, dtype=torch.bool |
| ) |
| else: |
| valid_mask = torch.ones( |
| depth_mask.shape, device=source.device, dtype=torch.bool |
| ) |
| depth_mask = depth_mask.to(device=source.device, dtype=torch.bool) |
| pool_mask = valid_mask & ~depth_mask |
| has_pool = pool_mask.any(dim=-1, keepdim=True) |
| pool_mask = torch.where(has_pool, pool_mask, valid_mask) |
| weights = pool_mask.to(dtype=source.dtype).unsqueeze(-1) |
| pooled = (source * weights).sum(dim=1) / weights.sum(dim=1).clamp_min(1.0) |
| gate_logits = gate_head(pooled.to(dtype=gate_head.weight.dtype)) |
| return torch.sigmoid(gate_logits).to(dtype=source.dtype) |
|
|
| def _depth_gate_from_condition( |
| self, |
| *, |
| input_ids: Optional[torch.Tensor], |
| encoder_attention_mask: Optional[torch.Tensor], |
| layer_kv_states: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]], |
| ) -> Tuple[ |
| Optional[Union[torch.Tensor, Sequence[torch.Tensor]]], Optional[torch.Tensor] |
| ]: |
| gate_head = self.action_expert_depth_gate |
| if gate_head is None: |
| return None, None |
| depth_mask = self._get_depth_token_mask(input_ids, encoder_attention_mask) |
| if depth_mask is None or layer_kv_states is None: |
| return None, depth_mask |
| sources = [value for _, value in layer_kv_states] |
| if isinstance(gate_head, nn.ModuleList): |
| if len(gate_head) != len(sources): |
| raise ValueError( |
| f"Depth gate layer count mismatch: gates={len(gate_head)}, sources={len(sources)}." |
| ) |
| gates = [ |
| self._depth_gate_from_source( |
| gate, |
| source=source, |
| depth_mask=depth_mask, |
| encoder_attention_mask=encoder_attention_mask, |
| ) |
| for gate, source in zip(gate_head, sources) |
| ] |
| return gates, depth_mask |
| gate = self._depth_gate_from_source( |
| gate_head, |
| source=sources[-1], |
| depth_mask=depth_mask, |
| encoder_attention_mask=encoder_attention_mask, |
| ) |
| return gate, depth_mask |
|
|
| @staticmethod |
| def _depth_gate_for_layer( |
| gate: Union[torch.Tensor, Sequence[torch.Tensor]], |
| layer_idx: int, |
| *, |
| num_layers: int, |
| ) -> torch.Tensor: |
| if isinstance(gate, torch.Tensor): |
| return gate |
| if len(gate) != num_layers: |
| raise ValueError( |
| f"Depth gate layer count mismatch: gates={len(gate)}, layers={num_layers}." |
| ) |
| return gate[layer_idx] |
|
|
| def _apply_depth_gate_to_layer_kv_states( |
| self, |
| layer_kv_states: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]], |
| depth_mask: Optional[torch.Tensor], |
| gate: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]], |
| ) -> Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]]: |
| if layer_kv_states is None or depth_mask is None or gate is None: |
| return layer_kv_states |
| gated_kv = [] |
| for layer_idx, (key, value) in enumerate(layer_kv_states): |
| layer_gate = self._depth_gate_for_layer( |
| gate, layer_idx, num_layers=len(layer_kv_states) |
| ) |
| mask = depth_mask.to(device=key.device, dtype=torch.bool) |
| view_shape = [mask.shape[0], mask.shape[1]] + [1] * (key.ndim - 2) |
| scale = torch.ones(view_shape, device=key.device, dtype=key.dtype) |
| gate_view = layer_gate.to(device=key.device, dtype=key.dtype).view( |
| layer_gate.shape[0], |
| *([1] * (key.ndim - 1)), |
| ) |
| scale = torch.where(mask.view(view_shape), gate_view, scale) |
| gated_kv.append((key * scale, value * scale)) |
| return gated_kv |
|
|
| @staticmethod |
| def _action_dim_valid_mask( |
| target: torch.Tensor, |
| action_dim_is_pad: Optional[torch.Tensor], |
| ) -> Optional[torch.Tensor]: |
| if action_dim_is_pad is None: |
| return None |
| mask = ~action_dim_is_pad.to(device=target.device, dtype=torch.bool) |
| if mask.ndim == 1: |
| mask = mask.unsqueeze(0) |
| if mask.shape[-1] != target.shape[-1]: |
| raise ValueError( |
| f"action_dim_is_pad width {mask.shape[-1]} does not match target width {target.shape[-1]}." |
| ) |
| if mask.shape[0] == 1 and target.shape[0] != 1: |
| mask = mask.expand(target.shape[0], -1) |
| if mask.shape[0] != target.shape[0]: |
| raise ValueError( |
| f"action_dim_is_pad batch {mask.shape[0]} does not match target batch {target.shape[0]}." |
| ) |
| while mask.ndim < target.ndim: |
| mask = mask.unsqueeze(1) |
| return mask |
|
|
| @classmethod |
| def _mask_action_dim_tensor( |
| cls, |
| tensor: torch.Tensor, |
| *, |
| action_dim_is_pad: Optional[torch.Tensor], |
| enabled: bool, |
| ) -> torch.Tensor: |
| if not enabled: |
| return tensor |
| valid_mask = cls._action_dim_valid_mask(tensor, action_dim_is_pad) |
| if valid_mask is None: |
| return tensor |
| return tensor.masked_fill(~valid_mask, 0) |
|
|
| def _run_action_flow_loop( |
| self, inputs: _ActionFlowInputs, steps: int |
| ) -> torch.Tensor: |
| action_expert = self._require_action_expert() |
| dt = 1.0 / steps |
| trajectory = inputs.trajectory |
| action_dim_is_pad = inputs.action_dim_is_pad |
| mask_enabled = self.config.mask_action_dim_padding |
| for idx in range(steps): |
| velocity = action_expert.forward_with_context( |
| trajectory, |
| inputs.modulations[idx].conditioning, |
| context=inputs.context, |
| modulation=inputs.modulations[idx], |
| ) |
| velocity = self._mask_action_dim_tensor( |
| velocity, |
| action_dim_is_pad=action_dim_is_pad, |
| enabled=mask_enabled, |
| ) |
| trajectory = trajectory + dt * velocity |
| trajectory = self._mask_action_dim_tensor( |
| trajectory, |
| action_dim_is_pad=action_dim_is_pad, |
| enabled=mask_enabled, |
| ) |
| return trajectory |
|
|
| @torch.no_grad() |
| def generate_actions_from_inputs( |
| self, |
| *, |
| input_ids: torch.LongTensor, |
| pixel_values: Optional[torch.Tensor] = None, |
| image_token_pooling: Optional[torch.Tensor] = None, |
| image_grids: Optional[torch.Tensor] = None, |
| image_num_crops: Optional[torch.Tensor] = None, |
| pixel_values_videos: Optional[torch.Tensor] = None, |
| video_token_pooling: Optional[torch.Tensor] = None, |
| video_grids: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| token_type_ids: Optional[torch.LongTensor] = None, |
| states: Optional[torch.Tensor] = None, |
| action_dim_is_pad: Optional[torch.Tensor] = None, |
| num_steps: Optional[int] = None, |
| generator: Optional[torch.Generator] = None, |
| encoder_kv_states: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None, |
| encoder_attention_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| action_expert = self._require_action_expert() |
| if encoder_kv_states is None: |
| outputs = self( |
| input_ids=input_ids, |
| pixel_values=pixel_values, |
| image_token_pooling=image_token_pooling, |
| image_grids=image_grids, |
| image_num_crops=image_num_crops, |
| pixel_values_videos=pixel_values_videos, |
| video_token_pooling=video_token_pooling, |
| video_grids=video_grids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| use_cache=True, |
| ) |
| encoder_kv_states = self._extract_kv_states(outputs.past_key_values) |
| encoder_attention_mask = self._get_encoder_attention_mask( |
| input_ids, attention_mask |
| ) |
| elif encoder_attention_mask is None: |
| encoder_attention_mask = self._get_encoder_attention_mask( |
| input_ids, attention_mask |
| ) |
|
|
| depth_gate, depth_mask = self._depth_gate_from_condition( |
| input_ids=input_ids, |
| encoder_attention_mask=encoder_attention_mask, |
| layer_kv_states=encoder_kv_states, |
| ) |
| encoder_kv_states = self._apply_depth_gate_to_layer_kv_states( |
| encoder_kv_states, |
| depth_mask, |
| depth_gate, |
| ) |
| steps = int(num_steps or self.config.flow_matching_num_steps) |
| if steps <= 0: |
| raise ValueError(f"num_steps must be >= 1, got {steps}.") |
| source_tensor = encoder_kv_states[0][0] |
| batch_size = source_tensor.shape[0] |
| device = source_tensor.device |
| trajectory = torch.randn( |
| (batch_size, self.config.action_horizon, self.config.max_action_dim), |
| device=device, |
| dtype=torch.float32, |
| generator=generator, |
| ) |
| trajectory = self._mask_action_dim_tensor( |
| trajectory, |
| action_dim_is_pad=action_dim_is_pad, |
| enabled=self.config.mask_action_dim_padding, |
| ) |
| action_context = action_expert.prepare_context( |
| encoder_kv_states=encoder_kv_states, |
| encoder_attention_mask=encoder_attention_mask, |
| state_embeddings=states, |
| batch_size=batch_size, |
| seq_len=trajectory.shape[1], |
| device=device, |
| dtype=trajectory.dtype, |
| ) |
| flow_timesteps = [ |
| torch.full( |
| (batch_size,), idx / steps, device=device, dtype=trajectory.dtype |
| ) |
| for idx in range(steps) |
| ] |
| modulation_cache = action_expert.get_or_prepare_modulation_cache( |
| flow_timesteps, |
| cache_key=(steps, batch_size, device, trajectory.dtype), |
| ) |
| flow_inputs = _ActionFlowInputs( |
| trajectory=trajectory, |
| context=action_context, |
| modulations=modulation_cache, |
| action_dim_is_pad=action_dim_is_pad, |
| ) |
| action_cuda_graph_manager = self.action_cuda_graph_manager |
| if ( |
| action_cuda_graph_manager is not None |
| and action_cuda_graph_manager.can_use_action_flow(flow_inputs) |
| ): |
| trajectory = action_cuda_graph_manager.run_action_flow( |
| flow_inputs, steps, self._run_action_flow_loop |
| ) |
| else: |
| trajectory = self._run_action_flow_loop(flow_inputs, steps) |
| return trajectory |
|
|
| def build_batched_images( |
| self, |
| input_ids: torch.LongTensor, |
| pixel_values: torch.Tensor, |
| image_token_pooling: torch.Tensor, |
| image_grids: torch.Tensor, |
| image_num_crops: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| |
| raw_counts = (input_ids == self.config.image_end_token_id).sum(1) |
| total_images = int(image_grids.size(0)) |
| total_end_tokens = int(raw_counts.sum().item()) |
| if total_images <= 0: |
| counts = raw_counts.new_zeros(raw_counts.shape) |
| elif total_end_tokens == total_images: |
| counts = raw_counts |
| elif total_end_tokens == 2 * total_images: |
| counts = raw_counts // 2 |
| else: |
| raise ValueError( |
| "Could not infer image counts from image end tokens: " |
| f"end_tokens={total_end_tokens}, image_grids={total_images}." |
| ) |
| N = counts.size(0) |
| device = input_ids.device |
|
|
| |
| num_images = total_images |
|
|
| |
| assert image_grids.size(0) == num_images, ( |
| f"Expected {num_images} image grids, but got {image_grids.size(0)}" |
| ) |
| assert image_num_crops.size(0) == num_images, ( |
| f"Expected {num_images} image num crops, but got {image_num_crops.size(0)}" |
| ) |
|
|
| |
| with torch.no_grad(): |
| first_prod = image_grids[:, :2].prod(dim=1) |
| second_prod = image_grids[:, 2:].prod(dim=1) |
| num_pooled_patches_per_image = (first_prod + second_prod).to( |
| image_num_crops.dtype |
| ) |
|
|
| |
| n_crops, n_patches, pixels_per_patch = pixel_values.shape |
|
|
| |
| |
| example_ids_for_image = torch.arange(N, device=device).repeat_interleave( |
| counts |
| ) |
| assert example_ids_for_image.numel() == num_images |
|
|
| |
| crops_per_example = torch.zeros( |
| N, dtype=image_num_crops.dtype, device=image_num_crops.device |
| ) |
| crops_per_example.index_add_(0, example_ids_for_image, image_num_crops) |
|
|
| |
| patches_per_image = image_num_crops * n_patches |
|
|
| |
| counts_list = counts.tolist() |
| index_offset_per_example_list = [] |
| offset_img = 0 |
| for c in counts_list: |
| per_img_patches = patches_per_image[offset_img : offset_img + c] |
| |
| index_offset = [0] + per_img_patches.cumsum(0).tolist()[:-1] |
| index_offset_per_example_list.append(index_offset) |
| offset_img += c |
|
|
| |
| num_pooled_patches_per_example = torch.zeros( |
| N, |
| dtype=num_pooled_patches_per_image.dtype, |
| device=num_pooled_patches_per_image.device, |
| ) |
| num_pooled_patches_per_example.index_add_( |
| 0, example_ids_for_image, num_pooled_patches_per_image |
| ) |
|
|
| |
| total_crops = int(crops_per_example.sum().item()) |
| assert total_crops == n_crops, ( |
| f"Expected {total_crops} crops, but got {n_crops}" |
| ) |
|
|
| total_num_pooled_patches = int(num_pooled_patches_per_example.sum().item()) |
| assert total_num_pooled_patches == image_token_pooling.size(0), ( |
| f"Expected {total_num_pooled_patches} pooled patches, but got {image_token_pooling.size(0)}" |
| ) |
|
|
| |
| M = int(crops_per_example.max().item()) |
| images = torch.full( |
| (N, M, n_patches, pixels_per_patch), |
| fill_value=-1, |
| dtype=pixel_values.dtype, |
| device=pixel_values.device, |
| ) |
|
|
| |
| offset_crop = 0 |
| for i in range(N): |
| num = int(crops_per_example[i].item()) |
| cur = pixel_values[ |
| offset_crop : offset_crop + num |
| ] |
| images[i, :num] = cur |
| offset_crop += num |
|
|
| |
| assert offset_crop == n_crops |
|
|
| |
| P = int(num_pooled_patches_per_example.max().item()) |
| _, dim = image_token_pooling.shape |
| new_token_pooling = torch.full( |
| (N, P, dim), |
| fill_value=-1, |
| dtype=image_token_pooling.dtype, |
| device=image_token_pooling.device, |
| ) |
|
|
| |
| patch_offset = 0 |
| img_offset = 0 |
|
|
| for i, c in enumerate(counts_list): |
| num_patches = int(num_pooled_patches_per_example[i].item()) |
|
|
| |
| cur = image_token_pooling[ |
| patch_offset : patch_offset + num_patches |
| ].clone() |
|
|
| index_offset_per_example = index_offset_per_example_list[i] |
| per_img_pooled = num_pooled_patches_per_image[ |
| img_offset : img_offset + c |
| ] |
|
|
| assert len(index_offset_per_example) == per_img_pooled.numel() |
|
|
| |
| offset = 0 |
| for j in range(c): |
| index_offset = int(index_offset_per_example[j]) |
| n = int(per_img_pooled[j].item()) |
| cur_slice = cur[offset : offset + n] |
|
|
| |
| cur[offset : offset + n] = torch.where( |
| cur_slice >= 0, |
| cur_slice + index_offset, |
| cur_slice, |
| ) |
| offset += n |
|
|
| new_token_pooling[i, :num_patches] = cur |
|
|
| patch_offset += num_patches |
| img_offset += c |
|
|
| |
| assert patch_offset == total_num_pooled_patches |
| assert img_offset == num_images |
|
|
| return images, new_token_pooling |
|
|
| def build_batched_videos( |
| self, |
| input_ids: torch.LongTensor, |
| pixel_values_videos: torch.Tensor, |
| video_token_pooling: torch.Tensor, |
| video_grids: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
| |
| if self.config.use_frame_special_tokens: |
| end_token_id = self.config.frame_end_token_id |
| else: |
| end_token_id = self.config.image_end_token_id |
| counts = (input_ids == end_token_id).any(dim=1).long() |
| N = counts.size(0) |
| device = input_ids.device |
|
|
| |
| num_videos = int(counts.sum().item()) |
|
|
| |
| assert video_grids.size(0) == num_videos, ( |
| f"Expected {num_videos} videos, but got {video_grids.size(0)}" |
| ) |
|
|
| video_num_frames = video_grids[:, 0] |
| num_pooled_patches_per_video = video_grids.prod(dim=1) |
|
|
| |
| n_frames, n_patches, pixels_per_patch = pixel_values_videos.shape |
|
|
| |
| |
| example_ids_for_video = torch.arange(N, device=device).repeat_interleave( |
| counts |
| ) |
| assert example_ids_for_video.numel() == num_videos |
|
|
| |
| frames_per_example = torch.zeros( |
| N, |
| dtype=video_num_frames.dtype, |
| device=device, |
| ) |
| frames_per_example.index_add_(0, example_ids_for_video, video_num_frames) |
|
|
| |
| num_pooled_patches_per_example = torch.zeros( |
| N, |
| dtype=num_pooled_patches_per_video.dtype, |
| device=num_pooled_patches_per_video.device, |
| ) |
| num_pooled_patches_per_example.index_add_( |
| 0, |
| example_ids_for_video, |
| num_pooled_patches_per_video, |
| ) |
|
|
| |
| total_frames = int(frames_per_example.sum().item()) |
| assert total_frames == n_frames, ( |
| f"Expected {total_frames} frames, but got {n_frames}" |
| ) |
|
|
| total_num_pooled_patches = int(num_pooled_patches_per_example.sum().item()) |
| assert total_num_pooled_patches == video_token_pooling.size(0), ( |
| f"Expected {total_num_pooled_patches} pooled patches, but got {video_token_pooling.size(0)}" |
| ) |
|
|
| |
| M = int(frames_per_example.max().item()) |
| videos = torch.full( |
| (N, M, n_patches, pixels_per_patch), |
| fill_value=-1, |
| dtype=pixel_values_videos.dtype, |
| device=device, |
| ) |
|
|
| |
| offset_frame = 0 |
| for i in range(N): |
| num = int(frames_per_example[i].item()) |
| cur = pixel_values_videos[ |
| offset_frame : offset_frame + num |
| ] |
| videos[i, :num] = cur |
| offset_frame += num |
|
|
| |
| assert offset_frame == n_frames |
|
|
| |
| P = int(num_pooled_patches_per_example.max().item()) |
| _, dim = video_token_pooling.shape |
| new_token_pooling = torch.full( |
| (N, P, dim), |
| fill_value=-1, |
| dtype=video_token_pooling.dtype, |
| device=video_token_pooling.device, |
| ) |
|
|
| |
| patch_offset = 0 |
| for i in range(N): |
| num_patches = int(num_pooled_patches_per_example[i].item()) |
| cur = video_token_pooling[ |
| patch_offset : patch_offset + num_patches |
| ] |
| new_token_pooling[i, :num_patches] = cur |
| patch_offset += num_patches |
|
|
| |
| assert patch_offset == total_num_pooled_patches |
|
|
| return videos, new_token_pooling |
|
|
| def merge_visual_inputs( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| pixel_values: Optional[torch.Tensor] = None, |
| image_token_pooling: Optional[torch.Tensor] = None, |
| image_grids: Optional[torch.Tensor] = None, |
| image_num_crops: Optional[torch.Tensor] = None, |
| pixel_values_videos: Optional[torch.Tensor] = None, |
| video_token_pooling: Optional[torch.Tensor] = None, |
| video_grids: Optional[torch.Tensor] = None, |
| ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: |
| if pixel_values is not None and pixel_values_videos is not None: |
| raise ValueError( |
| "pixel_values and pixel_values_videos are provided at the same time" |
| ) |
| elif pixel_values is not None: |
| assert input_ids is not None |
| images, token_pooling = self.build_batched_images( |
| input_ids=input_ids, |
| pixel_values=pixel_values, |
| image_token_pooling=image_token_pooling, |
| image_grids=image_grids, |
| image_num_crops=image_num_crops, |
| ) |
| elif pixel_values_videos is not None: |
| assert input_ids is not None |
| images, token_pooling = self.build_batched_videos( |
| input_ids=input_ids, |
| pixel_values_videos=pixel_values_videos, |
| video_token_pooling=video_token_pooling, |
| video_grids=video_grids, |
| ) |
| else: |
| images, token_pooling = None, None |
| return images, token_pooling |
|
|
| def build_input_embeddings( |
| self, |
| input_ids: torch.LongTensor, |
| images: Optional[torch.FloatTensor] = None, |
| token_pooling: Optional[torch.LongTensor] = None, |
| ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
|
| |
| |
| input_ids = input_ids * (input_ids != -1).to(input_ids.dtype) |
| x = self.transformer.wte(input_ids) |
|
|
| image_features: Optional[torch.FloatTensor] = None |
| if images is not None: |
| image_features = self.vision_backbone(images, token_pooling).to(x.device) |
| is_image_patch = input_ids.view(-1) == self.config.image_patch_id |
| assert is_image_patch.sum() == len(image_features) |
| x.view(-1, x.shape[-1])[is_image_patch] += image_features |
|
|
| |
| x = self.transformer.emb_drop(x) |
|
|
| return x, image_features |
|
|
| def _build_native_attention_bias( |
| self, |
| *, |
| inputs_embeds: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| token_type_ids: Optional[torch.Tensor], |
| past_key_values: Optional[Cache], |
| ) -> torch.Tensor: |
| if attention_mask is not None and attention_mask.ndim == 4: |
| return attention_mask.to(device=inputs_embeds.device) |
| batch_size, seq_len = inputs_embeds.shape[:2] |
| past_length = _cache_seq_len_int(past_key_values) |
| current_length = past_length + int(seq_len) |
| max_cache_len = _cache_max_len_int(past_key_values) |
| attention_mask_len = max_cache_len if max_cache_len > 0 else current_length |
| device = inputs_embeds.device |
|
|
| if attention_mask is None: |
| positions = torch.arange(attention_mask_len, device=device) |
| valid_mask = positions.unsqueeze(0) < current_length |
| valid_mask = valid_mask.expand(batch_size, -1) |
| elif attention_mask.ndim == 2: |
| valid_mask = torch.zeros( |
| (batch_size, attention_mask_len), device=device, dtype=torch.bool |
| ) |
| source_mask = attention_mask.to(device=device, dtype=torch.bool) |
| copy_len = min(int(source_mask.shape[-1]), attention_mask_len) |
| if copy_len > 0: |
| valid_mask[:, :copy_len] = source_mask[:, :copy_len] |
| if attention_mask_len > current_length: |
| valid_mask[:, current_length:] = False |
| else: |
| raise ValueError( |
| f"Unsupported attention_mask shape for MolmoAct2: {tuple(attention_mask.shape)}" |
| ) |
|
|
| valid_mask = valid_mask[:, None, None, :] |
| causal_mask = torch.tril( |
| torch.ones( |
| attention_mask_len, attention_mask_len, device=device, dtype=torch.bool |
| ) |
| )[None, None, past_length:current_length, :attention_mask_len] |
|
|
| if token_type_ids is not None and past_length == 0: |
| image_mask = token_type_ids.to(device=device, dtype=torch.bool) |
| can_attend_back = image_mask[:, :, None] & image_mask[:, None, :] |
| image_len = min(int(token_type_ids.shape[1]), attention_mask_len) |
| causal_mask[:, :, :, :image_len] = ( |
| causal_mask[:, :, :, :image_len] |
| | can_attend_back[:, None, :, :image_len] |
| ) |
|
|
| allowed = valid_mask & causal_mask |
| return torch.where( |
| allowed, |
| torch.zeros((), device=device, dtype=inputs_embeds.dtype), |
| torch.full( |
| (), |
| torch.finfo(inputs_embeds.dtype).min, |
| device=device, |
| dtype=inputs_embeds.dtype, |
| ), |
| ) |
|
|
| @can_return_tuple |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| image_token_pooling: Optional[torch.Tensor] = None, |
| image_grids: Optional[torch.Tensor] = None, |
| image_num_crops: Optional[torch.Tensor] = None, |
| pixel_values_videos: Optional[torch.Tensor] = None, |
| video_token_pooling: Optional[torch.Tensor] = None, |
| video_grids: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| past_key_values: Optional[Cache] = None, |
| token_type_ids: Optional[torch.LongTensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> Union[tuple, MolmoAct2ModelOutputWithPast]: |
|
|
| output_attentions = ( |
| output_attentions |
| if output_attentions is not None |
| else self.config.output_attentions |
| ) |
| output_hidden_states = ( |
| output_hidden_states |
| if output_hidden_states is not None |
| else self.config.output_hidden_states |
| ) |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError( |
| "You must specify exactly one of input_ids or inputs_embeds" |
| ) |
|
|
| images, token_pooling = self.merge_visual_inputs( |
| input_ids=input_ids, |
| pixel_values=pixel_values, |
| image_token_pooling=image_token_pooling, |
| image_grids=image_grids, |
| image_num_crops=image_num_crops, |
| pixel_values_videos=pixel_values_videos, |
| video_token_pooling=video_token_pooling, |
| video_grids=video_grids, |
| ) |
|
|
| if images is not None and inputs_embeds is not None: |
| raise ValueError( |
| "You cannot specify both images and inputs_embeds at the same time." |
| ) |
|
|
| if inputs_embeds is None: |
| inputs_embeds, image_features = self.build_input_embeddings( |
| input_ids, |
| images, |
| token_pooling, |
| ) |
|
|
| if cache_position is None: |
| past_seen_tokens = _cache_seq_len_int(past_key_values) |
| cache_position = torch.arange( |
| past_seen_tokens, |
| past_seen_tokens + inputs_embeds.shape[1], |
| device=inputs_embeds.device, |
| ) |
|
|
| if isinstance(attention_mask, dict): |
| causal_mask_mapping = attention_mask |
| else: |
| causal_mask_mapping = self._build_native_attention_bias( |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| past_key_values=past_key_values, |
| ) |
|
|
| outputs = self.transformer( |
| attention_mask=causal_mask_mapping, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
|
|
| return MolmoAct2ModelOutputWithPast( |
| last_hidden_state=outputs.last_hidden_state, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| image_hidden_states=image_features if images is not None else None, |
| ) |
|
|
|
|
| class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixin): |
| _checkpoint_conversion_mapping = {} |
| _tied_weights_keys = [] |
| |
| accepts_loss_kwargs = False |
| config: MolmoAct2Config |
|
|
| def __init__(self, config: MolmoAct2Config): |
| super().__init__(config) |
|
|
| self.model = MolmoAct2Model(config) |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| self.vocab_size = config.vocab_size |
| self.model.action_cuda_graph_manager = ActionCudaGraphManager(self.model) |
| self.depth_decode_cuda_graph_manager = DepthDecodeCudaGraphManager(self) |
| |
| self.post_init() |
|
|
| def get_input_embeddings(self) -> torch.nn.Module: |
| return self.model.transformer.wte |
|
|
| def set_input_embeddings(self, value: torch.nn.Module) -> None: |
| self.model.transformer.wte = value |
|
|
| def set_decoder(self, decoder): |
| self.model.set_decoder(decoder) |
|
|
| def get_decoder(self): |
| return self.model.get_decoder() |
|
|
| |
| @property |
| def language_model(self) -> torch.nn.Module: |
| return self.model.transformer |
|
|
| @property |
| def vision_backbone(self) -> torch.nn.Module: |
| return self.model.vision_backbone |
|
|
| def _get_robot_stats(self) -> _RobotStats: |
| stats = getattr(self, "_molmoact2_robot_stats", None) |
| if stats is not None: |
| return stats |
| filename = getattr(self.config, "norm_stats_filename", "norm_stats.json") |
| base_dir = getattr(self.config, "_name_or_path", None) or getattr( |
| self, "name_or_path", None |
| ) |
| if not base_dir: |
| raise ValueError( |
| "MolmoAct2 normalization stats are not loaded and config._name_or_path is empty; " |
| "load the model from a converted HF directory containing norm_stats.json." |
| ) |
| stats_path = os.path.join(str(base_dir), filename) |
| if not os.path.isfile(stats_path): |
| raise FileNotFoundError( |
| f"MolmoAct2 normalization stats file is missing: {stats_path}. " |
| "Converted checkpoints must include norm_stats.json." |
| ) |
| with open(stats_path, "r", encoding="utf-8") as f: |
| payload = json.load(f) |
| stats = _RobotStats(payload) |
| self._molmoact2_robot_stats = stats |
| return stats |
|
|
| @staticmethod |
| def _move_inputs_to_device( |
| inputs: Mapping[str, Any], device: torch.device |
| ) -> Dict[str, Any]: |
| out = {} |
| for key, value in inputs.items(): |
| out[key] = value.to(device) if torch.is_tensor(value) else value |
| return out |
|
|
| @staticmethod |
| def _drop_trivial_attention_mask(inputs: Mapping[str, Any]) -> Dict[str, Any]: |
| out = dict(inputs) |
| attention_mask = out.get("attention_mask") |
| if torch.is_tensor(attention_mask) and bool( |
| attention_mask.to(dtype=torch.bool).all().item() |
| ): |
| out.pop("attention_mask", None) |
| return out |
|
|
| @staticmethod |
| def _count_images(images: Any) -> int: |
| if images is None: |
| return 0 |
| if isinstance(images, (list, tuple)): |
| return len(images) |
| arr = np.asarray(images) if not torch.is_tensor(images) else images |
| if getattr(arr, "ndim", 0) == 4: |
| return int(arr.shape[0]) |
| return 1 |
|
|
| @staticmethod |
| def _build_action_dim_is_pad( |
| *, |
| action_dim: int, |
| max_action_dim: int, |
| batch_size: int, |
| device: torch.device, |
| ) -> Optional[torch.Tensor]: |
| if int(action_dim) > int(max_action_dim): |
| raise ValueError( |
| f"Requested action_dim {int(action_dim)} exceeds checkpoint max_action_dim {int(max_action_dim)}." |
| ) |
| if int(action_dim) == int(max_action_dim): |
| return None |
| mask = torch.ones( |
| (int(batch_size), int(max_action_dim)), device=device, dtype=torch.bool |
| ) |
| mask[:, : int(action_dim)] = False |
| return mask |
|
|
| @staticmethod |
| def _slice_action_dim(actions: torch.Tensor, action_dim: int) -> torch.Tensor: |
| if actions.shape[-1] < int(action_dim): |
| raise ValueError( |
| f"Requested action_dim {int(action_dim)} but chunk only has width {actions.shape[-1]}." |
| ) |
| return actions[..., : int(action_dim)] |
|
|
| @staticmethod |
| def _slice_action_chunk( |
| actions: torch.Tensor, n_obs_steps: int, n_action_steps: Optional[int] |
| ) -> torch.Tensor: |
| if n_action_steps is None: |
| return actions |
| start = int(n_obs_steps) - 1 |
| end = start + int(n_action_steps) |
| if end > actions.shape[1]: |
| raise ValueError( |
| f"Requested actions up to {end} but model produced horizon {actions.shape[1]}." |
| ) |
| return actions[:, start:end] |
|
|
| def _depth_token_id_to_bin(self) -> Dict[int, int]: |
| if ( |
| self.config.depth_token_start_id is None |
| or int(self.config.num_depth_tokens or 0) <= 0 |
| ): |
| return {} |
| start = int(self.config.depth_token_start_id) |
| return {start + idx: idx for idx in range(int(self.config.num_depth_tokens))} |
|
|
| def _action_token_id_to_bin(self) -> Dict[int, int]: |
| if ( |
| self.config.action_token_start_id is None |
| or int(self.config.num_action_tokens or 0) <= 0 |
| ): |
| return {} |
| start = int(self.config.action_token_start_id) |
| return {start + idx: idx for idx in range(int(self.config.num_action_tokens))} |
|
|
| def _require_eos_token_id(self) -> int: |
| eos_token_id = getattr(self.config, "eos_token_id", None) |
| if ( |
| eos_token_id is None |
| and getattr(self, "generation_config", None) is not None |
| ): |
| eos_token_id = getattr(self.generation_config, "eos_token_id", None) |
| if isinstance(eos_token_id, (list, tuple)): |
| eos_token_id = eos_token_id[0] if eos_token_id else None |
| if eos_token_id is None: |
| raise RuntimeError( |
| "Discrete action generation requires `eos_token_id` in the converted HF config." |
| ) |
| return int(eos_token_id) |
|
|
| def _decode_depth_bins_from_token_ids( |
| self, token_ids: torch.Tensor |
| ) -> torch.Tensor: |
| if ( |
| self.config.depth_start_token_id is None |
| or self.config.depth_end_token_id is None |
| ): |
| raise RuntimeError( |
| "Depth generation requires <depth_start>/<depth_end> token IDs." |
| ) |
| token_id_to_bin = self._depth_token_id_to_bin() |
| if not token_id_to_bin: |
| raise RuntimeError( |
| "Depth generation requires indexed depth tokens in the converted config." |
| ) |
| depth_token_bins = _extract_discrete_token_bins( |
| _flatten_generated_token_ids(token_ids), |
| int(self.config.depth_start_token_id), |
| int(self.config.depth_end_token_id), |
| token_id_to_bin, |
| ) |
| if not depth_token_bins: |
| raise RuntimeError( |
| "Model generated no decodable depth tokens between <depth_start>/<depth_end>." |
| ) |
| return torch.as_tensor([depth_token_bins], device=self.device, dtype=torch.long) |
|
|
| def _consume_generation_tokens( |
| self, |
| token_ids: torch.Tensor, |
| *, |
| past_key_values: Optional[Cache], |
| attention_mask: Optional[torch.Tensor], |
| ) -> Tuple[MolmoAct2CausalLMOutputWithPast, Optional[torch.Tensor]]: |
| if token_ids.ndim == 1: |
| next_input_ids = token_ids.unsqueeze(1) |
| elif token_ids.ndim == 2: |
| next_input_ids = token_ids |
| else: |
| raise ValueError( |
| f"Expected token_ids to have rank 1 or 2, got {tuple(token_ids.shape)}." |
| ) |
| next_attention_mask = attention_mask |
| if next_attention_mask is not None: |
| past_length = _cache_seq_len_int(past_key_values) |
| required_len = int(past_length) + int(next_input_ids.shape[1]) |
| if int(next_attention_mask.shape[-1]) < required_len: |
| pad_len = required_len - int(next_attention_mask.shape[-1]) |
| next_attention_mask = torch.cat( |
| ( |
| next_attention_mask, |
| next_attention_mask.new_ones( |
| (next_input_ids.shape[0], pad_len) |
| ), |
| ), |
| dim=-1, |
| ) |
| past_length = _cache_seq_len_int(past_key_values) |
| output = self( |
| input_ids=next_input_ids, |
| attention_mask=next_attention_mask, |
| past_key_values=past_key_values, |
| use_cache=True, |
| cache_position=( |
| torch.arange( |
| past_length, |
| past_length + int(next_input_ids.shape[1]), |
| device=next_input_ids.device, |
| ) |
| if past_key_values is not None |
| else None |
| ), |
| ) |
| return output, next_attention_mask |
|
|
| def _make_depth_decode_attention_bias( |
| self, inputs: Mapping[str, Any], past_key_values: Cache |
| ) -> torch.Tensor: |
| layers = getattr(past_key_values, "layers", None) |
| max_cache_len = int(getattr(layers[0], "max_cache_len", 0)) if layers else 0 |
| if max_cache_len <= 0: |
| raise RuntimeError( |
| "Depth decode fast path requires a cache with a fixed maximum length." |
| ) |
| input_ids = inputs["input_ids"] |
| batch_size = int(input_ids.shape[0]) |
| device = input_ids.device |
| dtype = self.lm_head.weight.dtype |
|
|
| positions = torch.arange(max_cache_len, device=device, dtype=torch.long) |
| valid_mask = torch.ones( |
| (batch_size, max_cache_len), device=device, dtype=torch.bool |
| ) |
| attention_mask = inputs.get("attention_mask") |
| if attention_mask is not None: |
| source_mask = attention_mask.to(device=device, dtype=torch.bool) |
| copy_len = min(int(source_mask.shape[-1]), max_cache_len) |
| if copy_len > 0: |
| valid_mask[:, :copy_len] = source_mask[:, :copy_len] |
| causal_mask = positions[None, :] <= positions[:, None] |
| allowed = causal_mask.unsqueeze(0) & valid_mask[:, None, :] |
| attention_bias = torch.where( |
| allowed[:, None, :, :], |
| torch.zeros((), device=device, dtype=dtype), |
| torch.full((), torch.finfo(dtype).min, device=device, dtype=dtype), |
| ) |
| return attention_bias |
|
|
| def _embed_base_tokens(self, input_ids: torch.Tensor) -> torch.Tensor: |
| |
| |
| wte = self.model.transformer.wte |
| base_embedding = getattr(wte, "embedding", None) |
| if base_embedding is None: |
| return wte(input_ids) |
| return F.embedding(input_ids, base_embedding) |
|
|
| def _run_depth_decode_step( |
| self, |
| token_ids: torch.Tensor, |
| *, |
| past_key_values: Cache, |
| attention_bias: torch.Tensor, |
| ) -> Tuple[torch.Tensor, Cache]: |
| if token_ids.ndim == 1: |
| next_input_ids = token_ids.unsqueeze(1) |
| elif token_ids.ndim == 2: |
| next_input_ids = token_ids |
| else: |
| raise ValueError( |
| f"Expected token_ids to have rank 1 or 2, got {tuple(token_ids.shape)}." |
| ) |
| past_length = _cache_seq_len_int(past_key_values) |
| end = past_length + int(next_input_ids.shape[1]) |
| if self.depth_decode_cuda_graph_manager.can_use( |
| next_input_ids, |
| past_key_values=past_key_values, |
| attention_bias=attention_bias, |
| ): |
| return self.depth_decode_cuda_graph_manager.run( |
| next_input_ids, |
| past_key_values=past_key_values, |
| attention_bias=attention_bias, |
| past_length=past_length, |
| ) |
| cache_position = torch.arange( |
| past_length, end, device=next_input_ids.device, dtype=torch.long |
| ) |
| attention_bias = attention_bias[:, :, past_length:end, :end] |
| inputs_embeds = self._embed_base_tokens(next_input_ids) |
| outputs = self.model.transformer( |
| attention_mask=attention_bias, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=True, |
| output_attentions=False, |
| output_hidden_states=False, |
| cache_position=cache_position, |
| ) |
| return outputs.last_hidden_state[:, -1:, :], outputs.past_key_values |
|
|
| def _project_depth_logits(self, last_hidden: torch.Tensor) -> torch.Tensor: |
| start = int(self.config.depth_token_start_id) |
| end_id = start + int(self.config.num_depth_tokens) |
| return F.linear(last_hidden, self.lm_head.weight[start:end_id]) |
|
|
| def _max_depth_decode_steps(self) -> int: |
| return max( |
| int(self.config.num_depth_codes or 0) + 8, |
| int(self.config.action_horizon or 1) * 16, |
| 1, |
| ) |
|
|
| def _make_depth_static_cache(self, inputs: Mapping[str, Any]) -> Cache: |
| prompt_len = inputs["input_ids"].shape[1] |
| action_horizon = int(self.config.action_horizon or 1) |
| max_end_steps = max(8, action_horizon) |
| action_token_budget = max(1, action_horizon * 16) |
| return self.depth_decode_cuda_graph_manager.make_static_cache( |
| max_cache_len=prompt_len |
| + self._max_depth_decode_steps() |
| + max_end_steps |
| + action_token_budget, |
| ) |
|
|
| def _continue_discrete_generation_from_output( |
| self, |
| initial_output: MolmoAct2CausalLMOutputWithPast, |
| *, |
| past_key_values: Optional[Cache], |
| attention_mask: Optional[torch.Tensor], |
| end_token_id: int, |
| max_steps: int, |
| ) -> torch.Tensor: |
| generated_tokens: List[torch.Tensor] = [] |
| current_output = initial_output |
| current_past_key_values = past_key_values |
| current_attention_mask = attention_mask |
| hit_end = False |
| for _ in range(int(max_steps)): |
| next_token = torch.argmax(current_output.logits[:, -1, :], dim=-1) |
| generated_tokens.append(next_token) |
| if bool((next_token == int(end_token_id)).all()): |
| hit_end = True |
| break |
| current_output, current_attention_mask = self._consume_generation_tokens( |
| next_token, |
| past_key_values=current_past_key_values, |
| attention_mask=current_attention_mask, |
| ) |
| current_past_key_values = current_output.past_key_values |
| if not generated_tokens: |
| raise RuntimeError("Discrete continuation generated no tokens.") |
| if not hit_end: |
| raise RuntimeError( |
| f"Discrete continuation did not emit end token {int(end_token_id)} within {int(max_steps)} steps." |
| ) |
| return torch.stack(generated_tokens, dim=1) |
|
|
| def _generate_depth_prefix( |
| self, |
| inputs: Mapping[str, Any], |
| *, |
| latest_first_image: Optional[np.ndarray], |
| depth_cache: Optional[Mapping[str, Any]], |
| enable_adaptive_depth: bool, |
| ) -> _DepthPrefix: |
| if ( |
| self.config.depth_start_token_id is None |
| or self.config.depth_end_token_id is None |
| ): |
| raise RuntimeError( |
| "Depth reasoning requires single-token <depth_start>/<depth_end>." |
| ) |
| if ( |
| self.config.depth_token_start_id is None |
| or int(self.config.num_depth_tokens or 0) <= 0 |
| ): |
| raise RuntimeError("Depth reasoning requires indexed depth tokens.") |
| batch_size = int(inputs["input_ids"].shape[0]) |
| if batch_size != 1 and enable_adaptive_depth: |
| raise ValueError("enable_adaptive_depth=True currently supports batch size 1.") |
| static_cache = self._make_depth_static_cache(inputs) |
| output = self(**inputs, use_cache=True, past_key_values=static_cache) |
| current_output = output |
| current_past_key_values = output.past_key_values |
| current_attention_mask = inputs.get("attention_mask") |
| generated_tokens: List[torch.Tensor] = [] |
|
|
| if not enable_adaptive_depth: |
| hit_depth_end = False |
| max_steps = self._max_depth_decode_steps() |
| for _ in range(max_steps): |
| next_token = torch.argmax(current_output.logits[:, -1, :], dim=-1) |
| generated_tokens.append(next_token) |
| current_output, current_attention_mask = ( |
| self._consume_generation_tokens( |
| next_token, |
| past_key_values=current_past_key_values, |
| attention_mask=current_attention_mask, |
| ) |
| ) |
| current_past_key_values = current_output.past_key_values |
| if bool((next_token == int(self.config.depth_end_token_id)).all()): |
| hit_depth_end = True |
| break |
| if not generated_tokens: |
| raise RuntimeError("Depth generation produced no tokens.") |
| if not hit_depth_end: |
| raise RuntimeError( |
| f"Depth generation did not emit <depth_end> within {max_steps} steps." |
| ) |
| depth_token_ids = torch.stack(generated_tokens, dim=1) |
| full_input_ids = torch.cat([inputs["input_ids"], depth_token_ids], dim=1) |
| full_attention_mask = None |
| if current_attention_mask is not None: |
| full_attention_mask = current_attention_mask[ |
| :, : full_input_ids.shape[1] |
| ] |
| encoder_kv_states = self.model._extract_kv_states(current_past_key_values) |
| return _DepthPrefix( |
| token_ids=depth_token_ids, |
| depth_bins=self._decode_depth_bins_from_token_ids(depth_token_ids), |
| full_input_ids=full_input_ids, |
| attention_mask=full_attention_mask, |
| encoder_kv_states=encoder_kv_states, |
| next_output=current_output, |
| past_key_values=current_past_key_values, |
| ) |
|
|
| depth_start = torch.full( |
| (batch_size,), |
| int(self.config.depth_start_token_id), |
| device=self.device, |
| dtype=torch.long, |
| ) |
| code_token_ids = torch.arange( |
| int(self.config.depth_token_start_id), |
| int(self.config.depth_token_start_id) + int(self.config.num_depth_tokens), |
| device=self.device, |
| dtype=torch.long, |
| ) |
| depth_attention_bias = self._make_depth_decode_attention_bias( |
| inputs, current_past_key_values |
| ) |
| generated_tokens.append(depth_start) |
| last_hidden, current_past_key_values = self._run_depth_decode_step( |
| depth_start, |
| past_key_values=current_past_key_values, |
| attention_bias=depth_attention_bias, |
| ) |
| previous_image = None |
| previous_bins = None |
| if depth_cache is not None: |
| previous_image = depth_cache.get("image") |
| previous_bins = depth_cache.get("depth_bins") |
| selective = ( |
| bool(enable_adaptive_depth) |
| and latest_first_image is not None |
| and previous_image is not None |
| and previous_bins is not None |
| ) |
| update_mask = None |
| previous_buffer_t = None |
| if selective: |
| previous_buffer = np.asarray(previous_bins, dtype=np.int64).reshape(-1) |
| if previous_buffer.shape[0] == int(self.config.num_depth_codes): |
| update_mask = _compute_depth_update_mask( |
| latest_first_image, |
| _normalize_image_for_cache(previous_image), |
| num_depth_codes=int(self.config.num_depth_codes), |
| ) |
| previous_buffer_t = ( |
| torch.from_numpy(previous_buffer) |
| .to( |
| device=self.device, |
| dtype=torch.long, |
| ) |
| .unsqueeze(0) |
| ) |
| else: |
| selective = False |
|
|
| depth_bins = torch.zeros( |
| (batch_size, int(self.config.num_depth_codes)), |
| device=self.device, |
| dtype=torch.long, |
| ) |
| num_depth_codes = int(self.config.num_depth_codes) |
| if not selective or update_mask is None or previous_buffer_t is None: |
| for depth_idx in range(num_depth_codes): |
| depth_logits = self._project_depth_logits(last_hidden) |
| predicted_bins = depth_logits.squeeze(1).argmax(dim=-1) |
| depth_bins[:, depth_idx] = predicted_bins |
| chosen_token_ids = code_token_ids[predicted_bins] |
| generated_tokens.append(chosen_token_ids) |
| last_hidden, current_past_key_values = self._run_depth_decode_step( |
| chosen_token_ids, |
| past_key_values=current_past_key_values, |
| attention_bias=depth_attention_bias, |
| ) |
| else: |
| for start_idx, end_idx, should_generate in _build_depth_update_spans( |
| update_mask |
| ): |
| if should_generate: |
| for depth_idx in range(start_idx, end_idx): |
| depth_logits = self._project_depth_logits(last_hidden) |
| predicted_bins = depth_logits.squeeze(1).argmax(dim=-1) |
| depth_bins[:, depth_idx] = predicted_bins |
| chosen_token_ids = code_token_ids[predicted_bins] |
| generated_tokens.append(chosen_token_ids) |
| last_hidden, current_past_key_values = ( |
| self._run_depth_decode_step( |
| chosen_token_ids, |
| past_key_values=current_past_key_values, |
| attention_bias=depth_attention_bias, |
| ) |
| ) |
| continue |
| replay_bins = previous_buffer_t[:, start_idx:end_idx].expand( |
| batch_size, -1 |
| ) |
| depth_bins[:, start_idx:end_idx] = replay_bins |
| replay_token_ids = code_token_ids[replay_bins] |
| generated_tokens.extend(replay_token_ids.unbind(dim=1)) |
| last_hidden, current_past_key_values = self._run_depth_decode_step( |
| replay_token_ids, |
| past_key_values=current_past_key_values, |
| attention_bias=depth_attention_bias, |
| ) |
| hit_depth_end = False |
| max_depth_end_steps = max(8, int(self.config.action_horizon or 1)) |
| full_logits = self.lm_head(last_hidden) |
| for _ in range(max_depth_end_steps): |
| next_token = full_logits.squeeze(1).argmax(dim=-1) |
| generated_tokens.append(next_token) |
| last_hidden, current_past_key_values = self._run_depth_decode_step( |
| next_token, |
| past_key_values=current_past_key_values, |
| attention_bias=depth_attention_bias, |
| ) |
| full_logits = self.lm_head(last_hidden) |
| if bool((next_token == int(self.config.depth_end_token_id)).all()): |
| hit_depth_end = True |
| break |
| if not hit_depth_end: |
| raise RuntimeError( |
| f"Depth generation did not emit <depth_end> within {max_depth_end_steps} steps " |
| "after adaptive depth tokens." |
| ) |
|
|
| depth_token_ids = torch.stack(generated_tokens, dim=1) |
| full_input_ids = torch.cat([inputs["input_ids"], depth_token_ids], dim=1) |
| attention_mask = inputs.get("attention_mask") |
| if attention_mask is not None: |
| full_attention_mask = torch.cat( |
| (attention_mask, attention_mask.new_ones(depth_token_ids.shape)), |
| dim=-1, |
| )[:, : full_input_ids.shape[1]] |
| else: |
| full_attention_mask = None |
| current_output = MolmoAct2CausalLMOutputWithPast( |
| logits=full_logits, |
| past_key_values=current_past_key_values, |
| ) |
| encoder_kv_states = self.model._extract_kv_states(current_past_key_values) |
| return _DepthPrefix( |
| token_ids=depth_token_ids, |
| depth_bins=depth_bins, |
| full_input_ids=full_input_ids, |
| attention_mask=full_attention_mask, |
| encoder_kv_states=encoder_kv_states, |
| next_output=current_output, |
| past_key_values=current_past_key_values, |
| ) |
|
|
| def _decode_discrete_action_chunk( |
| self, |
| generated_token_ids: torch.Tensor, |
| *, |
| action_tokenizer: Any, |
| action_dim: int, |
| ) -> torch.Tensor: |
| if action_tokenizer is None: |
| raise ValueError( |
| "action_mode='discrete' requires an `action_tokenizer` input." |
| ) |
| if ( |
| self.config.action_start_token_id is None |
| or self.config.action_end_token_id is None |
| ): |
| raise RuntimeError( |
| "Discrete action generation requires <action_start>/<action_end> token IDs." |
| ) |
| token_id_to_bin = self._action_token_id_to_bin() |
| if not token_id_to_bin: |
| raise RuntimeError( |
| "Discrete action generation requires indexed action tokens in the converted config." |
| ) |
| discrete_token_ids = _extract_discrete_token_bins( |
| _flatten_generated_token_ids(generated_token_ids), |
| int(self.config.action_start_token_id), |
| int(self.config.action_end_token_id), |
| token_id_to_bin, |
| ) |
| if not discrete_token_ids: |
| raise RuntimeError( |
| "Model generated no decodable action tokens between <action_start>/<action_end>." |
| ) |
| try: |
| decoded = action_tokenizer.decode( |
| [discrete_token_ids], |
| time_horizon=int(self.config.action_horizon), |
| action_dim=int(action_dim), |
| ) |
| except TypeError: |
| decoded = action_tokenizer.decode([discrete_token_ids]) |
| action_chunk = np.asarray(decoded, dtype=np.float32) |
| if action_chunk.ndim == 1: |
| action_chunk = action_chunk[None, None, :] |
| elif action_chunk.ndim == 2: |
| action_chunk = action_chunk[None, :, :] |
| elif action_chunk.ndim > 3: |
| action_chunk = action_chunk.reshape( |
| 1, action_chunk.shape[-2], action_chunk.shape[-1] |
| ) |
| if action_chunk.ndim != 3: |
| raise RuntimeError( |
| f"Decoded action chunk has unexpected shape {action_chunk.shape}." |
| ) |
| return torch.as_tensor(action_chunk, device=self.device, dtype=torch.float32) |
|
|
| @torch.no_grad() |
| def predict_action( |
| self, |
| *, |
| processor: Any, |
| images: Any, |
| task: str, |
| state: Any, |
| norm_tag: str, |
| action_mode: str = "continuous", |
| enable_depth_reasoning: bool = False, |
| enable_adaptive_depth: bool = True, |
| depth_cache: Optional[Mapping[str, Any]] = None, |
| action_tokenizer: Any = None, |
| num_steps: Optional[int] = None, |
| n_action_steps: Optional[int] = None, |
| generator: Optional[torch.Generator] = None, |
| normalize_language: bool = True, |
| enable_cuda_graph: bool = True, |
| return_dict: bool = True, |
| ) -> Union[MolmoAct2ActionOutput, torch.Tensor]: |
| if not bool(self.config.add_action_expert): |
| raise RuntimeError( |
| "This MolmoAct2 checkpoint was converted with add_action_expert=False; " |
| "use standard Transformers generation for VLM inference." |
| ) |
| if state is None: |
| raise ValueError( |
| "MolmoAct2 `predict_action` requires `state` for discrete state prompting." |
| ) |
| action_mode = str(action_mode or "continuous") |
| if action_mode not in {"continuous", "discrete"}: |
| raise ValueError("action_mode must be either 'continuous' or 'discrete'.") |
| if action_mode == "continuous" and self.config.action_format not in { |
| "continuous", |
| "both", |
| }: |
| raise ValueError( |
| f"action_mode='continuous' requires checkpoint action_format in {{'continuous', 'both'}}, " |
| f"got {self.config.action_format!r}." |
| ) |
| if action_mode == "discrete": |
| if action_tokenizer is None: |
| raise ValueError( |
| "action_mode='discrete' requires an `action_tokenizer` input." |
| ) |
| if self.config.action_format not in {"discrete", "both"}: |
| raise ValueError( |
| f"action_mode='discrete' requires checkpoint action_format in {{'discrete', 'both'}}, " |
| f"got {self.config.action_format!r}." |
| ) |
| if enable_depth_reasoning and not bool(self.config.enable_depth_reasoning): |
| raise ValueError( |
| "this model was not trained with `--enable_depth_reasoning`." |
| ) |
|
|
| stats = self._get_robot_stats() |
| norm_tag = stats.validate_tag(norm_tag) |
| metadata = stats.get_metadata(norm_tag) |
| normalized_state = np.asarray( |
| stats.normalize_state(state, norm_tag), dtype=np.float32 |
| ) |
| num_state_tokens = int(self.config.num_state_tokens or 0) |
| if num_state_tokens <= 0: |
| raise RuntimeError( |
| "Discrete state prompting requires indexed state tokens in the converted config." |
| ) |
| discrete_state_string = _build_discrete_state_string( |
| normalized_state, num_state_tokens |
| ) |
| style = "robot_depth_action" if enable_depth_reasoning else "robot_action" |
| task_text = str(task or "") |
| if normalize_language: |
| task_text = _normalize_question_text(task_text) |
| text = _build_robot_text( |
| task=task_text, |
| style=style, |
| discrete_state_string=discrete_state_string, |
| setup_type=str(metadata.get("setup_type", "") or ""), |
| control_mode=str(metadata.get("control_mode", "") or ""), |
| add_setup_tokens=bool(self.config.add_setup_tokens), |
| add_control_tokens=bool(self.config.add_control_tokens), |
| num_images=self._count_images(images), |
| ) |
| inputs = processor(text=text, images=images, return_tensors="pt") |
| inputs = self._move_inputs_to_device(inputs, self.device) |
| inputs = self._drop_trivial_attention_mask(inputs) |
|
|
| action_dim = stats.get_action_dim(norm_tag) |
| if action_dim is None: |
| action_dim = int(self.config.max_action_dim) |
| action_dim = int(action_dim) |
| action_horizon = stats.get_action_horizon(norm_tag) or int( |
| self.config.action_horizon |
| ) |
| if int(action_horizon) > int(self.config.action_horizon): |
| raise ValueError( |
| f"Tag action_horizon={int(action_horizon)} exceeds checkpoint action_horizon={int(self.config.action_horizon)}." |
| ) |
| resolved_n_action_steps = n_action_steps |
| if resolved_n_action_steps is None: |
| resolved_n_action_steps = stats.get_n_action_steps(norm_tag) |
| if resolved_n_action_steps is None: |
| resolved_n_action_steps = int(action_horizon) |
| resolved_n_action_steps = int(resolved_n_action_steps) |
| if resolved_n_action_steps < 1: |
| raise ValueError( |
| f"n_action_steps must be >= 1, got {resolved_n_action_steps}." |
| ) |
| if resolved_n_action_steps > int(action_horizon): |
| raise ValueError( |
| f"Requested n_action_steps={resolved_n_action_steps} exceeds tag action_horizon={int(action_horizon)}." |
| ) |
| batch_size = int(inputs["input_ids"].shape[0]) |
| action_dim_is_pad = self._build_action_dim_is_pad( |
| action_dim=action_dim, |
| max_action_dim=int(self.config.max_action_dim), |
| batch_size=batch_size, |
| device=self.device, |
| ) |
| self.model.action_cuda_graph_manager.set_enabled(enable_cuda_graph) |
| self.depth_decode_cuda_graph_manager.set_enabled(enable_cuda_graph) |
|
|
| generated_token_ids = None |
| depth_bins = None |
| updated_depth_cache = depth_cache |
| if action_mode == "continuous": |
| if enable_depth_reasoning: |
| latest_first_image = _extract_first_image(images) |
| depth_prefix = self._generate_depth_prefix( |
| inputs, |
| latest_first_image=latest_first_image, |
| depth_cache=depth_cache, |
| enable_adaptive_depth=bool(enable_adaptive_depth), |
| ) |
| generated_token_ids = depth_prefix.token_ids |
| depth_bins = depth_prefix.depth_bins |
| actions = self.model.generate_actions_from_inputs( |
| input_ids=depth_prefix.full_input_ids, |
| attention_mask=depth_prefix.attention_mask, |
| action_dim_is_pad=action_dim_is_pad, |
| num_steps=num_steps, |
| generator=generator, |
| encoder_kv_states=depth_prefix.encoder_kv_states, |
| encoder_attention_mask=self.model._get_encoder_attention_mask( |
| depth_prefix.full_input_ids, |
| depth_prefix.attention_mask, |
| ), |
| ) |
| if latest_first_image is not None: |
| updated_depth_cache = { |
| "image": latest_first_image, |
| "depth_bins": depth_bins.detach() |
| .cpu() |
| .reshape(-1) |
| .numpy() |
| .astype(np.int64), |
| } |
| else: |
| actions = self.model.generate_actions_from_inputs( |
| **inputs, |
| action_dim_is_pad=action_dim_is_pad, |
| num_steps=num_steps, |
| generator=generator, |
| ) |
| else: |
| if enable_depth_reasoning: |
| latest_first_image = _extract_first_image(images) |
| depth_prefix = self._generate_depth_prefix( |
| inputs, |
| latest_first_image=latest_first_image, |
| depth_cache=depth_cache, |
| enable_adaptive_depth=bool(enable_adaptive_depth), |
| ) |
| action_token_ids = self._continue_discrete_generation_from_output( |
| depth_prefix.next_output, |
| past_key_values=depth_prefix.past_key_values, |
| attention_mask=depth_prefix.attention_mask, |
| end_token_id=self._require_eos_token_id(), |
| max_steps=max(1, int(self.config.action_horizon * 16)), |
| ) |
| generated_token_ids = torch.cat( |
| [depth_prefix.token_ids, action_token_ids], dim=1 |
| ) |
| depth_bins = depth_prefix.depth_bins |
| if latest_first_image is not None: |
| updated_depth_cache = { |
| "image": latest_first_image, |
| "depth_bins": depth_bins.detach() |
| .cpu() |
| .reshape(-1) |
| .numpy() |
| .astype(np.int64), |
| } |
| else: |
| prefill_output = self(**inputs, use_cache=True) |
| action_token_ids = self._continue_discrete_generation_from_output( |
| prefill_output, |
| past_key_values=prefill_output.past_key_values, |
| attention_mask=inputs.get("attention_mask"), |
| end_token_id=self._require_eos_token_id(), |
| max_steps=max(1, int(self.config.action_horizon * 16)), |
| ) |
| generated_token_ids = action_token_ids |
| actions = self._decode_discrete_action_chunk( |
| generated_token_ids, |
| action_tokenizer=action_tokenizer, |
| action_dim=action_dim, |
| ) |
|
|
| actions = self._slice_action_dim(actions, action_dim) |
| actions = self._slice_action_chunk( |
| actions, int(self.config.n_obs_steps), resolved_n_action_steps |
| ) |
| actions = stats.unnormalize_action(actions, norm_tag) |
| if not torch.is_tensor(actions): |
| actions = torch.as_tensor(actions, device=self.device, dtype=torch.float32) |
| else: |
| actions = actions.to(device=self.device, dtype=torch.float32) |
| output = MolmoAct2ActionOutput( |
| actions=actions, |
| generated_token_ids=generated_token_ids, |
| depth_bins=depth_bins, |
| depth_cache=updated_depth_cache, |
| ) |
| if return_dict: |
| return output |
| return actions |
|
|
| @can_return_tuple |
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| pixel_values: Optional[torch.Tensor] = None, |
| image_token_pooling: Optional[torch.Tensor] = None, |
| image_grids: Optional[torch.Tensor] = None, |
| image_num_crops: Optional[torch.Tensor] = None, |
| pixel_values_videos: Optional[torch.Tensor] = None, |
| video_token_pooling: Optional[torch.Tensor] = None, |
| video_grids: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[list[torch.FloatTensor]] = None, |
| token_type_ids: Optional[torch.LongTensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| logits_to_keep: Union[int, torch.Tensor] = 0, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> Union[tuple, MolmoAct2CausalLMOutputWithPast]: |
| r""" |
| ```python |
| >>> from PIL import Image |
| >>> import requests |
| >>> from transformers import AutoProcessor, MolmoAct2ForConditionalGeneration |
| |
| >>> model = MolmoAct2ForConditionalGeneration.from_pretrained("...") |
| >>> processor = AutoProcessor.from_pretrained("...") |
| |
| >>> prompt = "What's the content of the image?" |
| >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" |
| >>> image = Image.open(requests.get(url, stream=True).raw) |
| |
| >>> messages = [{"role": "user", "content": [{"type": "text", "text": prompt}, {"type": "image", "image": image}]}] |
| |
| >>> inputs = processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True) |
| |
| >>> # Generate |
| >>> generated_ids = model.generate(**inputs, max_new_tokens=15) |
| >>> generated_tokens = generated_ids[:, inputs['input_ids'].size(1):] |
| >>> processor.post_process_image_text_to_text(generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| "The image shows a bustling street scene in what appears to be a Chinatown area. There's ..." |
| ```""" |
| outputs = self.model( |
| input_ids=input_ids, |
| pixel_values=pixel_values, |
| image_token_pooling=image_token_pooling, |
| image_grids=image_grids, |
| image_num_crops=image_num_crops, |
| pixel_values_videos=pixel_values_videos, |
| video_token_pooling=video_token_pooling, |
| video_grids=video_grids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| token_type_ids=token_type_ids, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
|
|
| hidden_states = outputs.last_hidden_state |
| |
| slice_indices = ( |
| slice(-logits_to_keep, None) |
| if isinstance(logits_to_keep, int) |
| else logits_to_keep |
| ) |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
| loss = None |
| if labels is not None: |
| loss = self.loss_function( |
| logits=logits, labels=labels, vocab_size=self.vocab_size |
| ) |
|
|
| return MolmoAct2CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| image_hidden_states=outputs.image_hidden_states, |
| ) |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids: torch.LongTensor, |
| past_key_values: Optional[list[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| image_token_pooling: Optional[torch.Tensor] = None, |
| image_grids: Optional[torch.Tensor] = None, |
| image_num_crops: Optional[torch.Tensor] = None, |
| pixel_values_videos: Optional[torch.Tensor] = None, |
| video_token_pooling: Optional[torch.Tensor] = None, |
| video_grids: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| token_type_ids: Optional[torch.LongTensor] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| logits_to_keep: Optional[Union[int, torch.Tensor]] = None, |
| **kwargs, |
| ): |
|
|
| model_inputs = super().prepare_inputs_for_generation( |
| input_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| cache_position=cache_position, |
| logits_to_keep=logits_to_keep, |
| token_type_ids=token_type_ids, |
| **kwargs, |
| ) |
|
|
| include_visual_inputs = past_key_values is None |
| if past_key_values is not None and hasattr(past_key_values, "get_seq_length"): |
| include_visual_inputs = int(past_key_values.get_seq_length()) == 0 |
| if include_visual_inputs: |
| model_inputs["pixel_values"] = pixel_values |
| model_inputs["image_token_pooling"] = image_token_pooling |
| model_inputs["image_grids"] = image_grids |
| model_inputs["image_num_crops"] = image_num_crops |
| model_inputs["pixel_values_videos"] = pixel_values_videos |
| model_inputs["video_token_pooling"] = video_token_pooling |
| model_inputs["video_grids"] = video_grids |
|
|
| return model_inputs |
|
|
| |
| @staticmethod |
| def create_masks_for_generate( |
| config: PretrainedConfig, |
| input_embeds: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| cache_position: torch.Tensor, |
| past_key_values: Optional[Cache], |
| position_ids: Optional[torch.Tensor], |
| token_type_ids: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> dict: |
| |
| mask_kwargs = { |
| "config": config.get_text_config(), |
| "input_embeds": input_embeds, |
| "attention_mask": attention_mask, |
| "cache_position": cache_position, |
| "past_key_values": past_key_values, |
| "position_ids": position_ids, |
| } |
| |
| if token_type_ids is not None and input_embeds.shape[1] != 1: |
| |
| mask_kwargs["or_mask_function"] = token_type_ids_mask_function( |
| token_type_ids.to(cache_position.device) |
| ) |
|
|
| return create_masks_for_generate(**mask_kwargs) |
|
|
|
|
| |
| AutoModelForImageTextToText.register(MolmoAct2Config, MolmoAct2ForConditionalGeneration) |
|
|