hqfang commited on
Commit
e1d2401
·
verified ·
1 Parent(s): 48551e6

Update MolmoAct2 action mode inference API

Browse files
Files changed (3) hide show
  1. config.json +1 -1
  2. configuration_molmoact2.py +2 -2
  3. modeling_molmoact2.py +22 -17
config.json CHANGED
@@ -4,7 +4,7 @@
4
  "action_expert_depth_gate": false,
5
  "action_expert_depth_gate_init_bias": -4.0,
6
  "action_expert_depth_gate_per_layer": false,
7
- "action_format": "discrete",
8
  "max_action_horizon": 30,
9
  "action_output_token_id": 151931,
10
  "action_start_token_id": 151932,
 
4
  "action_expert_depth_gate": false,
5
  "action_expert_depth_gate_init_bias": -4.0,
6
  "action_expert_depth_gate_per_layer": false,
7
+ "action_mode": "discrete",
8
  "max_action_horizon": 30,
9
  "action_output_token_id": 151931,
10
  "action_start_token_id": 151932,
configuration_molmoact2.py CHANGED
@@ -375,7 +375,7 @@ class MolmoAct2Config(PretrainedConfig):
375
  max_action_dim: int = 32,
376
  max_action_horizon: int = 30,
377
  n_obs_steps: int = 30,
378
- action_format: str = "both",
379
  state_format: str = "discrete",
380
  flow_matching_num_steps: int = 10,
381
  flow_matching_cutoff: float = 1.0,
@@ -461,7 +461,7 @@ class MolmoAct2Config(PretrainedConfig):
461
  self.max_action_dim = max_action_dim
462
  self.max_action_horizon = max_action_horizon
463
  self.n_obs_steps = n_obs_steps
464
- self.action_format = action_format
465
  self.state_format = state_format
466
  self.flow_matching_num_steps = flow_matching_num_steps
467
  self.flow_matching_cutoff = flow_matching_cutoff
 
375
  max_action_dim: int = 32,
376
  max_action_horizon: int = 30,
377
  n_obs_steps: int = 30,
378
+ action_mode: str = "both",
379
  state_format: str = "discrete",
380
  flow_matching_num_steps: int = 10,
381
  flow_matching_cutoff: float = 1.0,
 
461
  self.max_action_dim = max_action_dim
462
  self.max_action_horizon = max_action_horizon
463
  self.n_obs_steps = n_obs_steps
464
+ self.action_mode = action_mode
465
  self.state_format = state_format
466
  self.flow_matching_num_steps = flow_matching_num_steps
467
  self.flow_matching_cutoff = flow_matching_cutoff
modeling_molmoact2.py CHANGED
@@ -2949,7 +2949,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel):
2949
  mask = input_ids != -1
2950
  else:
2951
  return None
2952
- if self.config.action_format != "both" or input_ids is None:
2953
  return mask
2954
  eos_id = getattr(self.config, "eos_token_id", None)
2955
  if eos_id is not None:
@@ -4452,7 +4452,7 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi
4452
  ) -> torch.Tensor:
4453
  if action_tokenizer is None:
4454
  raise ValueError(
4455
- "action_mode='discrete' requires an `action_tokenizer` input."
4456
  )
4457
  if (
4458
  self.config.action_start_token_id is None
@@ -4508,7 +4508,7 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi
4508
  task: str,
4509
  state: Any,
4510
  norm_tag: str,
4511
- action_mode: str = "continuous",
4512
  enable_depth_reasoning: bool = False,
4513
  enable_adaptive_depth: bool = True,
4514
  depth_cache: Optional[Mapping[str, Any]] = None,
@@ -4524,31 +4524,36 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi
4524
  raise ValueError(
4525
  "MolmoAct2 `predict_action` requires `state` for discrete state prompting."
4526
  )
4527
- action_mode = str(action_mode or "continuous")
4528
- if action_mode not in {"continuous", "discrete"}:
4529
- raise ValueError("action_mode must be either 'continuous' or 'discrete'.")
4530
- if action_mode == "continuous" and not bool(self.config.add_action_expert):
 
 
 
 
 
4531
  raise RuntimeError(
4532
- "action_mode='continuous' requires an action expert, but this checkpoint "
4533
  "was converted with add_action_expert=False."
4534
  )
4535
- if action_mode == "continuous" and self.config.action_format not in {
4536
  "continuous",
4537
  "both",
4538
  }:
4539
  raise ValueError(
4540
- f"action_mode='continuous' requires checkpoint action_format in {{'continuous', 'both'}}, "
4541
- f"got {self.config.action_format!r}."
4542
  )
4543
- if action_mode == "discrete":
4544
  if action_tokenizer is None:
4545
  raise ValueError(
4546
- "action_mode='discrete' requires an `action_tokenizer` input."
4547
  )
4548
- if self.config.action_format not in {"discrete", "both"}:
4549
  raise ValueError(
4550
- f"action_mode='discrete' requires checkpoint action_format in {{'discrete', 'both'}}, "
4551
- f"got {self.config.action_format!r}."
4552
  )
4553
  if enable_depth_reasoning and not bool(self.config.enable_depth_reasoning):
4554
  raise ValueError(
@@ -4625,7 +4630,7 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi
4625
  generated_token_ids = None
4626
  depth_bins = None
4627
  updated_depth_cache = depth_cache
4628
- if action_mode == "continuous":
4629
  if enable_depth_reasoning:
4630
  latest_first_image = _extract_first_image(images)
4631
  depth_prefix = self._generate_depth_prefix(
 
2949
  mask = input_ids != -1
2950
  else:
2951
  return None
2952
+ if self.config.action_mode != "both" or input_ids is None:
2953
  return mask
2954
  eos_id = getattr(self.config, "eos_token_id", None)
2955
  if eos_id is not None:
 
4452
  ) -> torch.Tensor:
4453
  if action_tokenizer is None:
4454
  raise ValueError(
4455
+ "inference_action_mode='discrete' requires an `action_tokenizer` input."
4456
  )
4457
  if (
4458
  self.config.action_start_token_id is None
 
4508
  task: str,
4509
  state: Any,
4510
  norm_tag: str,
4511
+ inference_action_mode: Optional[str] = None,
4512
  enable_depth_reasoning: bool = False,
4513
  enable_adaptive_depth: bool = True,
4514
  depth_cache: Optional[Mapping[str, Any]] = None,
 
4524
  raise ValueError(
4525
  "MolmoAct2 `predict_action` requires `state` for discrete state prompting."
4526
  )
4527
+ if inference_action_mode is None:
4528
+ raise ValueError(
4529
+ "`inference_action_mode` must be provided explicitly as either "
4530
+ "'continuous' or 'discrete'."
4531
+ )
4532
+ inference_action_mode = str(inference_action_mode)
4533
+ if inference_action_mode not in {"continuous", "discrete"}:
4534
+ raise ValueError("inference_action_mode must be either 'continuous' or 'discrete'.")
4535
+ if inference_action_mode == "continuous" and not bool(self.config.add_action_expert):
4536
  raise RuntimeError(
4537
+ "inference_action_mode='continuous' requires an action expert, but this checkpoint "
4538
  "was converted with add_action_expert=False."
4539
  )
4540
+ if inference_action_mode == "continuous" and self.config.action_mode not in {
4541
  "continuous",
4542
  "both",
4543
  }:
4544
  raise ValueError(
4545
+ "inference_action_mode='continuous' requires checkpoint action_mode in "
4546
+ f"{{'continuous', 'both'}}, got {self.config.action_mode!r}."
4547
  )
4548
+ if inference_action_mode == "discrete":
4549
  if action_tokenizer is None:
4550
  raise ValueError(
4551
+ "inference_action_mode='discrete' requires an `action_tokenizer` input."
4552
  )
4553
+ if self.config.action_mode not in {"discrete", "both"}:
4554
  raise ValueError(
4555
+ "inference_action_mode='discrete' requires checkpoint action_mode in "
4556
+ f"{{'discrete', 'both'}}, got {self.config.action_mode!r}."
4557
  )
4558
  if enable_depth_reasoning and not bool(self.config.enable_depth_reasoning):
4559
  raise ValueError(
 
4630
  generated_token_ids = None
4631
  depth_bins = None
4632
  updated_depth_cache = depth_cache
4633
+ if inference_action_mode == "continuous":
4634
  if enable_depth_reasoning:
4635
  latest_first_image = _extract_first_image(images)
4636
  depth_prefix = self._generate_depth_prefix(