hqfang commited on
Commit
eecdbf4
·
verified ·
1 Parent(s): 5f712be

Enable CUDA graph AR decode for discrete actions

Browse files

Generalizes the existing single-token decode CUDA graph helper and uses it for no-depth discrete action autoregressive decoding when enable_cuda_graph=True. The original KV-cache path is preserved when enable_cuda_graph=False.

Files changed (1) hide show
  1. modeling_molmoact2.py +59 -10
modeling_molmoact2.py CHANGED
@@ -4127,14 +4127,14 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi
4127
 
4128
  def _embed_base_tokens(self, input_ids: torch.Tensor) -> torch.Tensor:
4129
  # Skips MolmoAct2Embedding's per-call cat([base, new]); safe only for IDs
4130
- # below text_config.vocab_size, which is the case for all depth tokens.
4131
  wte = self.model.transformer.wte
4132
  base_embedding = getattr(wte, "embedding", None)
4133
  if base_embedding is None:
4134
  return wte(input_ids)
4135
  return F.embedding(input_ids, base_embedding)
4136
 
4137
- def _run_depth_decode_step(
4138
  self,
4139
  token_ids: torch.Tensor,
4140
  *,
@@ -4178,6 +4178,19 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi
4178
  )
4179
  return outputs.last_hidden_state[:, -1:, :], outputs.past_key_values
4180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4181
  def _project_depth_logits(self, last_hidden: torch.Tensor) -> torch.Tensor:
4182
  start = int(self.config.depth_token_start_id)
4183
  end_id = start + int(self.config.num_depth_tokens)
@@ -4190,6 +4203,12 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi
4190
  1,
4191
  )
4192
 
 
 
 
 
 
 
4193
  def _make_depth_static_cache(self, inputs: Mapping[str, Any]) -> Cache:
4194
  prompt_len = inputs["input_ids"].shape[1]
4195
  action_horizon = int(self.config.action_horizon or 1)
@@ -4210,6 +4229,7 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi
4210
  attention_mask: Optional[torch.Tensor],
4211
  end_token_id: int,
4212
  max_steps: int,
 
4213
  ) -> torch.Tensor:
4214
  generated_tokens: List[torch.Tensor] = []
4215
  current_output = initial_output
@@ -4222,12 +4242,23 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi
4222
  if bool((next_token == int(end_token_id)).all()):
4223
  hit_end = True
4224
  break
4225
- current_output, current_attention_mask = self._consume_generation_tokens(
4226
- next_token,
4227
- past_key_values=current_past_key_values,
4228
- attention_mask=current_attention_mask,
4229
- )
4230
- current_past_key_values = current_output.past_key_values
 
 
 
 
 
 
 
 
 
 
 
4231
  if not generated_tokens:
4232
  raise RuntimeError("Discrete continuation generated no tokens.")
4233
  if not hit_end:
@@ -4705,13 +4736,31 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi
4705
  .astype(np.int64),
4706
  }
4707
  else:
4708
- prefill_output = self(**inputs, use_cache=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4709
  action_token_ids = self._continue_discrete_generation_from_output(
4710
  prefill_output,
4711
  past_key_values=prefill_output.past_key_values,
4712
  attention_mask=inputs.get("attention_mask"),
4713
  end_token_id=self._require_eos_token_id(),
4714
- max_steps=max(1, int(self.config.action_horizon * 16)),
 
4715
  )
4716
  generated_token_ids = action_token_ids
4717
  actions = self._decode_discrete_action_chunk(
 
4127
 
4128
  def _embed_base_tokens(self, input_ids: torch.Tensor) -> torch.Tensor:
4129
  # Skips MolmoAct2Embedding's per-call cat([base, new]); safe only for IDs
4130
+ # below text_config.vocab_size. This includes released depth/action tokens.
4131
  wte = self.model.transformer.wte
4132
  base_embedding = getattr(wte, "embedding", None)
4133
  if base_embedding is None:
4134
  return wte(input_ids)
4135
  return F.embedding(input_ids, base_embedding)
4136
 
4137
+ def _run_ar_decode_step(
4138
  self,
4139
  token_ids: torch.Tensor,
4140
  *,
 
4178
  )
4179
  return outputs.last_hidden_state[:, -1:, :], outputs.past_key_values
4180
 
4181
+ def _run_depth_decode_step(
4182
+ self,
4183
+ token_ids: torch.Tensor,
4184
+ *,
4185
+ past_key_values: Cache,
4186
+ attention_bias: torch.Tensor,
4187
+ ) -> Tuple[torch.Tensor, Cache]:
4188
+ return self._run_ar_decode_step(
4189
+ token_ids,
4190
+ past_key_values=past_key_values,
4191
+ attention_bias=attention_bias,
4192
+ )
4193
+
4194
  def _project_depth_logits(self, last_hidden: torch.Tensor) -> torch.Tensor:
4195
  start = int(self.config.depth_token_start_id)
4196
  end_id = start + int(self.config.num_depth_tokens)
 
4203
  1,
4204
  )
4205
 
4206
+ def _make_ar_decode_static_cache(self, inputs: Mapping[str, Any], max_steps: int) -> Cache:
4207
+ prompt_len = inputs["input_ids"].shape[1]
4208
+ return self.depth_decode_cuda_graph_manager.make_static_cache(
4209
+ max_cache_len=prompt_len + max(1, int(max_steps)),
4210
+ )
4211
+
4212
  def _make_depth_static_cache(self, inputs: Mapping[str, Any]) -> Cache:
4213
  prompt_len = inputs["input_ids"].shape[1]
4214
  action_horizon = int(self.config.action_horizon or 1)
 
4229
  attention_mask: Optional[torch.Tensor],
4230
  end_token_id: int,
4231
  max_steps: int,
4232
+ attention_bias: Optional[torch.Tensor] = None,
4233
  ) -> torch.Tensor:
4234
  generated_tokens: List[torch.Tensor] = []
4235
  current_output = initial_output
 
4242
  if bool((next_token == int(end_token_id)).all()):
4243
  hit_end = True
4244
  break
4245
+ if attention_bias is None:
4246
+ current_output, current_attention_mask = self._consume_generation_tokens(
4247
+ next_token,
4248
+ past_key_values=current_past_key_values,
4249
+ attention_mask=current_attention_mask,
4250
+ )
4251
+ current_past_key_values = current_output.past_key_values
4252
+ else:
4253
+ last_hidden, current_past_key_values = self._run_ar_decode_step(
4254
+ next_token,
4255
+ past_key_values=current_past_key_values,
4256
+ attention_bias=attention_bias,
4257
+ )
4258
+ current_output = MolmoAct2CausalLMOutputWithPast(
4259
+ logits=self.lm_head(last_hidden),
4260
+ past_key_values=current_past_key_values,
4261
+ )
4262
  if not generated_tokens:
4263
  raise RuntimeError("Discrete continuation generated no tokens.")
4264
  if not hit_end:
 
4736
  .astype(np.int64),
4737
  }
4738
  else:
4739
+ max_action_decode_steps = max(1, int(self.config.action_horizon * 16))
4740
+ action_attention_bias = None
4741
+ if enable_cuda_graph:
4742
+ action_static_cache = self._make_ar_decode_static_cache(
4743
+ inputs,
4744
+ max_steps=max_action_decode_steps,
4745
+ )
4746
+ action_attention_bias = self._make_depth_decode_attention_bias(
4747
+ inputs,
4748
+ action_static_cache,
4749
+ )
4750
+ prefill_output = self(
4751
+ **inputs,
4752
+ use_cache=True,
4753
+ past_key_values=action_static_cache,
4754
+ )
4755
+ else:
4756
+ prefill_output = self(**inputs, use_cache=True)
4757
  action_token_ids = self._continue_discrete_generation_from_output(
4758
  prefill_output,
4759
  past_key_values=prefill_output.past_key_values,
4760
  attention_mask=inputs.get("attention_mask"),
4761
  end_token_id=self._require_eos_token_id(),
4762
+ max_steps=max_action_decode_steps,
4763
+ attention_bias=action_attention_bias,
4764
  )
4765
  generated_token_ids = action_token_ids
4766
  actions = self._decode_discrete_action_chunk(