Enable CUDA graph AR decode for discrete actions
Browse filesGeneralizes the existing single-token decode CUDA graph helper and uses it for no-depth discrete action autoregressive decoding when enable_cuda_graph=True. The original KV-cache path is preserved when enable_cuda_graph=False.
- modeling_molmoact2.py +59 -10
modeling_molmoact2.py
CHANGED
|
@@ -4127,14 +4127,14 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi
|
|
| 4127 |
|
| 4128 |
def _embed_base_tokens(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 4129 |
# Skips MolmoAct2Embedding's per-call cat([base, new]); safe only for IDs
|
| 4130 |
-
# below text_config.vocab_size
|
| 4131 |
wte = self.model.transformer.wte
|
| 4132 |
base_embedding = getattr(wte, "embedding", None)
|
| 4133 |
if base_embedding is None:
|
| 4134 |
return wte(input_ids)
|
| 4135 |
return F.embedding(input_ids, base_embedding)
|
| 4136 |
|
| 4137 |
-
def
|
| 4138 |
self,
|
| 4139 |
token_ids: torch.Tensor,
|
| 4140 |
*,
|
|
@@ -4178,6 +4178,19 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi
|
|
| 4178 |
)
|
| 4179 |
return outputs.last_hidden_state[:, -1:, :], outputs.past_key_values
|
| 4180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4181 |
def _project_depth_logits(self, last_hidden: torch.Tensor) -> torch.Tensor:
|
| 4182 |
start = int(self.config.depth_token_start_id)
|
| 4183 |
end_id = start + int(self.config.num_depth_tokens)
|
|
@@ -4190,6 +4203,12 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi
|
|
| 4190 |
1,
|
| 4191 |
)
|
| 4192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4193 |
def _make_depth_static_cache(self, inputs: Mapping[str, Any]) -> Cache:
|
| 4194 |
prompt_len = inputs["input_ids"].shape[1]
|
| 4195 |
action_horizon = int(self.config.action_horizon or 1)
|
|
@@ -4210,6 +4229,7 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi
|
|
| 4210 |
attention_mask: Optional[torch.Tensor],
|
| 4211 |
end_token_id: int,
|
| 4212 |
max_steps: int,
|
|
|
|
| 4213 |
) -> torch.Tensor:
|
| 4214 |
generated_tokens: List[torch.Tensor] = []
|
| 4215 |
current_output = initial_output
|
|
@@ -4222,12 +4242,23 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi
|
|
| 4222 |
if bool((next_token == int(end_token_id)).all()):
|
| 4223 |
hit_end = True
|
| 4224 |
break
|
| 4225 |
-
|
| 4226 |
-
|
| 4227 |
-
|
| 4228 |
-
|
| 4229 |
-
|
| 4230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4231 |
if not generated_tokens:
|
| 4232 |
raise RuntimeError("Discrete continuation generated no tokens.")
|
| 4233 |
if not hit_end:
|
|
@@ -4705,13 +4736,31 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi
|
|
| 4705 |
.astype(np.int64),
|
| 4706 |
}
|
| 4707 |
else:
|
| 4708 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4709 |
action_token_ids = self._continue_discrete_generation_from_output(
|
| 4710 |
prefill_output,
|
| 4711 |
past_key_values=prefill_output.past_key_values,
|
| 4712 |
attention_mask=inputs.get("attention_mask"),
|
| 4713 |
end_token_id=self._require_eos_token_id(),
|
| 4714 |
-
max_steps=
|
|
|
|
| 4715 |
)
|
| 4716 |
generated_token_ids = action_token_ids
|
| 4717 |
actions = self._decode_discrete_action_chunk(
|
|
|
|
| 4127 |
|
| 4128 |
def _embed_base_tokens(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 4129 |
# Skips MolmoAct2Embedding's per-call cat([base, new]); safe only for IDs
|
| 4130 |
+
# below text_config.vocab_size. This includes released depth/action tokens.
|
| 4131 |
wte = self.model.transformer.wte
|
| 4132 |
base_embedding = getattr(wte, "embedding", None)
|
| 4133 |
if base_embedding is None:
|
| 4134 |
return wte(input_ids)
|
| 4135 |
return F.embedding(input_ids, base_embedding)
|
| 4136 |
|
| 4137 |
+
def _run_ar_decode_step(
|
| 4138 |
self,
|
| 4139 |
token_ids: torch.Tensor,
|
| 4140 |
*,
|
|
|
|
| 4178 |
)
|
| 4179 |
return outputs.last_hidden_state[:, -1:, :], outputs.past_key_values
|
| 4180 |
|
| 4181 |
+
def _run_depth_decode_step(
|
| 4182 |
+
self,
|
| 4183 |
+
token_ids: torch.Tensor,
|
| 4184 |
+
*,
|
| 4185 |
+
past_key_values: Cache,
|
| 4186 |
+
attention_bias: torch.Tensor,
|
| 4187 |
+
) -> Tuple[torch.Tensor, Cache]:
|
| 4188 |
+
return self._run_ar_decode_step(
|
| 4189 |
+
token_ids,
|
| 4190 |
+
past_key_values=past_key_values,
|
| 4191 |
+
attention_bias=attention_bias,
|
| 4192 |
+
)
|
| 4193 |
+
|
| 4194 |
def _project_depth_logits(self, last_hidden: torch.Tensor) -> torch.Tensor:
|
| 4195 |
start = int(self.config.depth_token_start_id)
|
| 4196 |
end_id = start + int(self.config.num_depth_tokens)
|
|
|
|
| 4203 |
1,
|
| 4204 |
)
|
| 4205 |
|
| 4206 |
+
def _make_ar_decode_static_cache(self, inputs: Mapping[str, Any], max_steps: int) -> Cache:
|
| 4207 |
+
prompt_len = inputs["input_ids"].shape[1]
|
| 4208 |
+
return self.depth_decode_cuda_graph_manager.make_static_cache(
|
| 4209 |
+
max_cache_len=prompt_len + max(1, int(max_steps)),
|
| 4210 |
+
)
|
| 4211 |
+
|
| 4212 |
def _make_depth_static_cache(self, inputs: Mapping[str, Any]) -> Cache:
|
| 4213 |
prompt_len = inputs["input_ids"].shape[1]
|
| 4214 |
action_horizon = int(self.config.action_horizon or 1)
|
|
|
|
| 4229 |
attention_mask: Optional[torch.Tensor],
|
| 4230 |
end_token_id: int,
|
| 4231 |
max_steps: int,
|
| 4232 |
+
attention_bias: Optional[torch.Tensor] = None,
|
| 4233 |
) -> torch.Tensor:
|
| 4234 |
generated_tokens: List[torch.Tensor] = []
|
| 4235 |
current_output = initial_output
|
|
|
|
| 4242 |
if bool((next_token == int(end_token_id)).all()):
|
| 4243 |
hit_end = True
|
| 4244 |
break
|
| 4245 |
+
if attention_bias is None:
|
| 4246 |
+
current_output, current_attention_mask = self._consume_generation_tokens(
|
| 4247 |
+
next_token,
|
| 4248 |
+
past_key_values=current_past_key_values,
|
| 4249 |
+
attention_mask=current_attention_mask,
|
| 4250 |
+
)
|
| 4251 |
+
current_past_key_values = current_output.past_key_values
|
| 4252 |
+
else:
|
| 4253 |
+
last_hidden, current_past_key_values = self._run_ar_decode_step(
|
| 4254 |
+
next_token,
|
| 4255 |
+
past_key_values=current_past_key_values,
|
| 4256 |
+
attention_bias=attention_bias,
|
| 4257 |
+
)
|
| 4258 |
+
current_output = MolmoAct2CausalLMOutputWithPast(
|
| 4259 |
+
logits=self.lm_head(last_hidden),
|
| 4260 |
+
past_key_values=current_past_key_values,
|
| 4261 |
+
)
|
| 4262 |
if not generated_tokens:
|
| 4263 |
raise RuntimeError("Discrete continuation generated no tokens.")
|
| 4264 |
if not hit_end:
|
|
|
|
| 4736 |
.astype(np.int64),
|
| 4737 |
}
|
| 4738 |
else:
|
| 4739 |
+
max_action_decode_steps = max(1, int(self.config.action_horizon * 16))
|
| 4740 |
+
action_attention_bias = None
|
| 4741 |
+
if enable_cuda_graph:
|
| 4742 |
+
action_static_cache = self._make_ar_decode_static_cache(
|
| 4743 |
+
inputs,
|
| 4744 |
+
max_steps=max_action_decode_steps,
|
| 4745 |
+
)
|
| 4746 |
+
action_attention_bias = self._make_depth_decode_attention_bias(
|
| 4747 |
+
inputs,
|
| 4748 |
+
action_static_cache,
|
| 4749 |
+
)
|
| 4750 |
+
prefill_output = self(
|
| 4751 |
+
**inputs,
|
| 4752 |
+
use_cache=True,
|
| 4753 |
+
past_key_values=action_static_cache,
|
| 4754 |
+
)
|
| 4755 |
+
else:
|
| 4756 |
+
prefill_output = self(**inputs, use_cache=True)
|
| 4757 |
action_token_ids = self._continue_discrete_generation_from_output(
|
| 4758 |
prefill_output,
|
| 4759 |
past_key_values=prefill_output.past_key_values,
|
| 4760 |
attention_mask=inputs.get("attention_mask"),
|
| 4761 |
end_token_id=self._require_eos_token_id(),
|
| 4762 |
+
max_steps=max_action_decode_steps,
|
| 4763 |
+
attention_bias=action_attention_bias,
|
| 4764 |
)
|
| 4765 |
generated_token_ids = action_token_ids
|
| 4766 |
actions = self._decode_discrete_action_chunk(
|