| """Inference utilities for MolmoAct2""" |
|
|
| from dataclasses import dataclass |
| from typing import Any, Iterable, Optional, Sequence, Tuple |
|
|
| import torch |
| from torch.nn import functional as F |
| from transformers.cache_utils import Cache |
| from transformers.configuration_utils import PretrainedConfig |
|
|
|
|
| @dataclass |
| class _ActionFlowInputs: |
| trajectory: torch.Tensor |
| context: Any |
| modulations: Sequence[Any] |
| action_dim_is_pad: Optional[torch.Tensor] |
|
|
|
|
| @dataclass |
| class _ActionFlowCudaGraph: |
| key: Tuple[Any, ...] |
| graph: torch.cuda.CUDAGraph |
| static_inputs: _ActionFlowInputs |
| output: torch.Tensor |
|
|
|
|
| @dataclass |
| class _DepthDecodeCudaGraphLayerStage: |
| residual: torch.Tensor |
| query: torch.Tensor |
| key: torch.Tensor |
| value: torch.Tensor |
|
|
|
|
| @dataclass |
| class _DepthDecodeCudaGraphPostStage: |
| graph: torch.cuda.CUDAGraph |
| attn_context: torch.Tensor |
|
|
|
|
| @dataclass |
| class _DepthDecodeCudaGraph: |
| cache_key: Tuple[Any, ...] |
| pre_graph: torch.cuda.CUDAGraph |
| token_ids: torch.Tensor |
| cos: torch.Tensor |
| sin: torch.Tensor |
| positions: torch.Tensor |
| stages: Sequence[_DepthDecodeCudaGraphLayerStage] |
| post_graphs: Sequence[_DepthDecodeCudaGraphPostStage] |
| output: torch.Tensor |
|
|
|
|
| @dataclass |
| class _DepthDecodeCudaGraphSpec: |
| eligible: bool |
| cache_key_prefix: Tuple[Any, ...] |
| num_hidden_layers: int |
| head_dim: int |
| num_attention_heads: int |
|
|
|
|
| def _cache_seq_len_int(past_key_values: Optional[Cache]) -> int: |
| if past_key_values is None: |
| return 0 |
| seq_len = past_key_values.get_seq_length() |
| if torch.is_tensor(seq_len): |
| return int(seq_len.item()) |
| return int(seq_len) |
|
|
|
|
| def _cache_max_len_int(past_key_values: Optional[Cache]) -> int: |
| if past_key_values is None: |
| return -1 |
| max_len = past_key_values.get_max_cache_shape() |
| if torch.is_tensor(max_len): |
| return int(max_len.item()) |
| return int(max_len) |
|
|
|
|
| def _iter_cache_key_values( |
| past_key_values: Cache, |
| ) -> Iterable[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]]: |
| layers = getattr(past_key_values, "layers", None) |
| if layers is not None: |
| for layer in layers: |
| yield getattr(layer, "keys", None), getattr(layer, "values", None) |
| return |
| for layer in past_key_values: |
| yield layer[0], layer[1] |
|
|
|
|
| class _DepthDecodeStaticLayerCache: |
| is_compileable = False |
| is_sliding = False |
|
|
| def __init__(self, max_cache_len: int) -> None: |
| self.max_cache_len = int(max_cache_len) |
| self.cumulative_length = 0 |
| self.keys: Optional[torch.Tensor] = None |
| self.values: Optional[torch.Tensor] = None |
|
|
| def _allocate(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None: |
| bsz, n_heads = key_states.shape[:2] |
| self.keys = torch.empty( |
| (bsz, n_heads, self.max_cache_len, key_states.shape[-1]), |
| dtype=key_states.dtype, |
| device=key_states.device, |
| ) |
| self.values = torch.empty( |
| (bsz, n_heads, self.max_cache_len, value_states.shape[-1]), |
| dtype=value_states.dtype, |
| device=value_states.device, |
| ) |
|
|
| def update( |
| self, |
| key_states: torch.Tensor, |
| value_states: torch.Tensor, |
| *args, |
| **kwargs, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| if self.keys is None: |
| self._allocate(key_states, value_states) |
| start = self.cumulative_length |
| end = start + key_states.shape[-2] |
| if end > self.max_cache_len: |
| raise RuntimeError( |
| f"KV cache length {end} exceeds max_cache_len={self.max_cache_len}." |
| ) |
| self.keys[:, :, start:end, :].copy_(key_states) |
| self.values[:, :, start:end, :].copy_(value_states) |
| self.cumulative_length = end |
| return self.keys[:, :, :end, :], self.values[:, :, :end, :] |
|
|
| def get_seq_length(self) -> int: |
| return self.cumulative_length |
|
|
| def get_max_cache_shape(self) -> int: |
| return -1 |
|
|
| def reset(self) -> None: |
| self.cumulative_length = 0 |
|
|
|
|
| class _DepthDecodeStaticCache(Cache): |
| def __init__(self, config: PretrainedConfig, max_cache_len: int) -> None: |
| text_config = config.get_text_config(decoder=True) |
| super().__init__( |
| layers=[ |
| _DepthDecodeStaticLayerCache(max_cache_len=max_cache_len) |
| for _ in range(text_config.num_hidden_layers) |
| ] |
| ) |
|
|
| def get_seq_length(self, layer_idx: int = 0) -> int: |
| return self.layers[layer_idx].get_seq_length() |
|
|
| def get_max_cache_shape(self, layer_idx: int = 0) -> int: |
| return self.layers[layer_idx].get_max_cache_shape() |
|
|
| def reset(self) -> None: |
| for layer in self.layers: |
| layer.reset() |
|
|
|
|
| class ActionCudaGraphManager: |
| def __init__(self, model: Any) -> None: |
| self.model = model |
| self.enabled = True |
| self.action_flow_graph: Optional[_ActionFlowCudaGraph] = None |
|
|
| def set_enabled(self, enabled: bool) -> None: |
| self.enabled = bool(enabled) |
|
|
| def can_use_action_flow(self, inputs: _ActionFlowInputs) -> bool: |
| action_model = self.model |
| if not self.enabled: |
| return False |
| if action_model.training or action_model._require_action_expert().training: |
| return False |
| if inputs.trajectory.device.type != "cuda": |
| return False |
|
|
| def all_on_cuda(): |
| yield inputs.trajectory |
| for k, v in inputs.context.kv_contexts: |
| yield k |
| yield v |
| for t in ( |
| inputs.context.cross_mask, |
| inputs.context.self_mask, |
| inputs.context.valid_action, |
| inputs.action_dim_is_pad, |
| ): |
| if t is not None: |
| yield t |
| if inputs.context.rope_cache is not None: |
| yield from inputs.context.rope_cache |
| for step in inputs.modulations: |
| yield step.conditioning |
| for block_modulation in step.block_modulations: |
| yield from block_modulation |
| yield from step.final_modulation |
|
|
| return all(t.device.type == "cuda" for t in all_on_cuda()) |
|
|
| def run_action_flow( |
| self, |
| inputs: _ActionFlowInputs, |
| steps: int, |
| run_loop, |
| ) -> torch.Tensor: |
| key = _cuda_graph_key(inputs, steps) |
| cache = self.action_flow_graph |
| if cache is None or cache.key != key: |
| static_inputs = _clone_static_inputs(inputs) |
| graph, output = _capture_cuda_graph( |
| lambda: run_loop(static_inputs, steps), |
| inputs.trajectory.device, |
| after_warmup=lambda: static_inputs.trajectory.copy_(inputs.trajectory), |
| ) |
| cache = _ActionFlowCudaGraph( |
| key=key, |
| graph=graph, |
| static_inputs=static_inputs, |
| output=output, |
| ) |
| self.action_flow_graph = cache |
| else: |
| _copy_inputs_(cache.static_inputs, inputs) |
|
|
| cache.graph.replay() |
| return cache.output.clone() |
|
|
|
|
| class DepthDecodeCudaGraphManager: |
| def __init__(self, model: Any) -> None: |
| self.model = model |
| self.backbone = model.model |
| self.enabled = True |
| self.graph: Optional[_DepthDecodeCudaGraph] = None |
| self.graph_spec: Optional[_DepthDecodeCudaGraphSpec] = None |
|
|
| def set_enabled(self, enabled: bool) -> None: |
| self.enabled = bool(enabled) |
|
|
| def make_static_cache(self, max_cache_len: int) -> _DepthDecodeStaticCache: |
| return _DepthDecodeStaticCache( |
| config=self.model.config.text_config, |
| max_cache_len=max_cache_len, |
| ) |
|
|
| def _depth_decode_spec(self) -> _DepthDecodeCudaGraphSpec: |
| static = self.graph_spec |
| if static is None: |
| cfg = self.backbone.transformer.config |
| rotary_emb = getattr(self.backbone.transformer, "rotary_emb", None) |
| static = _DepthDecodeCudaGraphSpec( |
| eligible=( |
| not cfg.norm_after |
| and cfg.rope_scaling_layers is None |
| and getattr(rotary_emb, "rope_type", None) == "default" |
| and cfg._attn_implementation == "sdpa" |
| ), |
| cache_key_prefix=( |
| cfg.hidden_size, |
| cfg.num_attention_heads, |
| cfg.num_key_value_heads, |
| cfg.head_dim, |
| cfg.num_hidden_layers, |
| cfg.use_qk_norm, |
| cfg.qk_norm_type, |
| cfg._attn_implementation, |
| ), |
| num_hidden_layers=cfg.num_hidden_layers, |
| head_dim=cfg.head_dim, |
| num_attention_heads=cfg.num_attention_heads, |
| ) |
| self.graph_spec = static |
| return static |
|
|
| def can_use( |
| self, |
| next_input_ids: torch.Tensor, |
| *, |
| past_key_values: Cache, |
| attention_bias: torch.Tensor, |
| ) -> bool: |
| if ( |
| not self.enabled |
| or self.model.training |
| or self.backbone.transformer.training |
| ): |
| return False |
| if next_input_ids.device.type != "cuda": |
| return False |
| if ( |
| next_input_ids.ndim != 2 |
| or next_input_ids.shape[0] != 1 |
| or next_input_ids.shape[1] != 1 |
| ): |
| return False |
| if not isinstance(past_key_values, _DepthDecodeStaticCache): |
| return False |
| if ( |
| not torch.is_tensor(attention_bias) |
| or attention_bias.device != next_input_ids.device |
| ): |
| return False |
| return self._depth_decode_spec().eligible |
|
|
| def _depth_decode_key( |
| self, |
| next_input_ids: torch.Tensor, |
| attention_bias: torch.Tensor, |
| ) -> Tuple[Any, ...]: |
| device = next_input_ids.device |
| return ( |
| self._depth_decode_spec().cache_key_prefix, |
| device.type, |
| device.index, |
| self.model.lm_head.weight.dtype, |
| attention_bias.shape[-1], |
| ) |
|
|
| def _select_depth_decode_rope( |
| self, cos: torch.Tensor, sin: torch.Tensor, *, past_length: int |
| ) -> None: |
| emb = self.backbone.transformer.rotary_emb |
| cos.copy_(emb._pos_cos_cache[0, :, past_length : past_length + 1, :]) |
| sin.copy_(emb._pos_sin_cache[0, :, past_length : past_length + 1, :]) |
|
|
| def _depth_decode_pre_layer( |
| self, |
| layer_idx: int, |
| hidden_states: torch.Tensor, |
| cos: torch.Tensor, |
| sin: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| block = self.backbone.transformer.blocks[layer_idx] |
| attention = block.self_attn |
| residual = hidden_states |
| hidden_states = block.attn_norm(hidden_states) |
|
|
| input_shape = hidden_states.shape[:-1] |
| hidden_shape = (*input_shape, -1, attention.head_dim) |
| qkv = attention.att_proj(hidden_states) |
| query_states, key_states, value_states = qkv.split(attention.fused_dims, dim=-1) |
| value_states = value_states.view(hidden_shape) |
|
|
| apply_qk_norm = attention.q_norm is not None and attention.k_norm is not None |
| norm_after_view = apply_qk_norm and attention.qk_norm_type == "qwen3" |
|
|
| if apply_qk_norm and not norm_after_view: |
| query_states = attention.q_norm(query_states) |
| key_states = attention.k_norm(key_states) |
|
|
| query_states = query_states.view(hidden_shape) |
| key_states = key_states.view(hidden_shape) |
|
|
| if norm_after_view: |
| query_states = attention.q_norm(query_states) |
| key_states = attention.k_norm(key_states) |
|
|
| query_states = query_states.transpose(1, 2) |
| key_states = key_states.transpose(1, 2) |
| value_states = value_states.transpose(1, 2) |
| query_states, key_states = _apply_rotary_pos_emb( |
| query_states, key_states, cos, sin |
| ) |
| return residual, query_states, key_states, value_states |
|
|
| def _depth_decode_pre0( |
| self, |
| token_ids: torch.Tensor, |
| cos: torch.Tensor, |
| sin: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| inputs_embeds = self.model._embed_base_tokens(token_ids) |
| return self._depth_decode_pre_layer(0, inputs_embeds, cos, sin) |
|
|
| def _depth_decode_post_layer( |
| self, |
| layer_idx: int, |
| residual: torch.Tensor, |
| attn_context: torch.Tensor, |
| ) -> torch.Tensor: |
| block = self.backbone.transformer.blocks[layer_idx] |
| attention = block.self_attn |
| input_shape = residual.shape[:-1] |
| attn_output = attn_context.reshape(*input_shape, -1).contiguous() |
| attn_output = attention.attn_out(attn_output) |
| hidden_states = residual + block.dropout(attn_output) |
|
|
| residual = hidden_states |
| hidden_states = block.ff_norm(hidden_states) |
| hidden_states = block.mlp(hidden_states) |
| hidden_states = residual + block.dropout(hidden_states) |
| return hidden_states |
|
|
| def _depth_decode_post_and_pre_next( |
| self, |
| layer_idx: int, |
| residual: torch.Tensor, |
| attn_context: torch.Tensor, |
| cos: torch.Tensor, |
| sin: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context) |
| return self._depth_decode_pre_layer(layer_idx + 1, hidden_states, cos, sin) |
|
|
| def _depth_decode_last_post( |
| self, |
| layer_idx: int, |
| residual: torch.Tensor, |
| attn_context: torch.Tensor, |
| ) -> torch.Tensor: |
| hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context) |
| return self.backbone.transformer.ln_f(hidden_states) |
|
|
| def _build_depth_decode_graph( |
| self, |
| next_input_ids: torch.Tensor, |
| *, |
| past_length: int, |
| attention_bias: torch.Tensor, |
| ) -> _DepthDecodeCudaGraph: |
| text_config = self.backbone.transformer.config |
| device = next_input_ids.device |
| dtype = self.model.lm_head.weight.dtype |
| static = self._depth_decode_spec() |
| num_layers = static.num_hidden_layers |
| head_dim = static.head_dim |
| max_cache_len = int(attention_bias.shape[-1]) |
| max_rope_len = max(int(text_config.max_position_embeddings or 0), max_cache_len) |
| self.backbone.transformer.prepare_rope_cache( |
| device=device, max_seq_len=max_rope_len |
| ) |
|
|
| token_ids = torch.empty((1, 1), device=device, dtype=torch.long) |
| cos = torch.empty((1, 1, head_dim), device=device, dtype=dtype) |
| sin = torch.empty_like(cos) |
| positions = torch.arange(max_cache_len, device=device, dtype=torch.long) |
| context_shape = (1, 1, static.num_attention_heads, head_dim) |
|
|
| token_ids.copy_(next_input_ids) |
| self._select_depth_decode_rope(cos, sin, past_length=past_length) |
|
|
| pre_graph, pre_output = _capture_cuda_graph( |
| lambda: self._depth_decode_pre0(token_ids, cos, sin), |
| device, |
| ) |
| stages = [_DepthDecodeCudaGraphLayerStage(*pre_output)] |
| post_graphs = [] |
| for layer_idx in range(num_layers - 1): |
| stage = stages[-1] |
| attn_context = torch.empty(context_shape, device=device, dtype=dtype) |
| graph, output = _capture_cuda_graph( |
| lambda layer_idx=layer_idx, stage=stage, attn_context=attn_context: ( |
| self._depth_decode_post_and_pre_next( |
| layer_idx, |
| stage.residual, |
| attn_context, |
| cos, |
| sin, |
| ) |
| ), |
| device, |
| ) |
| post_graphs.append( |
| _DepthDecodeCudaGraphPostStage(graph=graph, attn_context=attn_context) |
| ) |
| stages.append(_DepthDecodeCudaGraphLayerStage(*output)) |
|
|
| last_stage = stages[-1] |
| last_attn_context = torch.empty(context_shape, device=device, dtype=dtype) |
| last_graph, last_output = _capture_cuda_graph( |
| lambda: self._depth_decode_last_post( |
| num_layers - 1, |
| last_stage.residual, |
| last_attn_context, |
| ), |
| device, |
| ) |
| post_graphs.append( |
| _DepthDecodeCudaGraphPostStage( |
| graph=last_graph, attn_context=last_attn_context |
| ) |
| ) |
| return _DepthDecodeCudaGraph( |
| cache_key=self._depth_decode_key(next_input_ids, attention_bias), |
| pre_graph=pre_graph, |
| token_ids=token_ids, |
| cos=cos, |
| sin=sin, |
| positions=positions, |
| stages=tuple(stages), |
| post_graphs=tuple(post_graphs), |
| output=last_output, |
| ) |
|
|
| def _get_depth_decode_graph( |
| self, |
| next_input_ids: torch.Tensor, |
| *, |
| past_length: int, |
| attention_bias: torch.Tensor, |
| ) -> _DepthDecodeCudaGraph: |
| key = self._depth_decode_key(next_input_ids, attention_bias) |
| decode_graph = self.graph |
| if decode_graph is None or decode_graph.cache_key != key: |
| decode_graph = self._build_depth_decode_graph( |
| next_input_ids, |
| past_length=past_length, |
| attention_bias=attention_bias, |
| ) |
| self.graph = decode_graph |
| else: |
| decode_graph.token_ids.copy_(next_input_ids) |
| self._select_depth_decode_rope( |
| decode_graph.cos, decode_graph.sin, past_length=past_length |
| ) |
| return decode_graph |
|
|
| def _run_depth_decode_attention_core( |
| self, |
| layer_idx: int, |
| stage: _DepthDecodeCudaGraphLayerStage, |
| *, |
| past_key_values: Cache, |
| attention_bias: torch.Tensor, |
| cache_position: torch.Tensor, |
| cos: torch.Tensor, |
| sin: torch.Tensor, |
| ) -> torch.Tensor: |
| attention = self.backbone.transformer.blocks[layer_idx].self_attn |
| cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
| key_states, value_states = past_key_values.update( |
| stage.key, |
| stage.value, |
| layer_idx, |
| cache_kwargs, |
| ) |
| key_states = _repeat_kv(key_states, attention.num_key_value_groups) |
| value_states = _repeat_kv(value_states, attention.num_key_value_groups) |
| attn_output = F.scaled_dot_product_attention( |
| stage.query, |
| key_states, |
| value_states, |
| attn_mask=attention_bias, |
| dropout_p=0.0, |
| is_causal=False, |
| ) |
| return attn_output.transpose(1, 2) |
|
|
| def run( |
| self, |
| next_input_ids: torch.Tensor, |
| *, |
| past_key_values: Cache, |
| attention_bias: torch.Tensor, |
| past_length: int, |
| ) -> Tuple[torch.Tensor, Cache]: |
| end = past_length + 1 |
| decode_graph = self._get_depth_decode_graph( |
| next_input_ids, |
| past_length=past_length, |
| attention_bias=attention_bias, |
| ) |
| cache_position = decode_graph.positions[past_length:end] |
| attention_bias_q = attention_bias[:, :, past_length:end, :end] |
|
|
| decode_graph.pre_graph.replay() |
|
|
| for layer_idx, post_graph in enumerate(decode_graph.post_graphs): |
| attn_context = self._run_depth_decode_attention_core( |
| layer_idx, |
| decode_graph.stages[layer_idx], |
| past_key_values=past_key_values, |
| attention_bias=attention_bias_q, |
| cache_position=cache_position, |
| cos=decode_graph.cos, |
| sin=decode_graph.sin, |
| ) |
| post_graph.attn_context.copy_(attn_context) |
| post_graph.graph.replay() |
|
|
| return decode_graph.output, past_key_values |
|
|
|
|
| def _cuda_graph_tensor_signature( |
| tensor: Optional[torch.Tensor], |
| ) -> Optional[Tuple[Any, ...]]: |
| if tensor is None: |
| return None |
| return ( |
| tuple(tensor.shape), |
| tuple(tensor.stride()), |
| str(tensor.dtype), |
| str(tensor.device), |
| ) |
|
|
|
|
| def _cuda_graph_context_signature(context: Any) -> Tuple[Any, ...]: |
| sig = _cuda_graph_tensor_signature |
| return ( |
| tuple((sig(k), sig(v)) for k, v in context.kv_contexts), |
| sig(context.cross_mask), |
| sig(context.self_mask), |
| sig(context.valid_action), |
| None |
| if context.rope_cache is None |
| else tuple(sig(t) for t in context.rope_cache), |
| ) |
|
|
|
|
| def _cuda_graph_modulation_signature(modulations: Sequence[Any]) -> Tuple[Any, ...]: |
| sig = _cuda_graph_tensor_signature |
| return tuple( |
| ( |
| sig(step.conditioning), |
| tuple( |
| tuple(sig(t) for t in block_modulation) |
| for block_modulation in step.block_modulations |
| ), |
| tuple(sig(t) for t in step.final_modulation), |
| ) |
| for step in modulations |
| ) |
|
|
|
|
| def _cuda_graph_key(inputs: _ActionFlowInputs, steps: int) -> Tuple[Any, ...]: |
| sig = _cuda_graph_tensor_signature |
| return ( |
| sig(inputs.trajectory), |
| _cuda_graph_context_signature(inputs.context), |
| _cuda_graph_modulation_signature(inputs.modulations), |
| sig(inputs.action_dim_is_pad), |
| int(steps), |
| ) |
|
|
|
|
| def _clone_static_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: |
| if tensor is None: |
| return None |
| static = torch.empty_strided( |
| tuple(tensor.shape), |
| tuple(tensor.stride()), |
| device=tensor.device, |
| dtype=tensor.dtype, |
| ) |
| static.copy_(tensor) |
| return static |
|
|
|
|
| def _clone_static_context(context: Any) -> Any: |
| rope_cache = None |
| if context.rope_cache is not None: |
| rope_cache = tuple(_clone_static_tensor(t) for t in context.rope_cache) |
| return context.__class__( |
| kv_contexts=tuple( |
| (_clone_static_tensor(k), _clone_static_tensor(v)) |
| for k, v in context.kv_contexts |
| ), |
| cross_mask=_clone_static_tensor(context.cross_mask), |
| self_mask=_clone_static_tensor(context.self_mask), |
| valid_action=_clone_static_tensor(context.valid_action), |
| rope_cache=rope_cache, |
| ) |
|
|
|
|
| def _clone_static_modulations(modulations: Sequence[Any]) -> Sequence[Any]: |
| return tuple( |
| step.__class__( |
| conditioning=_clone_static_tensor(step.conditioning), |
| block_modulations=tuple( |
| tuple(_clone_static_tensor(t) for t in block_modulation) |
| for block_modulation in step.block_modulations |
| ), |
| final_modulation=tuple( |
| _clone_static_tensor(t) for t in step.final_modulation |
| ), |
| ) |
| for step in modulations |
| ) |
|
|
|
|
| def _clone_static_inputs(inputs: _ActionFlowInputs) -> _ActionFlowInputs: |
| return _ActionFlowInputs( |
| trajectory=_clone_static_tensor(inputs.trajectory), |
| context=_clone_static_context(inputs.context), |
| modulations=_clone_static_modulations(inputs.modulations), |
| action_dim_is_pad=_clone_static_tensor(inputs.action_dim_is_pad), |
| ) |
|
|
|
|
| def _copy_context_(dst: Any, src: Any) -> None: |
| for (dst_k, dst_v), (src_k, src_v) in zip(dst.kv_contexts, src.kv_contexts): |
| dst_k.copy_(src_k) |
| dst_v.copy_(src_v) |
| if src.cross_mask is not None: |
| dst.cross_mask.copy_(src.cross_mask) |
| if src.self_mask is not None: |
| dst.self_mask.copy_(src.self_mask) |
| if src.valid_action is not None: |
| dst.valid_action.copy_(src.valid_action) |
| if src.rope_cache is not None: |
| for dst_tensor, src_tensor in zip(dst.rope_cache, src.rope_cache): |
| dst_tensor.copy_(src_tensor) |
|
|
|
|
| def _copy_inputs_(dst: _ActionFlowInputs, src: _ActionFlowInputs) -> None: |
| dst.trajectory.copy_(src.trajectory) |
| _copy_context_(dst.context, src.context) |
| if src.action_dim_is_pad is not None: |
| dst.action_dim_is_pad.copy_(src.action_dim_is_pad) |
|
|
|
|
| def _rotate_half(x: torch.Tensor) -> torch.Tensor: |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def _apply_rotary_pos_emb( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| cos: torch.Tensor, |
| sin: torch.Tensor, |
| unsqueeze_dim: int = 1, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| cos = cos.unsqueeze(unsqueeze_dim) |
| sin = sin.unsqueeze(unsqueeze_dim) |
| q_embed = (q * cos) + (_rotate_half(q) * sin) |
| k_embed = (k * cos) + (_rotate_half(k) * sin) |
| return q_embed, k_embed |
|
|
|
|
| def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
| if n_rep == 1: |
| return hidden_states |
| hidden_states = hidden_states[:, :, None, :, :].expand( |
| batch, num_key_value_heads, n_rep, slen, head_dim |
| ) |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
| def _capture_cuda_graph( |
| fn, |
| device: torch.device, |
| *, |
| after_warmup=None, |
| ) -> Tuple[torch.cuda.CUDAGraph, Any]: |
| warmup_stream = torch.cuda.Stream(device=device) |
| warmup_stream.wait_stream(torch.cuda.current_stream(device)) |
| with torch.cuda.stream(warmup_stream): |
| fn() |
| torch.cuda.current_stream(device).wait_stream(warmup_stream) |
| if after_warmup is not None: |
| after_warmup() |
|
|
| graph = torch.cuda.CUDAGraph() |
| with torch.cuda.graph(graph): |
| output = fn() |
| return graph, output |
|
|