hqfang commited on
Commit
048fc26
·
verified ·
1 Parent(s): 2478201

Add files using upload-large-folder tool

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,31 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags:
4
+ - molmoact2
5
+ - robotics
6
+ - image-text-to-text
7
+ - depth-reasoning
8
+ ---
9
+
10
+ <img src="assets/MolmoAct2-Think.svg" alt="MolmoAct Think Logo" style="width: auto; height: 50px;">
11
+
12
+ # **MolmoAct2-Think**
13
+
14
+ MolmoAct2-Think extends MolmoAct2 with depth-token reasoning. Before producing an action, the model can predict a compact 10 x 10 discrete depth representation and condition the action expert on the resulting depth-aware VLM cache.
15
+
16
+ This checkpoint is the post-trained, multi-embodiment depth-reasoning model. It is intended as a foundation checkpoint for further robot fine-tuning rather than as a ready-to-run policy for a single deployment setting.
17
+
18
+ ## Quick Links
19
+
20
+ - 📂 Models: [Models](https://huggingface.co/collections/allenai/molmoact2-models), [Finetuned Models](https://huggingface.co/collections/allenai/molmoact2-finetuned-models)
21
+ - 📂 Datasets: [MolmoAct2-BimanualYAM Dataset](https://huggingface.co/collections/allenai/molmoact2-datasets), [MolmoAct2 Datasets](https://huggingface.co/collections/allenai/molmoact2-datasets), [Molmo2-ER Datasets](https://huggingface.co/collections/allenai/molmo2-er-datasets)
22
+ - 📄 Paper:
23
+ - 💻 Code: [allenai/molmoact2](https://github.com/allenai/molmoact2)
24
+ - 🎥 Blog Post: [MolmoAct2](https://allenai.org/blog/molmoact2)
25
+
26
+ ## Intended Use
27
+
28
+ Use this checkpoint for further fine-tuning when the downstream policy should use depth reasoning. It contains the VLM, action expert, and depth-token weights, plus normalization metadata for the post-training mixture in `norm_stats.json`.
29
+
30
+ This model card intentionally does not include direct policy inference code. For ready-to-run depth-reasoning inference, use the fine-tuned `MolmoAct2-Think-LIBERO` checkpoint.
31
+
assets/MolmoAct2-Think.svg ADDED
chat_template.jinja ADDED
@@ -0,0 +1 @@
 
 
1
+ {% set DEMO_STYLES = ['point_count','pointing','cosyn_point','user_qa','long_caption','short_caption','video_long_caption','video_short_caption','video_point_track_per_frame','video_point_track_start_end','video_point_track_all_frames','video_single_point_track_start_end','video_transcript','video_clip_caption_start_end','video_clip_caption_start_end_in_seconds','video_clip_transcript_start_end','video_clip_transcript_start_end_in_seconds','video_frame_caption_timestamp','video_frame_caption_timestamp_in_seconds','correction_qa','text_sft','video_point','video_point_count','video_count','video_count_point','multi_image_pointing','multi_image_counting','multi_image_point_then_count','multi_image_count_then_point','demo','a_okvqa_mc','ai2_diagram_no_letter','ai2_diagram','science_qa','multi_image_mc','multi_image_mc_exp','mantis_instruct_mc','video_multiple_choice','video_multiple_choice_count_without_pointing','video_multiple_choice_multiple_correct','video_multiple_choice_w_subtitle'] %}{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% set has_subtitle = messages and messages[0]['role'].lower() == 'subtitle' %}{% for message in messages %}{% if message['content'] is not string %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% elif content['type'] == 'video' or 'video' in content or 'video_url' in content %}{% set video_count.value = video_count.value + 1 %}{% endif %}{% endfor %}{% endif %}{% endfor %}{% if image_count.value == 1 %}{{ '<|image|>' }}{% elif image_count.value > 1 %}{% for i in range(image_count.value) %}{{ 'Image ' ~ (i + 1) ~ '<|image|>' }}{% endfor %}{% endif %}{% for _ in range(video_count.value) %}{{ '<|video|>' }}{% endfor %}{% if has_subtitle %}{{ messages[0]['content'] }}{% endif %}{% for message in messages %}{% set role = message['role'].lower() %}{% if role == 'subtitle' %}{% continue %}{% endif %}{% set conv_index = loop.index - (1 if has_subtitle else 0) %}{%- if (conv_index % 2 == 1 and role != 'user') or (conv_index % 2 == 0 and role != 'assistant') -%}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{%- endif -%}{% if message['content'] is string %}{% set text_content = message['content'] %}{% else %}{% set m = namespace(text='') %}{% for content in message['content'] %}{% if content['type'] == 'text' %}{% if content['style'] is defined and content['style'] not in DEMO_STYLES %}{% set seg = content['style'] ~ ': ' ~ content['text'] %}{% else %}{% set seg = content['text'] %}{% endif %}{% set m.text = m.text ~ ('' if not m.text else ' ') ~ seg %}{% endif %}{% endfor %}{% set text_content = m.text %}{% endif %}{% if role == 'user' %}{% if not (has_subtitle and loop.index == 2) and not (not has_subtitle and loop.first) %}{{ '<|im_end|>\n' }}{% endif %}{{ '<|im_start|>user\n' }}{{ text_content }}{{ '<|im_end|>\n' }}{% else %} {# assistant #}{{ '<|im_start|>assistant\n' }}{{ text_content }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}
config.json ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "action_end_token_id": 151933,
3
+ "action_expert_condition_source": "kv_cache",
4
+ "action_expert_config": {
5
+ "attn_dropout": 0.0,
6
+ "causal_attn": false,
7
+ "compile": "blocks",
8
+ "context_layer_norm": true,
9
+ "dropout": 0.0,
10
+ "ffn_multiple_of": 256,
11
+ "hidden_size": 768,
12
+ "implementation": "new",
13
+ "max_action_dim": 32,
14
+ "max_horizon": 32,
15
+ "mlp_ratio": 4.0,
16
+ "model_type": "molmoact2_action_expert",
17
+ "num_heads": 8,
18
+ "num_layers": 36,
19
+ "qk_norm": true,
20
+ "qk_norm_eps": 1e-06,
21
+ "rope": true,
22
+ "rope_on_cross_attention": true,
23
+ "timestep_embed_dim": 256
24
+ },
25
+ "action_expert_depth_gate": false,
26
+ "action_expert_depth_gate_init_bias": -4.0,
27
+ "action_expert_depth_gate_per_layer": false,
28
+ "action_expert_layer_mode": "per_layer",
29
+ "action_format": "both",
30
+ "action_horizon": 30,
31
+ "action_output_token_id": 151931,
32
+ "action_start_token_id": 151932,
33
+ "action_token_start_id": 151934,
34
+ "adapter_config": {
35
+ "attention_dropout": 0.0,
36
+ "attn_implementation": "sdpa",
37
+ "float32_attention": true,
38
+ "head_dim": 72,
39
+ "hidden_act": "silu",
40
+ "hidden_size": 1152,
41
+ "image_feature_dropout": 0.0,
42
+ "initializer_range": 0.02,
43
+ "intermediate_size": 9728,
44
+ "model_type": "molmoact2",
45
+ "num_attention_heads": 16,
46
+ "num_key_value_heads": 16,
47
+ "pooling_attention_mask": true,
48
+ "residual_dropout": 0.0,
49
+ "text_hidden_size": 2560,
50
+ "vit_layers": [
51
+ -3,
52
+ -9
53
+ ]
54
+ },
55
+ "add_action_expert": true,
56
+ "add_control_tokens": true,
57
+ "add_setup_tokens": true,
58
+ "architectures": [
59
+ "MolmoAct2ForConditionalGeneration"
60
+ ],
61
+ "auto_map": {
62
+ "AutoConfig": "configuration_molmoact2.MolmoAct2Config",
63
+ "AutoModelForImageTextToText": "modeling_molmoact2.MolmoAct2ForConditionalGeneration"
64
+ },
65
+ "depth_end_token_id": 153984,
66
+ "depth_mode": 2,
67
+ "depth_output_token_id": 153982,
68
+ "depth_start_token_id": 153983,
69
+ "depth_token_start_id": 153985,
70
+ "dtype": "float32",
71
+ "enable_depth_reasoning": true,
72
+ "flow_matching_beta_alpha": 1.0,
73
+ "flow_matching_beta_beta": 1.5,
74
+ "flow_matching_cutoff": 1.0,
75
+ "flow_matching_num_steps": 10,
76
+ "flow_matching_time_offset": 0.001,
77
+ "flow_matching_time_scale": 0.999,
78
+ "frame_end_token_id": 155656,
79
+ "frame_start_token_id": 155655,
80
+ "image_col_id": 155651,
81
+ "image_end_token_id": 155649,
82
+ "image_high_res_id": 155650,
83
+ "image_low_res_id": 155654,
84
+ "image_patch_id": 155650,
85
+ "image_start_token_id": 155648,
86
+ "initializer_range": 0.02,
87
+ "low_res_image_start_token_id": 155652,
88
+ "mask_action_dim_padding": true,
89
+ "max_action_dim": 32,
90
+ "model_type": "molmoact2",
91
+ "n_obs_steps": 1,
92
+ "norm_stats_filename": "norm_stats.json",
93
+ "num_action_tokens": 2048,
94
+ "num_depth_codes": 100,
95
+ "num_depth_tokens": 128,
96
+ "num_state_tokens": 256,
97
+ "state_end_token_id": 151674,
98
+ "state_format": "discrete",
99
+ "state_start_token_id": 151673,
100
+ "state_token_start_id": 151675,
101
+ "text_config": {
102
+ "additional_vocab_size": 128,
103
+ "attention_dropout": 0.0,
104
+ "attn_implementation": "sdpa",
105
+ "embedding_dropout": 0.0,
106
+ "head_dim": 128,
107
+ "hidden_act": "silu",
108
+ "hidden_size": 2560,
109
+ "initializer_range": 0.02,
110
+ "intermediate_size": 9728,
111
+ "layer_norm_eps": 1e-06,
112
+ "max_position_embeddings": 16384,
113
+ "model_type": "molmoact2_text",
114
+ "norm_after": false,
115
+ "num_attention_heads": 32,
116
+ "num_hidden_layers": 36,
117
+ "num_key_value_heads": 8,
118
+ "qk_norm_type": "qwen3",
119
+ "qkv_bias": false,
120
+ "residual_dropout": 0.0,
121
+ "rope_parameters": {
122
+ "rope_theta": 5000000.0,
123
+ "rope_type": "default"
124
+ },
125
+ "rope_scaling_layers": null,
126
+ "rope_theta": 5000000.0,
127
+ "tie_word_embeddings": false,
128
+ "use_cache": true,
129
+ "use_qk_norm": true,
130
+ "vocab_size": 155648
131
+ },
132
+ "tie_word_embeddings": false,
133
+ "transformers_version": "5.3.0",
134
+ "use_frame_special_tokens": true,
135
+ "vit_config": {
136
+ "attention_dropout": 0.0,
137
+ "attn_implementation": "sdpa",
138
+ "float32_attention": true,
139
+ "head_dim": 72,
140
+ "hidden_act": "gelu_pytorch_tanh",
141
+ "hidden_size": 1152,
142
+ "image_default_input_size": [
143
+ 378,
144
+ 378
145
+ ],
146
+ "image_num_pos": 729,
147
+ "image_patch_size": 14,
148
+ "initializer_range": 0.02,
149
+ "intermediate_size": 4304,
150
+ "layer_norm_eps": 1e-06,
151
+ "model_type": "molmoact2",
152
+ "num_attention_heads": 16,
153
+ "num_hidden_layers": 27,
154
+ "num_key_value_heads": 16,
155
+ "residual_dropout": 0.0
156
+ },
157
+ "bos_token_id": 151645,
158
+ "eos_token_id": 151645,
159
+ "pad_token_id": 151643
160
+ }
configuration_molmoact2.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MolmoAct2 configuration
3
+ """
4
+
5
+ from typing import Optional, Any
6
+
7
+ from transformers import PretrainedConfig
8
+ from transformers.modeling_rope_utils import rope_config_validation
9
+ from transformers.utils import logging
10
+
11
+ logger = logging.get_logger(__name__)
12
+
13
+
14
+ class MolmoAct2VitConfig(PretrainedConfig):
15
+ r"""
16
+ This is the configuration class to store the configuration of a [`MolmoAct2VisionTransformer`].
17
+ It is used to instantiate a `MolmoAct2VisionTransformer` according to the specified arguments,
18
+ defining the model architecture.
19
+
20
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
21
+ documentation from [`PretrainedConfig`] for more information.
22
+
23
+ Example:
24
+ ```python
25
+ >>> from transformers import MolmoAct2VitConfig, MolmoAct2VisionTransformer
26
+
27
+ >>> # Initializing a MolmoAct2VitConfig
28
+ >>> configuration = MolmoAct2VitConfig()
29
+
30
+ >>> # Initializing a MolmoAct2VisionTransformer (with random weights)
31
+ >>> model = MolmoAct2VisionTransformer(configuration)
32
+
33
+ >>> # Accessing the model configuration
34
+ >>> configuration = model.config
35
+ ```"""
36
+
37
+ model_type = "molmoact2"
38
+ base_config_key = "vit_config"
39
+
40
+ def __init__(
41
+ self,
42
+ hidden_size: int = 1152,
43
+ intermediate_size: int = 4304,
44
+ num_hidden_layers: int = 27,
45
+ num_attention_heads: int = 16,
46
+ num_key_value_heads: int = 16,
47
+ head_dim: int = 72,
48
+ hidden_act: str = "gelu_pytorch_tanh",
49
+ layer_norm_eps: float = 1e-6,
50
+ image_default_input_size: tuple[int, int] = (378, 378),
51
+ image_patch_size: int = 14,
52
+ image_num_pos: int = 577,
53
+ attention_dropout: float = 0.0,
54
+ residual_dropout: float = 0.0,
55
+ initializer_range: float = 0.02,
56
+ float32_attention: bool = True,
57
+ attn_implementation: str = "eager",
58
+ **kwargs,
59
+ ):
60
+ self.attn_implementation = attn_implementation
61
+ super().__init__(
62
+ attn_implementation=attn_implementation,
63
+ **kwargs
64
+ )
65
+ self.hidden_size = hidden_size
66
+ self.intermediate_size = intermediate_size
67
+ self.num_hidden_layers = num_hidden_layers
68
+ self.num_attention_heads = num_attention_heads
69
+ self.num_key_value_heads = num_key_value_heads
70
+ self.head_dim = head_dim
71
+ self.hidden_act = hidden_act
72
+ self.layer_norm_eps = layer_norm_eps
73
+ self.image_default_input_size = image_default_input_size
74
+ self.image_patch_size = image_patch_size
75
+ self.image_num_pos = image_num_pos
76
+ self.attention_dropout = attention_dropout
77
+ self.residual_dropout = residual_dropout
78
+ self.initializer_range = initializer_range
79
+ self.float32_attention = float32_attention
80
+
81
+ @property
82
+ def image_num_patch(self):
83
+ h, w = self.image_default_input_size
84
+ return h // self.image_patch_size, w // self.image_patch_size
85
+
86
+
87
+ class MolmoAct2AdapterConfig(PretrainedConfig):
88
+ r"""
89
+ This is the configuration class to store the configuration of MolmoAct2Adapter. With MolmoAct2VitConfig,
90
+ It is used to instantiate an MolmoAct2VisionBackbone according to the specified arguments,
91
+ defining the model architecture.
92
+
93
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
94
+ documentation from [`PretrainedConfig`] for more information.
95
+
96
+ Example:
97
+
98
+ ```python
99
+ >>> from transformers import MolmoAct2VitConfig, MolmoAct2AdapterConfig, MolmoAct2VisionBackbone
100
+
101
+ >>> # Initializing a MolmoAct2VitConfig and a MolmoAct2AdapterConfig
102
+ >>> vit_config = MolmoAct2VitConfig()
103
+ >>> adapter_config = MolmoPoolingConfig()
104
+
105
+ >>> # Initializing a MolmoAct2VisionBackbone (with random weights)
106
+ >>> model = MolmoAct2VisionBackbone(vit_config, adapter_config)
107
+
108
+ >>> # Accessing the model configuration
109
+ >>> vit_configuration = model.vit_config
110
+ >>> adapter_configuration = model.adapter_config
111
+ ```"""
112
+
113
+ model_type = "molmoact2"
114
+ base_config_key = "adapter_config"
115
+
116
+ def __init__(
117
+ self,
118
+ vit_layers: tuple = (-3, -9),
119
+ pooling_attention_mask: bool = False,
120
+ hidden_size: int = 1152,
121
+ num_attention_heads: int = 16,
122
+ num_key_value_heads: int = 16,
123
+ head_dim: int = 72,
124
+ float32_attention: bool = True,
125
+ attention_dropout: float = 0.0,
126
+ residual_dropout: float = 0.0,
127
+ hidden_act: str = "silu",
128
+ intermediate_size: int = 18944,
129
+ text_hidden_size: int = 3584,
130
+ image_feature_dropout: float = 0.0,
131
+ initializer_range: float = 0.02,
132
+ attn_implementation: str = "eager",
133
+ **kwargs,
134
+ ):
135
+ self.attn_implementation = attn_implementation
136
+ super().__init__(
137
+ attn_implementation=attn_implementation,
138
+ **kwargs
139
+ )
140
+ self.vit_layers = vit_layers
141
+ self.pooling_attention_mask = pooling_attention_mask
142
+ self.hidden_size = hidden_size
143
+ self.num_attention_heads = num_attention_heads
144
+ self.num_key_value_heads = num_key_value_heads
145
+ self.head_dim = head_dim
146
+ self.float32_attention = float32_attention
147
+ self.attention_dropout = attention_dropout
148
+ self.residual_dropout = residual_dropout
149
+ self.hidden_act = hidden_act
150
+ self.intermediate_size = intermediate_size
151
+ self.text_hidden_size = text_hidden_size
152
+ self.image_feature_dropout = image_feature_dropout
153
+ self.initializer_range = initializer_range
154
+
155
+
156
+ class MolmoAct2TextConfig(PretrainedConfig):
157
+ r"""
158
+ This is the configuration class to store the configuration of a [`MolmoAct2TextModel`]. It is used to instantiate a
159
+ `MolmoAct2TextModel` according to the specified arguments, defining the model architecture.
160
+
161
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
162
+ documentation from [`PretrainedConfig`] for more information.
163
+
164
+ Example:
165
+ ```python
166
+ >>> from transformers import MolmoAct2TextConfig, MolmoAct2TextModel
167
+
168
+ >>> # Initializing a MolmoAct2TextConfig
169
+ >>> configuration = MolmoAct2TextConfig()
170
+
171
+ >>> # Initializing a MolmoAct2TextModel (with random weights)
172
+ >>> model = MolmoAct2TextModel(configuration)
173
+
174
+ >>> # Accessing the model configuration
175
+ >>> configuration = model.config
176
+ ```"""
177
+
178
+ model_type = "molmoact2_text"
179
+ base_config_key = "text_config"
180
+ keys_to_ignore_at_inference = ["past_key_values"]
181
+ base_model_tp_plan = {
182
+ "blocks.*.self_attn.att_proj": "colwise",
183
+ "blocks.*.self_attn.attn_out": "rowwise",
184
+ "blocks.*.mlp.ff_proj": "colwise",
185
+ "blocks.*.mlp.ff_out": "rowwise",
186
+ }
187
+ base_model_pp_plan = {
188
+ "wte": (["input_ids"], ["inputs_embeds"]),
189
+ "blocks": (["hidden_states", "attention_mask"], ["hidden_states"]),
190
+ "ln_f": (["hidden_states"], ["hidden_states"]),
191
+ }
192
+
193
+ def __init__(
194
+ self,
195
+ hidden_size: int = 3584,
196
+ num_attention_heads: int = 28,
197
+ num_key_value_heads: Optional[int] = 4,
198
+ head_dim: int = 128,
199
+ vocab_size: int = 152064,
200
+ additional_vocab_size: int = 128,
201
+ qkv_bias: bool = True,
202
+ num_hidden_layers: int = 48,
203
+ intermediate_size: int = 18944,
204
+ hidden_act: str = "silu",
205
+ embedding_dropout: float=0.0,
206
+ attention_dropout: float=0.0,
207
+ residual_dropout: float = 0.0,
208
+ max_position_embeddings: int = 4096,
209
+ rope_theta: float = 1000000.0,
210
+ rope_scaling: dict[str, Any] = None,
211
+ rope_scaling_layers: Optional[list[int]] = None,
212
+ use_qk_norm: bool = False,
213
+ qk_norm_type: str = "olmo",
214
+ layer_norm_eps: int = 1e-6,
215
+ norm_after: bool = False,
216
+ initializer_range: float = 0.02,
217
+ use_cache=True,
218
+ tie_word_embeddings=False,
219
+ attn_implementation: str = "eager",
220
+ **kwargs,
221
+ ):
222
+ self.attn_implementation = attn_implementation
223
+ super().__init__(
224
+ tie_word_embeddings=tie_word_embeddings,
225
+ attn_implementation=attn_implementation,
226
+ **kwargs
227
+ )
228
+ self.hidden_size = hidden_size
229
+ self.num_attention_heads = num_attention_heads
230
+ if num_key_value_heads is None:
231
+ num_key_value_heads = num_attention_heads
232
+ self.num_key_value_heads = num_key_value_heads
233
+ self.head_dim = head_dim
234
+ self.vocab_size = vocab_size
235
+ self.additional_vocab_size = additional_vocab_size
236
+ self.qkv_bias = qkv_bias
237
+ self.num_hidden_layers = num_hidden_layers
238
+ self.intermediate_size = intermediate_size
239
+ self.hidden_act = hidden_act
240
+ self.embedding_dropout = embedding_dropout
241
+ self.attention_dropout = attention_dropout
242
+ self.residual_dropout = residual_dropout
243
+ self.max_position_embeddings = max_position_embeddings
244
+ self.rope_theta = rope_theta
245
+ self.rope_scaling = rope_scaling
246
+ self.rope_scaling_layers = rope_scaling_layers
247
+ self.use_qk_norm = use_qk_norm
248
+ self.qk_norm_type = qk_norm_type
249
+ self.layer_norm_eps = layer_norm_eps
250
+ self.norm_after = norm_after
251
+ self.initializer_range = initializer_range
252
+ self.use_cache = use_cache
253
+
254
+ # Validate the correctness of rotary position embeddings parameters
255
+ rope_config_validation(self)
256
+
257
+
258
+ class MolmoAct2ActionExpertConfig(PretrainedConfig):
259
+ r"""Configuration for the MolmoAct2 modern action expert."""
260
+
261
+ model_type = "molmoact2_action_expert"
262
+ base_config_key = "action_expert_config"
263
+
264
+ def __init__(
265
+ self,
266
+ max_horizon: int = 32,
267
+ max_action_dim: int = 14,
268
+ hidden_size: int = 1024,
269
+ num_layers: int = 32,
270
+ num_heads: int = 16,
271
+ mlp_ratio: float = 8.0 / 3.0,
272
+ ffn_multiple_of: int = 256,
273
+ timestep_embed_dim: int = 256,
274
+ dropout: float = 0.0,
275
+ attn_dropout: float = 0.0,
276
+ context_layer_norm: bool = True,
277
+ qk_norm: bool = True,
278
+ qk_norm_eps: float = 1e-6,
279
+ rope: bool = True,
280
+ rope_on_cross_attention: bool = False,
281
+ causal_attn: bool = False,
282
+ compile: str = "blocks",
283
+ implementation: str = "new",
284
+ **kwargs,
285
+ ):
286
+ super().__init__(**kwargs)
287
+ if implementation != "new":
288
+ raise ValueError(
289
+ "MolmoAct2 HF export supports only action_expert.implementation='new'."
290
+ )
291
+ self.max_horizon = max_horizon
292
+ self.max_action_dim = max_action_dim
293
+ self.hidden_size = hidden_size
294
+ self.num_layers = num_layers
295
+ self.num_heads = num_heads
296
+ self.mlp_ratio = mlp_ratio
297
+ self.ffn_multiple_of = ffn_multiple_of
298
+ self.timestep_embed_dim = timestep_embed_dim
299
+ self.dropout = dropout
300
+ self.attn_dropout = attn_dropout
301
+ self.context_layer_norm = context_layer_norm
302
+ self.qk_norm = qk_norm
303
+ self.qk_norm_eps = qk_norm_eps
304
+ self.rope = rope
305
+ self.rope_on_cross_attention = rope_on_cross_attention
306
+ self.causal_attn = causal_attn
307
+ self.compile = compile
308
+ self.implementation = implementation
309
+
310
+
311
+ class MolmoAct2Config(PretrainedConfig):
312
+ r"""
313
+ This is the configuration class to store the configuration of a [`MolmoAct2ForConditionalGeneration`].
314
+ It is used to instantiate an MolmoAct2 model according to the specified arguments, defining the model architecture.
315
+
316
+ Example:
317
+
318
+ ```python
319
+ >>> from transformers import MolmoAct2Config, MolmoAct2VitConfig, MolmoAct2AdapterConfig, MolmoAct2TextConfig
320
+
321
+ >>> # Initializing a MolmoAct2VitConfig
322
+ >>> vit_config = MolmoAct2VitConfig()
323
+
324
+ >>> # Initializing a MolmoAct2AdapterConfig
325
+ >>> adapter_config = MolmoAct2AdapterConfig()
326
+
327
+ >>> # Initializing a MolmoAct2TextConfig
328
+ >>> text_config = MolmoAct2TextConfig()
329
+
330
+ >>> # Initializing a MolmoAct2Config
331
+ >>> configuration = MolmoAct2Config(
332
+ >>> vit_config=vit_config,
333
+ >>> adapter_config=adapter_config,
334
+ >>> text_config=text_config,
335
+ >>> image_start_token_id=151936,
336
+ >>> image_end_token_id=151937,
337
+ >>> image_patch_id=151938,
338
+ >>> image_col_id=151939,
339
+ >>> low_res_image_start_token_id=151940,
340
+ >>> image_low_res_id=151942,
341
+ >>> frame_start_token_id=151943,
342
+ >>> frame_end_token_id=151944,
343
+ >>> )
344
+
345
+ >>> # Initializing a model
346
+ >>> model = MolmoAct2ForConditionalGeneration(configuration)
347
+
348
+ >>> # Accessing the model configuration
349
+ >>> configuration = model.config
350
+ ```"""
351
+
352
+ model_type = "molmoact2"
353
+ sub_configs = {
354
+ "text_config": MolmoAct2TextConfig,
355
+ "vit_config": MolmoAct2VitConfig,
356
+ "adapter_config": MolmoAct2AdapterConfig,
357
+ "action_expert_config": MolmoAct2ActionExpertConfig,
358
+ }
359
+
360
+ def __init__(
361
+ self,
362
+ vit_config: MolmoAct2VitConfig = None,
363
+ adapter_config: MolmoAct2AdapterConfig = None,
364
+ text_config: MolmoAct2TextConfig = None,
365
+ action_expert_config: MolmoAct2ActionExpertConfig = None,
366
+ image_start_token_id: int = None,
367
+ low_res_image_start_token_id: int = None,
368
+ image_end_token_id: int = None,
369
+ image_low_res_id: int = None,
370
+ image_patch_id: int = None,
371
+ image_col_id: int = None,
372
+ frame_start_token_id: int = None,
373
+ frame_end_token_id: int = None,
374
+ use_frame_special_tokens: bool = True,
375
+ initializer_range: float = 0.02,
376
+ add_action_expert: bool = True,
377
+ max_action_dim: int = 7,
378
+ action_horizon: int = 16,
379
+ n_obs_steps: int = 1,
380
+ action_format: str = "continuous",
381
+ state_format: str = "discrete",
382
+ action_expert_condition_source: str = "kv_cache",
383
+ action_expert_layer_mode: str = "per_layer",
384
+ flow_matching_num_steps: int = 10,
385
+ flow_matching_cutoff: float = 1.0,
386
+ flow_matching_time_offset: float = 0.001,
387
+ flow_matching_time_scale: float = 0.999,
388
+ flow_matching_beta_alpha: float = 1.0,
389
+ flow_matching_beta_beta: float = 1.5,
390
+ mask_action_dim_padding: bool = True,
391
+ enable_depth_reasoning: bool = False,
392
+ depth_mode: int = 2,
393
+ num_depth_codes: int = 100,
394
+ action_expert_depth_gate: bool = False,
395
+ action_expert_depth_gate_per_layer: bool = False,
396
+ action_expert_depth_gate_init_bias: float = -4.0,
397
+ action_output_token_id: int = None,
398
+ action_start_token_id: int = None,
399
+ action_end_token_id: int = None,
400
+ action_token_start_id: int = None,
401
+ num_action_tokens: int = 0,
402
+ depth_output_token_id: int = None,
403
+ depth_start_token_id: int = None,
404
+ depth_end_token_id: int = None,
405
+ depth_token_start_id: int = None,
406
+ num_depth_tokens: int = 0,
407
+ state_start_token_id: int = None,
408
+ state_end_token_id: int = None,
409
+ state_token_start_id: int = None,
410
+ num_state_tokens: int = 0,
411
+ add_setup_tokens: bool = True,
412
+ add_control_tokens: bool = True,
413
+ norm_stats_filename: str = "norm_stats.json",
414
+ **kwargs,
415
+ ):
416
+ super().__init__(**kwargs)
417
+ if vit_config is None:
418
+ self.vit_config = MolmoAct2VitConfig()
419
+ elif isinstance(vit_config, dict):
420
+ self.vit_config = MolmoAct2VitConfig(**vit_config)
421
+ else:
422
+ self.vit_config = vit_config
423
+ if adapter_config is None:
424
+ self.adapter_config = MolmoAct2AdapterConfig()
425
+ elif isinstance(adapter_config, dict):
426
+ self.adapter_config = MolmoAct2AdapterConfig(**adapter_config)
427
+ else:
428
+ self.adapter_config = adapter_config
429
+ if text_config is None:
430
+ self.text_config = MolmoAct2TextConfig()
431
+ elif isinstance(text_config, dict):
432
+ self.text_config = MolmoAct2TextConfig(**text_config)
433
+ else:
434
+ self.text_config = text_config
435
+ self.add_action_expert = bool(add_action_expert)
436
+ if not self.add_action_expert:
437
+ self.action_expert_config = None
438
+ elif action_expert_config is None:
439
+ self.action_expert_config = MolmoAct2ActionExpertConfig(
440
+ max_horizon=action_horizon,
441
+ max_action_dim=max_action_dim,
442
+ num_layers=self.text_config.num_hidden_layers,
443
+ )
444
+ elif isinstance(action_expert_config, dict):
445
+ self.action_expert_config = MolmoAct2ActionExpertConfig(**action_expert_config)
446
+ else:
447
+ self.action_expert_config = action_expert_config
448
+ if self.add_action_expert:
449
+ self._validate_release_action_config(
450
+ action_expert_config=self.action_expert_config,
451
+ action_expert_condition_source=action_expert_condition_source,
452
+ action_expert_layer_mode=action_expert_layer_mode,
453
+ state_format=state_format,
454
+ )
455
+ self.image_start_token_id = image_start_token_id
456
+ self.low_res_image_start_token_id = low_res_image_start_token_id
457
+ self.image_end_token_id = image_end_token_id
458
+ self.image_low_res_id = image_low_res_id
459
+ self.image_high_res_id = image_patch_id
460
+ self.image_patch_id = image_patch_id
461
+ self.image_col_id = image_col_id
462
+ self.frame_start_token_id = frame_start_token_id
463
+ self.frame_end_token_id = frame_end_token_id
464
+ self.use_frame_special_tokens = use_frame_special_tokens
465
+ self.initializer_range = initializer_range
466
+ self.max_action_dim = max_action_dim
467
+ self.action_horizon = action_horizon
468
+ self.n_obs_steps = n_obs_steps
469
+ self.action_format = action_format
470
+ self.state_format = state_format
471
+ self.action_expert_condition_source = action_expert_condition_source
472
+ self.action_expert_layer_mode = action_expert_layer_mode
473
+ self.flow_matching_num_steps = flow_matching_num_steps
474
+ self.flow_matching_cutoff = flow_matching_cutoff
475
+ self.flow_matching_time_offset = flow_matching_time_offset
476
+ self.flow_matching_time_scale = flow_matching_time_scale
477
+ self.flow_matching_beta_alpha = flow_matching_beta_alpha
478
+ self.flow_matching_beta_beta = flow_matching_beta_beta
479
+ self.mask_action_dim_padding = mask_action_dim_padding
480
+ self.enable_depth_reasoning = enable_depth_reasoning
481
+ self.depth_mode = depth_mode
482
+ self.num_depth_codes = num_depth_codes
483
+ self.action_expert_depth_gate = action_expert_depth_gate
484
+ self.action_expert_depth_gate_per_layer = action_expert_depth_gate_per_layer
485
+ self.action_expert_depth_gate_init_bias = action_expert_depth_gate_init_bias
486
+ self.action_output_token_id = action_output_token_id
487
+ self.action_start_token_id = action_start_token_id
488
+ self.action_end_token_id = action_end_token_id
489
+ self.action_token_start_id = action_token_start_id
490
+ self.num_action_tokens = num_action_tokens
491
+ self.depth_output_token_id = depth_output_token_id
492
+ self.depth_start_token_id = depth_start_token_id
493
+ self.depth_end_token_id = depth_end_token_id
494
+ self.depth_token_start_id = depth_token_start_id
495
+ self.num_depth_tokens = num_depth_tokens
496
+ self.state_start_token_id = state_start_token_id
497
+ self.state_end_token_id = state_end_token_id
498
+ self.state_token_start_id = state_token_start_id
499
+ self.num_state_tokens = num_state_tokens
500
+ self.add_setup_tokens = add_setup_tokens
501
+ self.add_control_tokens = add_control_tokens
502
+ self.norm_stats_filename = norm_stats_filename
503
+
504
+ @staticmethod
505
+ def _validate_release_action_config(
506
+ *,
507
+ action_expert_config: MolmoAct2ActionExpertConfig,
508
+ action_expert_condition_source: str,
509
+ action_expert_layer_mode: str,
510
+ state_format: str,
511
+ ) -> None:
512
+ if action_expert_config.implementation != "new":
513
+ raise ValueError(
514
+ "MolmoAct2 HF export supports only action_expert.implementation='new'."
515
+ )
516
+ if action_expert_condition_source != "kv_cache":
517
+ raise ValueError(
518
+ "MolmoAct2 HF export supports only action_expert_condition_source='kv_cache'."
519
+ )
520
+ if action_expert_layer_mode != "per_layer":
521
+ raise ValueError(
522
+ "MolmoAct2 HF export supports only action_expert_layer_mode='per_layer'."
523
+ )
524
+ if state_format != "discrete":
525
+ raise ValueError("MolmoAct2 HF export supports only state_format='discrete'.")
526
+
527
+ @property
528
+ def image_num_patch(self):
529
+ assert self.vit_config is not None
530
+ return self.vit_config.image_num_patch
531
+
532
+ @property
533
+ def num_attention_heads(self):
534
+ return self.text_config.num_attention_heads
535
+
536
+ @property
537
+ def num_key_value_heads(self):
538
+ return self.text_config.num_key_value_heads
539
+
540
+ @property
541
+ def head_dim(self):
542
+ return self.text_config.head_dim
543
+
544
+ @property
545
+ def num_hidden_layers(self):
546
+ return self.text_config.num_hidden_layers
547
+
548
+ @property
549
+ def hidden_size(self):
550
+ return self.text_config.hidden_size
551
+
552
+ @property
553
+ def vocab_size(self):
554
+ return self.text_config.vocab_size
555
+
556
+ @property
557
+ def max_position_embeddings(self):
558
+ return self.text_config.max_position_embeddings
559
+
560
+
561
+ MolmoAct2VitConfig.register_for_auto_class()
562
+ MolmoAct2AdapterConfig.register_for_auto_class()
563
+ MolmoAct2TextConfig.register_for_auto_class()
564
+ MolmoAct2ActionExpertConfig.register_for_auto_class()
565
+ MolmoAct2Config.register_for_auto_class()
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151645,
3
+ "eos_token_id": 151645,
4
+ "pad_token_id": 151643,
5
+ "transformers_version": "5.3.0"
6
+ }
image_processing_molmoact2.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image processor class for MolmoAct2"""
2
+ from typing import Optional, Union
3
+ import numpy as np
4
+ import einops
5
+ import torch
6
+ import torchvision.transforms
7
+
8
+ from transformers.image_utils import (
9
+ IMAGENET_STANDARD_MEAN,
10
+ IMAGENET_STANDARD_STD,
11
+ ImageInput,
12
+ PILImageResampling,
13
+ make_flat_list_of_images,
14
+ valid_images,
15
+ to_numpy_array,
16
+ )
17
+ from transformers.image_transforms import convert_to_rgb
18
+ from transformers.processing_utils import ImagesKwargs
19
+ from transformers.image_processing_utils import BaseImageProcessor, get_size_dict
20
+ from transformers.utils import logging
21
+ from transformers.feature_extraction_utils import BatchFeature
22
+ from transformers.utils import TensorType, logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ def normalize_image(
29
+ image: np.ndarray,
30
+ image_mean: list[float],
31
+ image_std: list[float],
32
+ ) -> np.ndarray:
33
+ if np.allclose(image_mean, [0.5, 0.5, 0.5]) and np.allclose(image_std, [0.5, 0.5, 0.5]):
34
+ return image * np.asarray(2.0, dtype=np.float32) - np.asarray(1.0, dtype=np.float32)
35
+ image -= np.array(image_mean, dtype=np.float32)[None, None, :]
36
+ image /= np.array(image_std, dtype=np.float32)[None, None, :]
37
+ return image
38
+
39
+
40
+ def resize_image(
41
+ image: np.ndarray,
42
+ desired_output_size: list[int],
43
+ resample: PILImageResampling,
44
+ ) -> np.ndarray:
45
+ image = torch.permute(torch.from_numpy(image), [2, 0, 1])
46
+ dtype = image.dtype
47
+ if torch.is_floating_point(image):
48
+ in_min = 0.0
49
+ in_max = 1.0
50
+ resized = torchvision.transforms.Resize(
51
+ desired_output_size,
52
+ resample,
53
+ antialias=False,
54
+ )(image)
55
+ resized = torch.clip(resized, 0.0, 1.0).to(dtype)
56
+ else:
57
+ assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(image.dtype)
58
+ in_min = 0.0
59
+ in_max = 255.0
60
+ resized = torchvision.transforms.Resize(
61
+ desired_output_size,
62
+ resample,
63
+ antialias=False,
64
+ )(image)
65
+ resized = torch.clip(resized, 0, 255).to(dtype)
66
+
67
+ resized = resized.to(torch.float32)
68
+ resized = (resized - in_min) / (in_max - in_min)
69
+
70
+ resized = torch.permute(resized, [1, 2, 0]).numpy()
71
+
72
+ return resized
73
+
74
+
75
+ def select_tiling(h, w, patch_size, max_num_crops):
76
+ """Divide in image of size [w, h] in up to max_num_patches of size patch_size"""
77
+ original_size = np.stack([h, w]) # [1, 2]
78
+ original_res = h * w
79
+ tilings = []
80
+ for i in range(1, max_num_crops + 1):
81
+ for j in range(1, max_num_crops + 1):
82
+ if i*j <= max_num_crops:
83
+ tilings.append((i, j))
84
+ # sort so argmin and argmax favour smaller tilings in the event of a tie
85
+ tilings.sort(key=lambda x: (x[0]*x[1], x[0]))
86
+ candidate_tilings = np.array(tilings, dtype=np.int32) # [n_resolutions, 2]
87
+ candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2]
88
+
89
+ # How much we would need to scale the image to fit exactly in each tiling
90
+ original_size = np.stack([h, w], dtype=np.float32) # [1, 2]
91
+
92
+ # The original size can be zero in rare cases if the image is smaller than the margin
93
+ # In those cases letting the scale become infinite means the tiling is based on the
94
+ # other side, or falls back to the smallest tiling
95
+ with np.errstate(divide='ignore'):
96
+ required_scale_d = candidate_resolutions.astype(np.float32) / original_size,
97
+ required_scale = np.min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1]
98
+ if np.all(required_scale < 1):
99
+ # We are forced to downscale, so try to minimize the amount of downscaling
100
+ ix = np.argmax(required_scale)
101
+ else:
102
+ # Pick the resolution that required the least upscaling so that it most closely fits the image
103
+ required_scale = np.where(required_scale < 1.0, 10e9, required_scale)
104
+ ix = np.argmin(required_scale)
105
+ return candidate_tilings[ix]
106
+
107
+
108
+ def build_resized_image(
109
+ image: np.ndarray,
110
+ base_image_input_size: list[int],
111
+ resample: PILImageResampling,
112
+ image_mean: list[float],
113
+ image_std: list[float],
114
+ image_patch_size: int,
115
+ ) -> tuple[np.ndarray, np.ndarray]:
116
+ resized = resize_image(
117
+ image, base_image_input_size, resample,
118
+ )
119
+ resized = normalize_image(resized, image_mean, image_std)
120
+ if len(resized.shape) == 3:
121
+ resized = np.expand_dims(resized, 0)
122
+ crop_patch_w = base_image_input_size[1] // image_patch_size
123
+ crop_patch_h = base_image_input_size[0] // image_patch_size
124
+ resize_idx = np.arange(crop_patch_w*crop_patch_h).reshape([crop_patch_h, crop_patch_w])
125
+ return resized, resize_idx
126
+
127
+
128
+ def build_overlapping_crops(
129
+ image: np.ndarray,
130
+ max_crops: int,
131
+ overlap_margins: list[int],
132
+ base_image_input_size: list[int],
133
+ resample: PILImageResampling,
134
+ image_mean: list[float],
135
+ image_std: list[float],
136
+ image_patch_size: int,
137
+ ) -> tuple[np.ndarray, np.ndarray]:
138
+ """Decompose an image into a set of overlapping crops
139
+
140
+ :return crop_arr: [n_crops, h, w, 3] The crops
141
+ :return patch_idx: [overlap_patch_h, overlap_patch_w] For each patch in the resized image
142
+ the crops were extracted from, what patch in `crop_arr` it corresponds to
143
+ """
144
+ original_image_h, original_image_w = image.shape[:2]
145
+ crop_size = base_image_input_size[0]
146
+ assert base_image_input_size[0] == base_image_input_size[1]
147
+
148
+ left_margin, right_margin = overlap_margins
149
+ total_margin_pixels = image_patch_size * (right_margin + left_margin) # pixels removed per dim
150
+ crop_patches = base_image_input_size[0] // image_patch_size # patches per crop dim
151
+ crop_window_patches = crop_patches - (right_margin + left_margin) # usable patches
152
+ crop_window_size = crop_window_patches * image_patch_size
153
+ crop_patch_w = base_image_input_size[1] // image_patch_size
154
+ crop_patch_h = base_image_input_size[0] // image_patch_size
155
+ original_image_h, original_image_w = image.shape[:2]
156
+ crop_size = base_image_input_size[0]
157
+
158
+ # Decide how to tile the image, to account for the overlap margins we compute the tiling
159
+ # as if we had an image without the margins and were using a crop size without the margins
160
+ tiling = select_tiling(
161
+ original_image_h - total_margin_pixels,
162
+ original_image_w - total_margin_pixels,
163
+ crop_window_size,
164
+ max_crops,
165
+ )
166
+
167
+ src = resize_image(
168
+ image,
169
+ [tiling[0]*crop_window_size+total_margin_pixels, tiling[1]*crop_window_size+total_margin_pixels],
170
+ resample,
171
+ )
172
+ src = normalize_image(src, image_mean, image_std)
173
+
174
+ # Now we have to split the image into crops, and track what patches came from
175
+ # where in `patch_idx_arr`
176
+ n_crops = tiling[0] * tiling[1]
177
+ crop_arr = np.zeros([n_crops, crop_size, crop_size, 3], dtype=src.dtype)
178
+ patch_idx_arr = np.zeros([n_crops, crop_patch_h, crop_patch_w], dtype=np.int32)
179
+ on_crop = 0
180
+ for i in range(tiling[0]):
181
+ # Slide over `src` by `crop_window_size` steps, but extract crops of size `crops_size`
182
+ # which results in overlapping crop windows
183
+ y0 = i*crop_window_size
184
+ for j in range(tiling[1]):
185
+ x0 = j*crop_window_size
186
+ crop_arr[on_crop] = src[y0:y0+crop_size, x0:x0+crop_size]
187
+ patch_idx = np.arange(crop_patch_w*crop_patch_h).reshape(crop_patch_h, crop_patch_w)
188
+ patch_idx += on_crop * crop_patch_h * crop_patch_w
189
+
190
+ # Mask out idx that are in the overlap region
191
+ if i != 0:
192
+ patch_idx[:left_margin, :] = -1
193
+ if j != 0:
194
+ patch_idx[:, :left_margin] = -1
195
+ if i != tiling[0]-1:
196
+ patch_idx[-right_margin:, :] = -1
197
+ if j != tiling[1]-1:
198
+ patch_idx[:, -right_margin:] = -1
199
+ patch_idx_arr[on_crop] = patch_idx
200
+ on_crop += 1
201
+
202
+ # `patch_idx_arr` is ordered crop-by-crop, here we transpose `patch_idx_arr`
203
+ # so it is ordered left-to-right order
204
+ patch_idx_arr = np.reshape(
205
+ patch_idx_arr,
206
+ [tiling[0], tiling[1], crop_patch_h, crop_patch_w]
207
+ )
208
+ patch_idx_arr = np.transpose(patch_idx_arr, [0, 2, 1, 3])
209
+ patch_idx_arr = np.reshape(patch_idx_arr, [-1])
210
+
211
+ # Now get the parts not in the overlap region, so it should map each patch in `src`
212
+ # to the correct patch it should come from in `crop_arr`
213
+ patch_idx_arr = patch_idx_arr[patch_idx_arr >= 0].reshape(
214
+ src.shape[0]//image_patch_size,
215
+ src.shape[1]//image_patch_size,
216
+ )
217
+ return crop_arr, patch_idx_arr
218
+
219
+
220
+ def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
221
+ """Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
222
+ if len(array.shape) == 3:
223
+ n_crops, h, w = array.shape
224
+ h_patches = h//patch_size
225
+ w_patches = w//patch_size
226
+ array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
227
+ array = np.transpose(array, [0, 1, 3, 2, 4])
228
+ array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size])
229
+ return array
230
+ else:
231
+ n_crops, h, w, c = array.shape
232
+ h_patches = h//patch_size
233
+ w_patches = w//patch_size
234
+ array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
235
+ array = np.transpose(array, [0, 1, 3, 2, 4, 5])
236
+ array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size*c])
237
+ return array
238
+
239
+
240
+ def arange_for_pooling(
241
+ idx_arr: np.ndarray,
242
+ pool_h: int,
243
+ pool_w: int,
244
+ ) -> np.ndarray:
245
+ h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0]
246
+ w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1]
247
+ idx_arr = np.pad(idx_arr, [[h_pad//2, (h_pad+1)//2], [w_pad//2, (w_pad+1)//2]],
248
+ mode='constant',constant_values=-1)
249
+ return einops.rearrange(
250
+ idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
251
+
252
+
253
+ def image_to_patches_and_grids(
254
+ image: np.ndarray,
255
+ max_crops: int,
256
+ overlap_margins: list[int],
257
+ base_image_input_size: list[int],
258
+ resample: PILImageResampling,
259
+ image_mean: list[float],
260
+ image_std: list[float],
261
+ image_patch_size: int,
262
+ image_pooling_w: int,
263
+ image_pooling_h: int,
264
+ crop_mode: str = "overlap-and-resize-c2",
265
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
266
+ """
267
+ :return image_grids, the shape of each (low-res, high-res) image after pooling
268
+ :return crops, the image crops to processes with the ViT
269
+ :return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the
270
+ patches in `crops` to pool for that token, masked with -1
271
+ """
272
+ if isinstance(base_image_input_size, int):
273
+ base_image_input_size = (base_image_input_size, base_image_input_size)
274
+
275
+ base_image_input_d = image_patch_size
276
+ pooling_w = image_pooling_w
277
+ pooling_h = image_pooling_h
278
+ crop_patch_w = base_image_input_size[1] // base_image_input_d
279
+ crop_patch_h = base_image_input_size[0] // base_image_input_d
280
+
281
+ if crop_mode == "resize":
282
+ resized, resize_idx = build_resized_image(
283
+ image,
284
+ base_image_input_size,
285
+ resample,
286
+ image_mean,
287
+ image_std,
288
+ image_patch_size,
289
+ )
290
+ resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
291
+ resized_h, resized_w = resize_idx.shape[:2]
292
+ resize_idx = resize_idx.reshape([-1, pooling_h * pooling_w])
293
+ image_grid = [np.array([resized_h, resized_w, 0, 0])]
294
+ return (
295
+ np.stack(image_grid, 0),
296
+ batch_pixels_to_patches(resized, image_patch_size),
297
+ resize_idx,
298
+ )
299
+
300
+ if crop_mode not in {"overlap-and-resize-c2", "overlap-and-resize"}:
301
+ raise ValueError(f"Unsupported MolmoAct2 image crop_mode {crop_mode!r}.")
302
+
303
+ crop_arr, patch_idx_arr = build_overlapping_crops(
304
+ image,
305
+ max_crops,
306
+ overlap_margins,
307
+ base_image_input_size,
308
+ resample,
309
+ image_mean,
310
+ image_std,
311
+ image_patch_size,
312
+ )
313
+ pooling_idx = arange_for_pooling(patch_idx_arr, pooling_h, pooling_w)
314
+ h, w = pooling_idx.shape[:2]
315
+ pooling_idx = pooling_idx.reshape([-1, pooling_h*pooling_w])
316
+
317
+ # Finally do the same for the global image
318
+ resized, resize_idx = build_resized_image(
319
+ image,
320
+ base_image_input_size,
321
+ resample,
322
+ image_mean,
323
+ image_std,
324
+ image_patch_size,
325
+ )
326
+ crop_arr = np.concatenate([resized, crop_arr], 0)
327
+
328
+ resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
329
+ resized_h, resized_w = resize_idx.shape[:2]
330
+ resize_idx = resize_idx.reshape([-1, pooling_h*pooling_w])
331
+
332
+ # Global image goes first, so the order of patches in previous crops gets increased
333
+ pooling_idx = np.where(
334
+ pooling_idx >= 0,
335
+ pooling_idx + crop_patch_h*crop_patch_w,
336
+ -1
337
+ )
338
+ pooling_idx = np.concatenate([resize_idx, pooling_idx])
339
+ image_grid = [np.array([resized_h, resized_w, h, w])]
340
+
341
+ return (
342
+ np.stack(image_grid, 0),
343
+ batch_pixels_to_patches(crop_arr, image_patch_size),
344
+ pooling_idx
345
+ )
346
+
347
+
348
+ class MolmoAct2ImagesKwargs(ImagesKwargs, total=False):
349
+ max_crops: Optional[int]
350
+ overlap_margins: Optional[list[int]]
351
+ crop_mode: Optional[str]
352
+ patch_size: Optional[int]
353
+ pooling_size: Optional[list[int]]
354
+
355
+
356
+ class MolmoAct2ImageProcessor(BaseImageProcessor):
357
+ r"""
358
+ Constructs a MolmoAct2 image processor that preprocesses images for the model.
359
+
360
+ Args:
361
+ size (`dict[str, int]` *optional*, defaults to `{"height": 378, "width": 378}`):
362
+ Size of the image after resizing.
363
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
364
+ Resampling filter to use when resizing the image.
365
+ image_mean (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
366
+ Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
367
+ image_std (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
368
+ Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
369
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
370
+ Whether to convert the image to RGB.
371
+ max_crops (`int`, *optional*, defaults to `8`):
372
+ Maximum number of crops to use per image.
373
+ overlap_margins (`list[int]`, *optional*, defaults to `[4, 4]`):
374
+ Overlap margins to use.
375
+ patch_size (`int`, *optional*, defaults to 14):
376
+ The spatial patch size of the vision encoder.
377
+ pooling_size (`list[int]`, *optional*, defaults to `[2, 2]`):
378
+ The pooling size of the vision adapter.
379
+ """
380
+
381
+ model_input_names = ["pixel_values", "image_token_pooling", "image_grids", "image_num_crops"]
382
+
383
+ def __init__(
384
+ self,
385
+ size: Optional[dict[str, int]] = None,
386
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
387
+ image_mean: Optional[Union[float, list[float]]] = None,
388
+ image_std: Optional[Union[float, list[float]]] = None,
389
+ do_convert_rgb: bool = True,
390
+ max_crops: int = 8,
391
+ overlap_margins: list[int] = [4, 4],
392
+ crop_mode: str = "overlap-and-resize-c2",
393
+ patch_size: int = 14,
394
+ pooling_size: list[int] = [2, 2],
395
+ **kwargs,
396
+ ) -> None:
397
+ super().__init__(**kwargs)
398
+ size = size if size is not None else {"height": 378, "width": 378}
399
+ size = get_size_dict(size, default_to_square=True)
400
+ self.size = size
401
+
402
+ self.resample = resample
403
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
404
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
405
+ self.do_convert_rgb = do_convert_rgb
406
+
407
+ self.max_crops = max_crops
408
+ self.overlap_margins = overlap_margins
409
+ self.crop_mode = crop_mode
410
+ self.patch_size = patch_size
411
+ self.pooling_size = pooling_size
412
+
413
+ def preprocess(
414
+ self,
415
+ images: ImageInput,
416
+ size: Optional[dict[str, int]] = None,
417
+ resample: Optional[PILImageResampling] = None,
418
+ image_mean: Optional[Union[float, list[float]]] = None,
419
+ image_std: Optional[Union[float, list[float]]] = None,
420
+ do_convert_rgb: Optional[bool] = None,
421
+ max_crops: Optional[int] = None,
422
+ overlap_margins: Optional[list[int]] = None,
423
+ crop_mode: Optional[str] = None,
424
+ patch_size: Optional[int] = None,
425
+ pooling_size: Optional[list[int]] = None,
426
+ return_tensors: Optional[Union[str, TensorType]] = None,
427
+ **kwargs,
428
+ ) -> BatchFeature:
429
+ """
430
+ Args:
431
+ images (`ImageInput`):
432
+ Image to preprocess.
433
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
434
+ Size of the image after resizing.
435
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
436
+ Resampling filter to use when resizing the image. This can be one of the enum `PILImageResampling`. Only
437
+ has an effect if `do_resize` is set to `True`.
438
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
439
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
440
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
441
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
442
+ `True`.
443
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
444
+ Whether to convert the image to RGB.
445
+ max_crops (`int`, *optional*, defaults to `self.max_crops`):
446
+ Maximum number of crops to use per image.
447
+ overlap_margins (`list[int]`, *optional*, defaults to `self.overlap_margins`):
448
+ Overlap margins to use.
449
+ patch_size (`int`, *optional*, defaults to `self.patch_size`):
450
+ The spatial patch size of the vision encoder.
451
+ pooling_size (`list[int]`, *optional*, defaults to `self.pooling_size`):
452
+ The pooling size of the vision adapter.
453
+ return_tensors (`str` or `TensorType`, *optional*):
454
+ The type of tensors to return. Can be one of:
455
+ - Unset: Return a list of `np.ndarray`.
456
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
457
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
458
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
459
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
460
+
461
+ Returns:
462
+ A `BatchFeature` containing the following keys:
463
+ - `pixel_values`: The preprocessed images.
464
+ - `image_token_pooling`: The indices of the patches in `crops` to pool for each token in `image_tokens`.
465
+ - `image_grids`: The image grids.
466
+ - `image_num_crops`: The number of crops for each image.
467
+ """
468
+ if size is not None:
469
+ if "height" not in size or "width" not in size:
470
+ raise ValueError("size must contain 'height' and 'width' keys.")
471
+ else:
472
+ size = {**self.size}
473
+
474
+ base_image_input_size = [size["height"], size["width"]]
475
+
476
+ resample = resample or self.resample
477
+ image_mean = image_mean or self.image_mean
478
+ image_std = image_std or self.image_std
479
+ do_convert_rgb = do_convert_rgb or self.do_convert_rgb
480
+
481
+ max_crops = max_crops or self.max_crops
482
+ overlap_margins = overlap_margins or self.overlap_margins
483
+ crop_mode = crop_mode or self.crop_mode
484
+ patch_size = patch_size or self.patch_size
485
+ pooling_size = pooling_size or self.pooling_size
486
+
487
+ image_pooling_h, image_pooling_w = pooling_size
488
+
489
+ if images is not None:
490
+ images = self.fetch_images(images)
491
+ images = make_flat_list_of_images(images)
492
+
493
+ if images is not None and not valid_images(images):
494
+ raise ValueError(
495
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
496
+ "torch.Tensor, tf.Tensor or jax.ndarray."
497
+ )
498
+
499
+ if do_convert_rgb:
500
+ images = [convert_to_rgb(image) for image in images]
501
+
502
+ # All transformations expect numpy arrays.
503
+ images = [to_numpy_array(image) for image in images]
504
+
505
+ data = {}
506
+ if images is not None:
507
+ batch_grids = []
508
+ batch_crops = []
509
+ batch_pooled_patches_idx = []
510
+ batch_num_crops = []
511
+
512
+ for image in images:
513
+ image_grid, crops, pooled_idx = image_to_patches_and_grids(
514
+ image,
515
+ max_crops,
516
+ overlap_margins,
517
+ base_image_input_size,
518
+ resample,
519
+ image_mean,
520
+ image_std,
521
+ patch_size,
522
+ image_pooling_w,
523
+ image_pooling_h,
524
+ crop_mode,
525
+ )
526
+ batch_grids.append(image_grid)
527
+ batch_crops.append(crops)
528
+ batch_pooled_patches_idx.append(pooled_idx)
529
+ batch_num_crops.append(crops.shape[0])
530
+
531
+ pixel_values = np.concatenate(batch_crops, 0)
532
+ image_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
533
+ image_grids = np.concatenate(batch_grids, 0)
534
+ image_num_crops = np.array(batch_num_crops)
535
+
536
+ data.update(
537
+ pixel_values=pixel_values,
538
+ image_token_pooling=image_token_pooling,
539
+ image_grids=image_grids,
540
+ image_num_crops=image_num_crops,
541
+ )
542
+
543
+ return BatchFeature(data, tensor_type=return_tensors)
544
+
545
+
546
+ MolmoAct2ImageProcessor.register_for_auto_class()
inference.py ADDED
@@ -0,0 +1,768 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Inference utilities for MolmoAct2"""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Iterable, Optional, Sequence, Tuple
5
+
6
+ import torch
7
+ from torch.nn import functional as F
8
+ from transformers.cache_utils import Cache
9
+ from transformers.configuration_utils import PretrainedConfig
10
+
11
+
12
+ @dataclass
13
+ class _ActionFlowInputs:
14
+ trajectory: torch.Tensor
15
+ context: Any
16
+ modulations: Sequence[Any]
17
+ action_dim_is_pad: Optional[torch.Tensor]
18
+
19
+
20
+ @dataclass
21
+ class _ActionFlowCudaGraph:
22
+ key: Tuple[Any, ...]
23
+ graph: torch.cuda.CUDAGraph
24
+ static_inputs: _ActionFlowInputs
25
+ output: torch.Tensor
26
+
27
+
28
+ @dataclass
29
+ class _DepthDecodeCudaGraphLayerStage:
30
+ residual: torch.Tensor
31
+ query: torch.Tensor
32
+ key: torch.Tensor
33
+ value: torch.Tensor
34
+
35
+
36
+ @dataclass
37
+ class _DepthDecodeCudaGraphPostStage:
38
+ graph: torch.cuda.CUDAGraph
39
+ attn_context: torch.Tensor
40
+
41
+
42
+ @dataclass
43
+ class _DepthDecodeCudaGraph:
44
+ cache_key: Tuple[Any, ...]
45
+ pre_graph: torch.cuda.CUDAGraph
46
+ token_ids: torch.Tensor
47
+ cos: torch.Tensor
48
+ sin: torch.Tensor
49
+ positions: torch.Tensor
50
+ stages: Sequence[_DepthDecodeCudaGraphLayerStage]
51
+ post_graphs: Sequence[_DepthDecodeCudaGraphPostStage]
52
+ output: torch.Tensor
53
+
54
+
55
+ @dataclass
56
+ class _DepthDecodeCudaGraphSpec:
57
+ eligible: bool
58
+ cache_key_prefix: Tuple[Any, ...]
59
+ num_hidden_layers: int
60
+ head_dim: int
61
+ num_attention_heads: int
62
+
63
+
64
+ def _cache_seq_len_int(past_key_values: Optional[Cache]) -> int:
65
+ if past_key_values is None:
66
+ return 0
67
+ seq_len = past_key_values.get_seq_length()
68
+ if torch.is_tensor(seq_len):
69
+ return int(seq_len.item())
70
+ return int(seq_len)
71
+
72
+
73
+ def _cache_max_len_int(past_key_values: Optional[Cache]) -> int:
74
+ if past_key_values is None:
75
+ return -1
76
+ max_len = past_key_values.get_max_cache_shape()
77
+ if torch.is_tensor(max_len):
78
+ return int(max_len.item())
79
+ return int(max_len)
80
+
81
+
82
+ def _iter_cache_key_values(
83
+ past_key_values: Cache,
84
+ ) -> Iterable[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]]:
85
+ layers = getattr(past_key_values, "layers", None)
86
+ if layers is not None:
87
+ for layer in layers:
88
+ yield getattr(layer, "keys", None), getattr(layer, "values", None)
89
+ return
90
+ for layer in past_key_values:
91
+ yield layer[0], layer[1]
92
+
93
+
94
+ class _DepthDecodeStaticLayerCache:
95
+ is_compileable = False
96
+ is_sliding = False
97
+
98
+ def __init__(self, max_cache_len: int) -> None:
99
+ self.max_cache_len = int(max_cache_len)
100
+ self.cumulative_length = 0
101
+ self.keys: Optional[torch.Tensor] = None
102
+ self.values: Optional[torch.Tensor] = None
103
+
104
+ def _allocate(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
105
+ bsz, n_heads = key_states.shape[:2]
106
+ self.keys = torch.empty(
107
+ (bsz, n_heads, self.max_cache_len, key_states.shape[-1]),
108
+ dtype=key_states.dtype,
109
+ device=key_states.device,
110
+ )
111
+ self.values = torch.empty(
112
+ (bsz, n_heads, self.max_cache_len, value_states.shape[-1]),
113
+ dtype=value_states.dtype,
114
+ device=value_states.device,
115
+ )
116
+
117
+ def update(
118
+ self,
119
+ key_states: torch.Tensor,
120
+ value_states: torch.Tensor,
121
+ *args,
122
+ **kwargs,
123
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
124
+ if self.keys is None:
125
+ self._allocate(key_states, value_states)
126
+ start = self.cumulative_length
127
+ end = start + key_states.shape[-2]
128
+ if end > self.max_cache_len:
129
+ raise RuntimeError(
130
+ f"KV cache length {end} exceeds max_cache_len={self.max_cache_len}."
131
+ )
132
+ self.keys[:, :, start:end, :].copy_(key_states)
133
+ self.values[:, :, start:end, :].copy_(value_states)
134
+ self.cumulative_length = end
135
+ return self.keys[:, :, :end, :], self.values[:, :, :end, :]
136
+
137
+ def get_seq_length(self) -> int:
138
+ return self.cumulative_length
139
+
140
+ def get_max_cache_shape(self) -> int:
141
+ return -1
142
+
143
+ def reset(self) -> None:
144
+ self.cumulative_length = 0
145
+
146
+
147
+ class _DepthDecodeStaticCache(Cache):
148
+ def __init__(self, config: PretrainedConfig, max_cache_len: int) -> None:
149
+ text_config = config.get_text_config(decoder=True)
150
+ super().__init__(
151
+ layers=[
152
+ _DepthDecodeStaticLayerCache(max_cache_len=max_cache_len)
153
+ for _ in range(text_config.num_hidden_layers)
154
+ ]
155
+ )
156
+
157
+ def get_seq_length(self, layer_idx: int = 0) -> int:
158
+ return self.layers[layer_idx].get_seq_length()
159
+
160
+ def get_max_cache_shape(self, layer_idx: int = 0) -> int:
161
+ return self.layers[layer_idx].get_max_cache_shape()
162
+
163
+ def reset(self) -> None:
164
+ for layer in self.layers:
165
+ layer.reset()
166
+
167
+
168
+ class ActionCudaGraphManager:
169
+ def __init__(self, model: Any) -> None:
170
+ self.model = model
171
+ self.enabled = True
172
+ self.action_flow_graph: Optional[_ActionFlowCudaGraph] = None
173
+
174
+ def set_enabled(self, enabled: bool) -> None:
175
+ self.enabled = bool(enabled)
176
+
177
+ def can_use_action_flow(self, inputs: _ActionFlowInputs) -> bool:
178
+ action_model = self.model
179
+ if not self.enabled:
180
+ return False
181
+ if action_model.training or action_model._require_action_expert().training:
182
+ return False
183
+ if inputs.trajectory.device.type != "cuda":
184
+ return False
185
+
186
+ def all_on_cuda():
187
+ yield inputs.trajectory
188
+ for k, v in inputs.context.kv_contexts:
189
+ yield k
190
+ yield v
191
+ for t in (
192
+ inputs.context.cross_mask,
193
+ inputs.context.self_mask,
194
+ inputs.context.valid_action,
195
+ inputs.action_dim_is_pad,
196
+ ):
197
+ if t is not None:
198
+ yield t
199
+ if inputs.context.rope_cache is not None:
200
+ yield from inputs.context.rope_cache
201
+ for step in inputs.modulations:
202
+ yield step.conditioning
203
+ for block_modulation in step.block_modulations:
204
+ yield from block_modulation
205
+ yield from step.final_modulation
206
+
207
+ return all(t.device.type == "cuda" for t in all_on_cuda())
208
+
209
+ def run_action_flow(
210
+ self,
211
+ inputs: _ActionFlowInputs,
212
+ steps: int,
213
+ run_loop,
214
+ ) -> torch.Tensor:
215
+ key = _cuda_graph_key(inputs, steps)
216
+ cache = self.action_flow_graph
217
+ if cache is None or cache.key != key:
218
+ static_inputs = _clone_static_inputs(inputs)
219
+ graph, output = _capture_cuda_graph(
220
+ lambda: run_loop(static_inputs, steps),
221
+ inputs.trajectory.device,
222
+ after_warmup=lambda: static_inputs.trajectory.copy_(inputs.trajectory),
223
+ )
224
+ cache = _ActionFlowCudaGraph(
225
+ key=key,
226
+ graph=graph,
227
+ static_inputs=static_inputs,
228
+ output=output,
229
+ )
230
+ self.action_flow_graph = cache
231
+ else:
232
+ _copy_inputs_(cache.static_inputs, inputs)
233
+
234
+ cache.graph.replay()
235
+ return cache.output.clone()
236
+
237
+
238
+ class DepthDecodeCudaGraphManager:
239
+ def __init__(self, model: Any) -> None:
240
+ self.model = model
241
+ self.backbone = model.model
242
+ self.enabled = True
243
+ self.graph: Optional[_DepthDecodeCudaGraph] = None
244
+ self.graph_spec: Optional[_DepthDecodeCudaGraphSpec] = None
245
+
246
+ def set_enabled(self, enabled: bool) -> None:
247
+ self.enabled = bool(enabled)
248
+
249
+ def make_static_cache(self, max_cache_len: int) -> _DepthDecodeStaticCache:
250
+ return _DepthDecodeStaticCache(
251
+ config=self.model.config.text_config,
252
+ max_cache_len=max_cache_len,
253
+ )
254
+
255
+ def _depth_decode_spec(self) -> _DepthDecodeCudaGraphSpec:
256
+ static = self.graph_spec
257
+ if static is None:
258
+ cfg = self.backbone.transformer.config
259
+ rotary_emb = getattr(self.backbone.transformer, "rotary_emb", None)
260
+ static = _DepthDecodeCudaGraphSpec(
261
+ eligible=(
262
+ not cfg.norm_after
263
+ and cfg.rope_scaling_layers is None
264
+ and getattr(rotary_emb, "rope_type", None) == "default"
265
+ and cfg._attn_implementation == "sdpa"
266
+ ),
267
+ cache_key_prefix=(
268
+ cfg.hidden_size,
269
+ cfg.num_attention_heads,
270
+ cfg.num_key_value_heads,
271
+ cfg.head_dim,
272
+ cfg.num_hidden_layers,
273
+ cfg.use_qk_norm,
274
+ cfg.qk_norm_type,
275
+ cfg._attn_implementation,
276
+ ),
277
+ num_hidden_layers=cfg.num_hidden_layers,
278
+ head_dim=cfg.head_dim,
279
+ num_attention_heads=cfg.num_attention_heads,
280
+ )
281
+ self.graph_spec = static
282
+ return static
283
+
284
+ def can_use(
285
+ self,
286
+ next_input_ids: torch.Tensor,
287
+ *,
288
+ past_key_values: Cache,
289
+ attention_bias: torch.Tensor,
290
+ ) -> bool:
291
+ if (
292
+ not self.enabled
293
+ or self.model.training
294
+ or self.backbone.transformer.training
295
+ ):
296
+ return False
297
+ if next_input_ids.device.type != "cuda":
298
+ return False
299
+ if (
300
+ next_input_ids.ndim != 2
301
+ or next_input_ids.shape[0] != 1
302
+ or next_input_ids.shape[1] != 1
303
+ ):
304
+ return False
305
+ if not isinstance(past_key_values, _DepthDecodeStaticCache):
306
+ return False
307
+ if (
308
+ not torch.is_tensor(attention_bias)
309
+ or attention_bias.device != next_input_ids.device
310
+ ):
311
+ return False
312
+ return self._depth_decode_spec().eligible
313
+
314
+ def _depth_decode_key(
315
+ self,
316
+ next_input_ids: torch.Tensor,
317
+ attention_bias: torch.Tensor,
318
+ ) -> Tuple[Any, ...]:
319
+ device = next_input_ids.device
320
+ return (
321
+ self._depth_decode_spec().cache_key_prefix,
322
+ device.type,
323
+ device.index,
324
+ self.model.lm_head.weight.dtype,
325
+ attention_bias.shape[-1],
326
+ )
327
+
328
+ def _select_depth_decode_rope(
329
+ self, cos: torch.Tensor, sin: torch.Tensor, *, past_length: int
330
+ ) -> None:
331
+ emb = self.backbone.transformer.rotary_emb
332
+ cos.copy_(emb._pos_cos_cache[0, :, past_length : past_length + 1, :])
333
+ sin.copy_(emb._pos_sin_cache[0, :, past_length : past_length + 1, :])
334
+
335
+ def _depth_decode_pre_layer(
336
+ self,
337
+ layer_idx: int,
338
+ hidden_states: torch.Tensor,
339
+ cos: torch.Tensor,
340
+ sin: torch.Tensor,
341
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
342
+ block = self.backbone.transformer.blocks[layer_idx]
343
+ attention = block.self_attn
344
+ residual = hidden_states
345
+ hidden_states = block.attn_norm(hidden_states)
346
+
347
+ input_shape = hidden_states.shape[:-1]
348
+ hidden_shape = (*input_shape, -1, attention.head_dim)
349
+ qkv = attention.att_proj(hidden_states)
350
+ query_states, key_states, value_states = qkv.split(attention.fused_dims, dim=-1)
351
+ value_states = value_states.view(hidden_shape)
352
+
353
+ apply_qk_norm = attention.q_norm is not None and attention.k_norm is not None
354
+ norm_after_view = apply_qk_norm and attention.qk_norm_type == "qwen3"
355
+
356
+ if apply_qk_norm and not norm_after_view:
357
+ query_states = attention.q_norm(query_states)
358
+ key_states = attention.k_norm(key_states)
359
+
360
+ query_states = query_states.view(hidden_shape)
361
+ key_states = key_states.view(hidden_shape)
362
+
363
+ if norm_after_view:
364
+ query_states = attention.q_norm(query_states)
365
+ key_states = attention.k_norm(key_states)
366
+
367
+ query_states = query_states.transpose(1, 2)
368
+ key_states = key_states.transpose(1, 2)
369
+ value_states = value_states.transpose(1, 2)
370
+ query_states, key_states = _apply_rotary_pos_emb(
371
+ query_states, key_states, cos, sin
372
+ )
373
+ return residual, query_states, key_states, value_states
374
+
375
+ def _depth_decode_pre0(
376
+ self,
377
+ token_ids: torch.Tensor,
378
+ cos: torch.Tensor,
379
+ sin: torch.Tensor,
380
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
381
+ inputs_embeds = self.model._embed_base_tokens(token_ids)
382
+ return self._depth_decode_pre_layer(0, inputs_embeds, cos, sin)
383
+
384
+ def _depth_decode_post_layer(
385
+ self,
386
+ layer_idx: int,
387
+ residual: torch.Tensor,
388
+ attn_context: torch.Tensor,
389
+ ) -> torch.Tensor:
390
+ block = self.backbone.transformer.blocks[layer_idx]
391
+ attention = block.self_attn
392
+ input_shape = residual.shape[:-1]
393
+ attn_output = attn_context.reshape(*input_shape, -1).contiguous()
394
+ attn_output = attention.attn_out(attn_output)
395
+ hidden_states = residual + block.dropout(attn_output)
396
+
397
+ residual = hidden_states
398
+ hidden_states = block.ff_norm(hidden_states)
399
+ hidden_states = block.mlp(hidden_states)
400
+ hidden_states = residual + block.dropout(hidden_states)
401
+ return hidden_states
402
+
403
+ def _depth_decode_post_and_pre_next(
404
+ self,
405
+ layer_idx: int,
406
+ residual: torch.Tensor,
407
+ attn_context: torch.Tensor,
408
+ cos: torch.Tensor,
409
+ sin: torch.Tensor,
410
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
411
+ hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
412
+ return self._depth_decode_pre_layer(layer_idx + 1, hidden_states, cos, sin)
413
+
414
+ def _depth_decode_last_post(
415
+ self,
416
+ layer_idx: int,
417
+ residual: torch.Tensor,
418
+ attn_context: torch.Tensor,
419
+ ) -> torch.Tensor:
420
+ hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
421
+ return self.backbone.transformer.ln_f(hidden_states)
422
+
423
+ def _build_depth_decode_graph(
424
+ self,
425
+ next_input_ids: torch.Tensor,
426
+ *,
427
+ past_length: int,
428
+ attention_bias: torch.Tensor,
429
+ ) -> _DepthDecodeCudaGraph:
430
+ text_config = self.backbone.transformer.config
431
+ device = next_input_ids.device
432
+ dtype = self.model.lm_head.weight.dtype
433
+ static = self._depth_decode_spec()
434
+ num_layers = static.num_hidden_layers
435
+ head_dim = static.head_dim
436
+ max_cache_len = int(attention_bias.shape[-1])
437
+ max_rope_len = max(int(text_config.max_position_embeddings or 0), max_cache_len)
438
+ self.backbone.transformer.prepare_rope_cache(
439
+ device=device, max_seq_len=max_rope_len
440
+ )
441
+
442
+ token_ids = torch.empty((1, 1), device=device, dtype=torch.long)
443
+ cos = torch.empty((1, 1, head_dim), device=device, dtype=dtype)
444
+ sin = torch.empty_like(cos)
445
+ positions = torch.arange(max_cache_len, device=device, dtype=torch.long)
446
+ context_shape = (1, 1, static.num_attention_heads, head_dim)
447
+
448
+ token_ids.copy_(next_input_ids)
449
+ self._select_depth_decode_rope(cos, sin, past_length=past_length)
450
+
451
+ pre_graph, pre_output = _capture_cuda_graph(
452
+ lambda: self._depth_decode_pre0(token_ids, cos, sin),
453
+ device,
454
+ )
455
+ stages = [_DepthDecodeCudaGraphLayerStage(*pre_output)]
456
+ post_graphs = []
457
+ for layer_idx in range(num_layers - 1):
458
+ stage = stages[-1]
459
+ attn_context = torch.empty(context_shape, device=device, dtype=dtype)
460
+ graph, output = _capture_cuda_graph(
461
+ lambda layer_idx=layer_idx, stage=stage, attn_context=attn_context: (
462
+ self._depth_decode_post_and_pre_next(
463
+ layer_idx,
464
+ stage.residual,
465
+ attn_context,
466
+ cos,
467
+ sin,
468
+ )
469
+ ),
470
+ device,
471
+ )
472
+ post_graphs.append(
473
+ _DepthDecodeCudaGraphPostStage(graph=graph, attn_context=attn_context)
474
+ )
475
+ stages.append(_DepthDecodeCudaGraphLayerStage(*output))
476
+
477
+ last_stage = stages[-1]
478
+ last_attn_context = torch.empty(context_shape, device=device, dtype=dtype)
479
+ last_graph, last_output = _capture_cuda_graph(
480
+ lambda: self._depth_decode_last_post(
481
+ num_layers - 1,
482
+ last_stage.residual,
483
+ last_attn_context,
484
+ ),
485
+ device,
486
+ )
487
+ post_graphs.append(
488
+ _DepthDecodeCudaGraphPostStage(
489
+ graph=last_graph, attn_context=last_attn_context
490
+ )
491
+ )
492
+ return _DepthDecodeCudaGraph(
493
+ cache_key=self._depth_decode_key(next_input_ids, attention_bias),
494
+ pre_graph=pre_graph,
495
+ token_ids=token_ids,
496
+ cos=cos,
497
+ sin=sin,
498
+ positions=positions,
499
+ stages=tuple(stages),
500
+ post_graphs=tuple(post_graphs),
501
+ output=last_output,
502
+ )
503
+
504
+ def _get_depth_decode_graph(
505
+ self,
506
+ next_input_ids: torch.Tensor,
507
+ *,
508
+ past_length: int,
509
+ attention_bias: torch.Tensor,
510
+ ) -> _DepthDecodeCudaGraph:
511
+ key = self._depth_decode_key(next_input_ids, attention_bias)
512
+ decode_graph = self.graph
513
+ if decode_graph is None or decode_graph.cache_key != key:
514
+ decode_graph = self._build_depth_decode_graph(
515
+ next_input_ids,
516
+ past_length=past_length,
517
+ attention_bias=attention_bias,
518
+ )
519
+ self.graph = decode_graph
520
+ else:
521
+ decode_graph.token_ids.copy_(next_input_ids)
522
+ self._select_depth_decode_rope(
523
+ decode_graph.cos, decode_graph.sin, past_length=past_length
524
+ )
525
+ return decode_graph
526
+
527
+ def _run_depth_decode_attention_core(
528
+ self,
529
+ layer_idx: int,
530
+ stage: _DepthDecodeCudaGraphLayerStage,
531
+ *,
532
+ past_key_values: Cache,
533
+ attention_bias: torch.Tensor,
534
+ cache_position: torch.Tensor,
535
+ cos: torch.Tensor,
536
+ sin: torch.Tensor,
537
+ ) -> torch.Tensor:
538
+ attention = self.backbone.transformer.blocks[layer_idx].self_attn
539
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
540
+ key_states, value_states = past_key_values.update(
541
+ stage.key,
542
+ stage.value,
543
+ layer_idx,
544
+ cache_kwargs,
545
+ )
546
+ key_states = _repeat_kv(key_states, attention.num_key_value_groups)
547
+ value_states = _repeat_kv(value_states, attention.num_key_value_groups)
548
+ attn_output = F.scaled_dot_product_attention(
549
+ stage.query,
550
+ key_states,
551
+ value_states,
552
+ attn_mask=attention_bias,
553
+ dropout_p=0.0,
554
+ is_causal=False,
555
+ )
556
+ return attn_output.transpose(1, 2)
557
+
558
+ def run(
559
+ self,
560
+ next_input_ids: torch.Tensor,
561
+ *,
562
+ past_key_values: Cache,
563
+ attention_bias: torch.Tensor,
564
+ past_length: int,
565
+ ) -> Tuple[torch.Tensor, Cache]:
566
+ end = past_length + 1
567
+ decode_graph = self._get_depth_decode_graph(
568
+ next_input_ids,
569
+ past_length=past_length,
570
+ attention_bias=attention_bias,
571
+ )
572
+ cache_position = decode_graph.positions[past_length:end]
573
+ attention_bias_q = attention_bias[:, :, past_length:end, :end]
574
+
575
+ decode_graph.pre_graph.replay()
576
+
577
+ for layer_idx, post_graph in enumerate(decode_graph.post_graphs):
578
+ attn_context = self._run_depth_decode_attention_core(
579
+ layer_idx,
580
+ decode_graph.stages[layer_idx],
581
+ past_key_values=past_key_values,
582
+ attention_bias=attention_bias_q,
583
+ cache_position=cache_position,
584
+ cos=decode_graph.cos,
585
+ sin=decode_graph.sin,
586
+ )
587
+ post_graph.attn_context.copy_(attn_context)
588
+ post_graph.graph.replay()
589
+
590
+ return decode_graph.output, past_key_values
591
+
592
+
593
+ def _cuda_graph_tensor_signature(
594
+ tensor: Optional[torch.Tensor],
595
+ ) -> Optional[Tuple[Any, ...]]:
596
+ if tensor is None:
597
+ return None
598
+ return (
599
+ tuple(tensor.shape),
600
+ tuple(tensor.stride()),
601
+ str(tensor.dtype),
602
+ str(tensor.device),
603
+ )
604
+
605
+
606
+ def _cuda_graph_context_signature(context: Any) -> Tuple[Any, ...]:
607
+ sig = _cuda_graph_tensor_signature
608
+ return (
609
+ tuple((sig(k), sig(v)) for k, v in context.kv_contexts),
610
+ sig(context.cross_mask),
611
+ sig(context.self_mask),
612
+ sig(context.valid_action),
613
+ None
614
+ if context.rope_cache is None
615
+ else tuple(sig(t) for t in context.rope_cache),
616
+ )
617
+
618
+
619
+ def _cuda_graph_modulation_signature(modulations: Sequence[Any]) -> Tuple[Any, ...]:
620
+ sig = _cuda_graph_tensor_signature
621
+ return tuple(
622
+ (
623
+ sig(step.conditioning),
624
+ tuple(
625
+ tuple(sig(t) for t in block_modulation)
626
+ for block_modulation in step.block_modulations
627
+ ),
628
+ tuple(sig(t) for t in step.final_modulation),
629
+ )
630
+ for step in modulations
631
+ )
632
+
633
+
634
+ def _cuda_graph_key(inputs: _ActionFlowInputs, steps: int) -> Tuple[Any, ...]:
635
+ sig = _cuda_graph_tensor_signature
636
+ return (
637
+ sig(inputs.trajectory),
638
+ _cuda_graph_context_signature(inputs.context),
639
+ _cuda_graph_modulation_signature(inputs.modulations),
640
+ sig(inputs.action_dim_is_pad),
641
+ int(steps),
642
+ )
643
+
644
+
645
+ def _clone_static_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
646
+ if tensor is None:
647
+ return None
648
+ static = torch.empty_strided(
649
+ tuple(tensor.shape),
650
+ tuple(tensor.stride()),
651
+ device=tensor.device,
652
+ dtype=tensor.dtype,
653
+ )
654
+ static.copy_(tensor)
655
+ return static
656
+
657
+
658
+ def _clone_static_context(context: Any) -> Any:
659
+ rope_cache = None
660
+ if context.rope_cache is not None:
661
+ rope_cache = tuple(_clone_static_tensor(t) for t in context.rope_cache)
662
+ return context.__class__(
663
+ kv_contexts=tuple(
664
+ (_clone_static_tensor(k), _clone_static_tensor(v))
665
+ for k, v in context.kv_contexts
666
+ ),
667
+ cross_mask=_clone_static_tensor(context.cross_mask),
668
+ self_mask=_clone_static_tensor(context.self_mask),
669
+ valid_action=_clone_static_tensor(context.valid_action),
670
+ rope_cache=rope_cache,
671
+ )
672
+
673
+
674
+ def _clone_static_modulations(modulations: Sequence[Any]) -> Sequence[Any]:
675
+ return tuple(
676
+ step.__class__(
677
+ conditioning=_clone_static_tensor(step.conditioning),
678
+ block_modulations=tuple(
679
+ tuple(_clone_static_tensor(t) for t in block_modulation)
680
+ for block_modulation in step.block_modulations
681
+ ),
682
+ final_modulation=tuple(
683
+ _clone_static_tensor(t) for t in step.final_modulation
684
+ ),
685
+ )
686
+ for step in modulations
687
+ )
688
+
689
+
690
+ def _clone_static_inputs(inputs: _ActionFlowInputs) -> _ActionFlowInputs:
691
+ return _ActionFlowInputs(
692
+ trajectory=_clone_static_tensor(inputs.trajectory),
693
+ context=_clone_static_context(inputs.context),
694
+ modulations=_clone_static_modulations(inputs.modulations),
695
+ action_dim_is_pad=_clone_static_tensor(inputs.action_dim_is_pad),
696
+ )
697
+
698
+
699
+ def _copy_context_(dst: Any, src: Any) -> None:
700
+ for (dst_k, dst_v), (src_k, src_v) in zip(dst.kv_contexts, src.kv_contexts):
701
+ dst_k.copy_(src_k)
702
+ dst_v.copy_(src_v)
703
+ if src.cross_mask is not None:
704
+ dst.cross_mask.copy_(src.cross_mask)
705
+ if src.self_mask is not None:
706
+ dst.self_mask.copy_(src.self_mask)
707
+ if src.valid_action is not None:
708
+ dst.valid_action.copy_(src.valid_action)
709
+ if src.rope_cache is not None:
710
+ for dst_tensor, src_tensor in zip(dst.rope_cache, src.rope_cache):
711
+ dst_tensor.copy_(src_tensor)
712
+
713
+
714
+ def _copy_inputs_(dst: _ActionFlowInputs, src: _ActionFlowInputs) -> None:
715
+ dst.trajectory.copy_(src.trajectory)
716
+ _copy_context_(dst.context, src.context)
717
+ if src.action_dim_is_pad is not None:
718
+ dst.action_dim_is_pad.copy_(src.action_dim_is_pad)
719
+
720
+
721
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
722
+ x1 = x[..., : x.shape[-1] // 2]
723
+ x2 = x[..., x.shape[-1] // 2 :]
724
+ return torch.cat((-x2, x1), dim=-1)
725
+
726
+
727
+ def _apply_rotary_pos_emb(
728
+ q: torch.Tensor,
729
+ k: torch.Tensor,
730
+ cos: torch.Tensor,
731
+ sin: torch.Tensor,
732
+ unsqueeze_dim: int = 1,
733
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
734
+ cos = cos.unsqueeze(unsqueeze_dim)
735
+ sin = sin.unsqueeze(unsqueeze_dim)
736
+ q_embed = (q * cos) + (_rotate_half(q) * sin)
737
+ k_embed = (k * cos) + (_rotate_half(k) * sin)
738
+ return q_embed, k_embed
739
+
740
+
741
+ def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
742
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
743
+ if n_rep == 1:
744
+ return hidden_states
745
+ hidden_states = hidden_states[:, :, None, :, :].expand(
746
+ batch, num_key_value_heads, n_rep, slen, head_dim
747
+ )
748
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
749
+
750
+
751
+ def _capture_cuda_graph(
752
+ fn,
753
+ device: torch.device,
754
+ *,
755
+ after_warmup=None,
756
+ ) -> Tuple[torch.cuda.CUDAGraph, Any]:
757
+ warmup_stream = torch.cuda.Stream(device=device)
758
+ warmup_stream.wait_stream(torch.cuda.current_stream(device))
759
+ with torch.cuda.stream(warmup_stream):
760
+ fn()
761
+ torch.cuda.current_stream(device).wait_stream(warmup_stream)
762
+ if after_warmup is not None:
763
+ after_warmup()
764
+
765
+ graph = torch.cuda.CUDAGraph()
766
+ with torch.cuda.graph(graph):
767
+ output = fn()
768
+ return graph, output
model-00001-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08663a88bc5c1f6c1cf8534a8cdf7971eb2fd66979ac42d38752d5209b971e6b
3
+ size 4929809880
model-00002-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9282d1390ca3fe81f31eaa4a925bc881ecb7c56c315eb8d3d1e2f6616cba7af9
3
+ size 4844690992
model-00003-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f35fdd96126d04975d1feae1b715a875d03ad9d07c7de25fa91a6573cceb1e7
3
+ size 4844691024
model-00004-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1b546c537a6eb6b5e281e5b8d6819ffe8ee8f88c55da5025e1e6ee2721cc907
3
+ size 4998106920
model-00005-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a925159b5e363f72007c38a2c4c2fc7cdac6e8b6ae990ddaa51f3abe526beb77
3
+ size 2345090936
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_molmoact2.py ADDED
The diff for this file is too large to render. See raw diff
 
norm_stats.json ADDED
@@ -0,0 +1,1739 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "format": "molmoact2_norm_stats.v1",
3
+ "norm_mode": "q01_q99",
4
+ "metadata_by_tag": {
5
+ "franka_molmoact": {
6
+ "action_key": "action.del_ee_action",
7
+ "state_key": "observation.state",
8
+ "camera_keys": [
9
+ "observation.images.primary",
10
+ "observation.images.secondary"
11
+ ],
12
+ "normalize_gripper": false,
13
+ "action_horizon": 10,
14
+ "n_action_steps": 10,
15
+ "setup_type": "single franka robotic arm in molmoact",
16
+ "control_mode": "delta end-effector pose",
17
+ "action_stats": {
18
+ "min": [
19
+ -0.07434078305959702,
20
+ -0.07339745759963989,
21
+ -0.06539416313171387,
22
+ -0.1688285619020462,
23
+ -0.10289879888296127,
24
+ -0.2667275667190552,
25
+ 0.0
26
+ ],
27
+ "max": [
28
+ 0.06042003631591797,
29
+ 0.09417290985584259,
30
+ 0.07019275426864624,
31
+ 0.2616892158985138,
32
+ 0.11751057207584381,
33
+ 0.16968433558940887,
34
+ 1.0
35
+ ],
36
+ "mean": [
37
+ 0.0005923698136522352,
38
+ 0.000245022598131832,
39
+ -4.604843771714063e-05,
40
+ 0.00022562421486693225,
41
+ -0.0005166618849942836,
42
+ -0.0002193919428051152,
43
+ 0.557619424517478
44
+ ],
45
+ "std": [
46
+ 0.005274540883280089,
47
+ 0.007662320435387572,
48
+ 0.006516662891595147,
49
+ 0.013564563259375743,
50
+ 0.011179215063905077,
51
+ 0.015195633113705318,
52
+ 0.49666890583432166
53
+ ],
54
+ "count": [
55
+ 1482599.0
56
+ ],
57
+ "q01": [
58
+ -0.011251236153566059,
59
+ -0.014918113203115847,
60
+ -0.011753186696798671,
61
+ -0.02785908205770074,
62
+ -0.025679744407356857,
63
+ -0.03279371599275369,
64
+ 3.7096921558780464e-05
65
+ ],
66
+ "q10": [
67
+ -0.005157558671709432,
68
+ -0.007627389508279324,
69
+ -0.006774633516067545,
70
+ -0.013867640389035468,
71
+ -0.01314412247667587,
72
+ -0.016390209597024155,
73
+ 0.012615970474925397
74
+ ],
75
+ "q50": [
76
+ 0.00047587496704591567,
77
+ -5.756867525949417e-05,
78
+ -0.0004126053693703461,
79
+ 0.00010505655624582394,
80
+ 6.41251115100509e-05,
81
+ -0.00028445035571581385,
82
+ 0.6884608295876035
83
+ ],
84
+ "q90": [
85
+ 0.006329953764757322,
86
+ 0.008685542226677301,
87
+ 0.008054135204293992,
88
+ 0.01397800046720906,
89
+ 0.010417940135392682,
90
+ 0.016052135597642597,
91
+ 0.9523435823006378
92
+ ],
93
+ "q99": [
94
+ 0.01117553552099105,
95
+ 0.016859041899882184,
96
+ 0.015590574732817865,
97
+ 0.029286192888436802,
98
+ 0.023178728984454205,
99
+ 0.031348125431223534,
100
+ 0.9523741512556912
101
+ ],
102
+ "names": [
103
+ "x",
104
+ "y",
105
+ "z",
106
+ "rx",
107
+ "ry",
108
+ "rz",
109
+ "gripper"
110
+ ],
111
+ "mask": [
112
+ true,
113
+ true,
114
+ true,
115
+ true,
116
+ true,
117
+ true,
118
+ false
119
+ ]
120
+ },
121
+ "state_stats": {
122
+ "min": [
123
+ -0.26428329944610596,
124
+ -0.6690786480903625,
125
+ -0.11737073212862015,
126
+ -3.141592264175415,
127
+ -1.4651211500167847,
128
+ -2.9524343013763428,
129
+ -0.014517154544591904
130
+ ],
131
+ "max": [
132
+ 0.8226616978645325,
133
+ 0.7252005338668823,
134
+ 0.9137527346611023,
135
+ 3.141592264175415,
136
+ 1.3202887773513794,
137
+ 1.35053551197052,
138
+ 1.0004678964614868
139
+ ],
140
+ "mean": [
141
+ 0.524974531443455,
142
+ -0.009077995640631331,
143
+ 0.37626277677807307,
144
+ -1.1230985182050761,
145
+ -0.15037831955429493,
146
+ -0.8360877101239638,
147
+ 0.4828000792054066
148
+ ],
149
+ "std": [
150
+ 0.108734747557998,
151
+ 0.19003219833018514,
152
+ 0.13128520583933115,
153
+ 2.684752550519181,
154
+ 0.2757065272207643,
155
+ 0.41552838131031417,
156
+ 0.44164051887464084
157
+ ],
158
+ "count": [
159
+ 1482599.0
160
+ ],
161
+ "q01": [
162
+ 0.3835934249761093,
163
+ -0.16975945635008685,
164
+ 0.26948875059068883,
165
+ -3.1154851500608602,
166
+ -0.3736599588300681,
167
+ -1.1556019879922976,
168
+ -0.010862173339892917
169
+ ],
170
+ "q10": [
171
+ 0.4262234785864046,
172
+ -0.14936716135148972,
173
+ 0.28255691882796946,
174
+ -2.886297848869603,
175
+ -0.33795875325689667,
176
+ -1.111211596662902,
177
+ -0.010235056183139664
178
+ ],
179
+ "q50": [
180
+ 0.526748316869364,
181
+ -0.018291417601920976,
182
+ 0.37479483390901625,
183
+ -1.3501052773711595,
184
+ -0.15482441515331163,
185
+ -0.8602460018117026,
186
+ 0.5755928858365665
187
+ ],
188
+ "q90": [
189
+ 0.6151325476883934,
190
+ 0.13938787377123313,
191
+ 0.47141213932315906,
192
+ 0.9833625012077128,
193
+ 0.044654971996918924,
194
+ -0.5307531964313489,
195
+ 0.8665980232471624
196
+ ],
197
+ "q99": [
198
+ 0.6251391330078412,
199
+ 0.16582465215431033,
200
+ 0.5049115299029577,
201
+ 1.351274663819693,
202
+ 0.09734563994616442,
203
+ -0.4656737437923123,
204
+ 0.8676457869434169
205
+ ],
206
+ "names": [
207
+ "x",
208
+ "y",
209
+ "z",
210
+ "rx",
211
+ "ry",
212
+ "rz",
213
+ "gripper"
214
+ ],
215
+ "mask": [
216
+ true,
217
+ true,
218
+ true,
219
+ true,
220
+ true,
221
+ true,
222
+ false
223
+ ]
224
+ }
225
+ },
226
+ "franka_droid": {
227
+ "action_key": "action",
228
+ "state_key": "observation.state",
229
+ "camera_keys": [
230
+ "observation.images.exterior_1_left",
231
+ "observation.images.exterior_2_left",
232
+ "observation.images.wrist_left"
233
+ ],
234
+ "normalize_gripper": false,
235
+ "action_horizon": 15,
236
+ "n_action_steps": 15,
237
+ "setup_type": "single franka robotic arm in droid",
238
+ "control_mode": "absolute joint pose",
239
+ "action_stats": {
240
+ "min": [
241
+ -2.781099557876587,
242
+ -1.6407934427261353,
243
+ -2.7493984699249268,
244
+ -2.9508564472198486,
245
+ -2.7826988697052,
246
+ 0.17983438074588776,
247
+ -2.901715040206909,
248
+ 0.0
249
+ ],
250
+ "max": [
251
+ 2.7449073791503906,
252
+ 1.6668277978897095,
253
+ 2.7546653747558594,
254
+ -0.1936211884021759,
255
+ 2.7786083221435547,
256
+ 4.402013778686523,
257
+ 2.90183162689209,
258
+ 1.0
259
+ ],
260
+ "mean": [
261
+ 0.010418229819396566,
262
+ 0.28233935319840636,
263
+ -0.015346633420959944,
264
+ -2.0060878874674715,
265
+ -0.029448930257783886,
266
+ 2.350942437431684,
267
+ 0.09820869537671756,
268
+ 0.4390250813949694
269
+ ],
270
+ "std": [
271
+ 0.3170372143097277,
272
+ 0.4863630998896905,
273
+ 0.27477375809610444,
274
+ 0.48806966037647037,
275
+ 0.528105567983804,
276
+ 0.4517944470893175,
277
+ 0.7430287051319469,
278
+ 0.44171628153080567
279
+ ],
280
+ "count": [
281
+ 17758044.0
282
+ ],
283
+ "q01": [
284
+ -0.2879620949867506,
285
+ -0.5702219304684566,
286
+ -0.31101638810433413,
287
+ -2.5622234922052725,
288
+ -0.5101021838814974,
289
+ 1.7376836093987995,
290
+ -0.5227783063045004,
291
+ 8.3274762776141e-05
292
+ ],
293
+ "q10": [
294
+ -0.2014563967491066,
295
+ -0.1953558605308627,
296
+ -0.20948523622127932,
297
+ -2.402722104277799,
298
+ -0.3766226599271436,
299
+ 1.9723782378158212,
300
+ -0.35517365133256956,
301
+ 0.006328163222238114
302
+ ],
303
+ "q50": [
304
+ 0.007162417319678469,
305
+ 0.336456475452052,
306
+ -0.013974891825914252,
307
+ -2.008015245005848,
308
+ -0.025656272672692895,
309
+ 2.3675065323600304,
310
+ 0.09545267517159627,
311
+ 0.4136583280016901
312
+ ],
313
+ "q90": [
314
+ 0.226670788411882,
315
+ 0.6598602025239771,
316
+ 0.17603345191458397,
317
+ -1.6119685483207011,
318
+ 0.3092560943750579,
319
+ 2.7008573894589345,
320
+ 0.5521874183259908,
321
+ 0.8737818408325101
322
+ ],
323
+ "q99": [
324
+ 0.32190872731317344,
325
+ 0.7405054873177153,
326
+ 0.2737893247287367,
327
+ -1.5075067942029405,
328
+ 0.4329542718063284,
329
+ 2.804162424656418,
330
+ 0.7128911284154664,
331
+ 0.8917437724235555
332
+ ],
333
+ "names": [
334
+ "joint_0",
335
+ "joint_1",
336
+ "joint_2",
337
+ "joint_3",
338
+ "joint_4",
339
+ "joint_5",
340
+ "joint_6",
341
+ "gripper"
342
+ ],
343
+ "mask": [
344
+ true,
345
+ true,
346
+ true,
347
+ true,
348
+ true,
349
+ true,
350
+ true,
351
+ false
352
+ ]
353
+ },
354
+ "state_stats": {
355
+ "min": [
356
+ -2.6536705493927,
357
+ -1.6156227588653564,
358
+ -2.6781487464904785,
359
+ -2.9409868717193604,
360
+ -2.6705946922302246,
361
+ 0.24893812835216522,
362
+ -2.757359266281128,
363
+ 0.0
364
+ ],
365
+ "max": [
366
+ 2.6687583923339844,
367
+ 1.5840554237365723,
368
+ 2.666306734085083,
369
+ -0.29779934883117676,
370
+ 2.6624162197113037,
371
+ 4.272191524505615,
372
+ 2.755643367767334,
373
+ 1.0
374
+ ],
375
+ "mean": [
376
+ 0.011081824850861873,
377
+ 0.27280296447760194,
378
+ -0.01550719225628586,
379
+ -2.01647228106023,
380
+ -0.029620826332964655,
381
+ 2.3483866081585507,
382
+ 0.09636965416886735,
383
+ 0.3927326432557614
384
+ ],
385
+ "std": [
386
+ 0.31291266868924655,
387
+ 0.4934370267472678,
388
+ 0.2728791258795487,
389
+ 0.48437020229024425,
390
+ 0.521435680610052,
391
+ 0.44821751701382595,
392
+ 0.7352730961005634,
393
+ 0.4070640216658998
394
+ ],
395
+ "count": [
396
+ 17758044.0
397
+ ],
398
+ "q01": [
399
+ -0.2793009809782748,
400
+ -0.5873924424866738,
401
+ -0.3058546817065916,
402
+ -2.5639055042030354,
403
+ -0.491431808753978,
404
+ 1.7381500993283228,
405
+ -0.5086147192989775,
406
+ 1.6414552399718753e-05
407
+ ],
408
+ "q10": [
409
+ -0.1994457930505723,
410
+ -0.2381088441987148,
411
+ -0.2103897594636481,
412
+ -2.421918892949847,
413
+ -0.3725951094142233,
414
+ 1.961410109454104,
415
+ -0.35782982482940473,
416
+ 0.005005809072924616
417
+ ],
418
+ "q50": [
419
+ 0.007891181486763803,
420
+ 0.3376595448103942,
421
+ -0.014280627673021464,
422
+ -2.0134951539128574,
423
+ -0.025990006808582142,
424
+ 2.3690656185268972,
425
+ 0.09443906823538496,
426
+ 0.38343357074070045
427
+ ],
428
+ "q90": [
429
+ 0.22605189533019984,
430
+ 0.6543162155730768,
431
+ 0.17689204963635444,
432
+ -1.6243810394635305,
433
+ 0.30497772553178637,
434
+ 2.696376125344824,
435
+ 0.5494813775877777,
436
+ 0.7734412581580631
437
+ ],
438
+ "q99": [
439
+ 0.3148177895778054,
440
+ 0.7235689468221655,
441
+ 0.2683897323238184,
442
+ -1.530780071911146,
443
+ 0.415067150345451,
444
+ 2.7863710743039887,
445
+ 0.6952765173061115,
446
+ 0.7968550629755542
447
+ ],
448
+ "names": [
449
+ "joint_0",
450
+ "joint_1",
451
+ "joint_2",
452
+ "joint_3",
453
+ "joint_4",
454
+ "joint_5",
455
+ "joint_6",
456
+ "gripper"
457
+ ],
458
+ "mask": [
459
+ true,
460
+ true,
461
+ true,
462
+ true,
463
+ true,
464
+ true,
465
+ true,
466
+ false
467
+ ]
468
+ }
469
+ },
470
+ "google_robot_fractal": {
471
+ "action_key": "action",
472
+ "state_key": "observation.state",
473
+ "camera_keys": [
474
+ "observation.images.image"
475
+ ],
476
+ "normalize_gripper": false,
477
+ "action_horizon": 3,
478
+ "n_action_steps": 3,
479
+ "setup_type": "google robot in rt_1",
480
+ "control_mode": "delta end-effector pose",
481
+ "action_stats": {
482
+ "min": [
483
+ -2.0204520225524902,
484
+ -5.497899532318115,
485
+ -2.031663417816162,
486
+ -1.569917917251587,
487
+ -1.569892168045044,
488
+ -1.570419430732727,
489
+ 0.0
490
+ ],
491
+ "max": [
492
+ 2.9984593391418457,
493
+ 22.09052848815918,
494
+ 2.7507524490356445,
495
+ 1.570636510848999,
496
+ 1.5321086645126343,
497
+ 1.5691522359848022,
498
+ 1.0
499
+ ],
500
+ "mean": [
501
+ 0.006986742172085001,
502
+ 0.006266400645656189,
503
+ -0.012625619452946994,
504
+ 0.04333477176605177,
505
+ -0.005755843126369106,
506
+ 0.0009133710921551742,
507
+ 0.5354204546016331
508
+ ],
509
+ "std": [
510
+ 0.06943342828666754,
511
+ 0.05987580207886052,
512
+ 0.07384291122356837,
513
+ 0.15697640227077467,
514
+ 0.13192376844373777,
515
+ 0.1463219229157086,
516
+ 0.49874381100185294
517
+ ],
518
+ "count": [
519
+ 3786400.0
520
+ ],
521
+ "q01": [
522
+ -0.22488493870935375,
523
+ -0.14842987771463928,
524
+ -0.23165991540148315,
525
+ -0.3518507387123856,
526
+ -0.4191961375830685,
527
+ -0.43642424734739155,
528
+ -1.000000013351432e-10
529
+ ],
530
+ "q10": [
531
+ -0.057097137110108394,
532
+ -0.04180085777840345,
533
+ -0.08797302699742898,
534
+ -0.08695764133325046,
535
+ -0.14987822626697328,
536
+ -0.14407043696379337,
537
+ -1.000000013351432e-10
538
+ ],
539
+ "q50": [
540
+ 0.0024323156617234785,
541
+ 0.001999621430072272,
542
+ -0.006186507557852898,
543
+ 0.010844173385829027,
544
+ 9.716094932283909e-05,
545
+ 0.00029282634304717123,
546
+ 0.9998131999001298
547
+ ],
548
+ "q90": [
549
+ 0.0799327921066265,
550
+ 0.06281248479995295,
551
+ 0.05719906967641521,
552
+ 0.2181351081319081,
553
+ 0.12581539646577725,
554
+ 0.14653933152766907,
555
+ 0.999962639980026
556
+ ],
557
+ "q99": [
558
+ 0.1780379284730618,
559
+ 0.1492598341805028,
560
+ 0.2184954847280796,
561
+ 0.5894017219543457,
562
+ 0.3527610110385077,
563
+ 0.4478335709948289,
564
+ 0.9999962639980026
565
+ ],
566
+ "names": [
567
+ "x",
568
+ "y",
569
+ "z",
570
+ "roll",
571
+ "pitch",
572
+ "yaw",
573
+ "gripper"
574
+ ],
575
+ "mask": [
576
+ true,
577
+ true,
578
+ true,
579
+ true,
580
+ true,
581
+ true,
582
+ false
583
+ ]
584
+ },
585
+ "state_stats": {
586
+ "min": [
587
+ -0.4436439275741577,
588
+ -0.9970501065254211,
589
+ -0.006579156965017319,
590
+ -0.8643477559089661,
591
+ -0.7079970240592957,
592
+ -0.7688722014427185,
593
+ -0.4999994933605194,
594
+ 0.0
595
+ ],
596
+ "max": [
597
+ 1.0534898042678833,
598
+ 0.48018959164619446,
599
+ 1.6896663904190063,
600
+ 0.9999993443489075,
601
+ 0.9999874830245972,
602
+ 0.9554369449615479,
603
+ 0.9914546012878418,
604
+ 1.0
605
+ ],
606
+ "mean": [
607
+ 0.5582046028643476,
608
+ -0.08324323429555826,
609
+ 0.7708198142579598,
610
+ -0.24752762586024715,
611
+ 0.4959921774813562,
612
+ 0.0925577145133276,
613
+ 0.20941890216560163,
614
+ 0.42619563761216767
615
+ ],
616
+ "std": [
617
+ 0.12440319799919354,
618
+ 0.11571359399631491,
619
+ 0.2458943611771509,
620
+ 0.5132342578001884,
621
+ 0.5223439094545202,
622
+ 0.1666598633276366,
623
+ 0.27617123901287927,
624
+ 0.4538753441706389
625
+ ],
626
+ "count": [
627
+ 3786400.0
628
+ ],
629
+ "q01": [
630
+ 0.3249422830693862,
631
+ -0.28341992821874495,
632
+ 0.14102827969076331,
633
+ -0.6864852132802142,
634
+ -0.6809632829655476,
635
+ -0.36044700054021983,
636
+ -0.4542378536110671,
637
+ -1.000000013351432e-10
638
+ ],
639
+ "q10": [
640
+ 0.42490653590113253,
641
+ -0.2163404740670024,
642
+ 0.37762326560147996,
643
+ -0.6294334687684712,
644
+ -0.5920843577131312,
645
+ -0.09803071723264807,
646
+ -0.23202098126670248,
647
+ -1.000000013351432e-10
648
+ ],
649
+ "q50": [
650
+ 0.5389458633818717,
651
+ -0.10059445446247807,
652
+ 0.8738477700690715,
653
+ -0.4849259061727551,
654
+ 0.7293306254210121,
655
+ 0.09137287071030761,
656
+ 0.23796976550241536,
657
+ 0.1832136750707287
658
+ ],
659
+ "q90": [
660
+ 0.7370583820538442,
661
+ 0.08210784119164745,
662
+ 0.9798527660285249,
663
+ 0.7291734785677116,
664
+ 0.84104651841686,
665
+ 0.3032210107222038,
666
+ 0.5373912158511455,
667
+ 0.9999365636178622
668
+ ],
669
+ "q99": [
670
+ 0.8750117781915163,
671
+ 0.21252014598261149,
672
+ 1.0727446933587392,
673
+ 0.9378297494636977,
674
+ 0.9562844548524763,
675
+ 0.46002622460251424,
676
+ 0.721691133425786,
677
+ 0.9999936563617862
678
+ ],
679
+ "names": [
680
+ "x",
681
+ "y",
682
+ "z",
683
+ "rx",
684
+ "ry",
685
+ "rz",
686
+ "rw",
687
+ "gripper"
688
+ ],
689
+ "mask": [
690
+ true,
691
+ true,
692
+ true,
693
+ true,
694
+ true,
695
+ true,
696
+ true,
697
+ false
698
+ ]
699
+ }
700
+ },
701
+ "widowx_bridge": {
702
+ "action_key": "action",
703
+ "state_key": "observation.state",
704
+ "camera_keys": [
705
+ "observation.images.image_0",
706
+ "observation.images.image_1",
707
+ "observation.images.image_2",
708
+ "observation.images.image_3"
709
+ ],
710
+ "normalize_gripper": false,
711
+ "action_horizon": 5,
712
+ "n_action_steps": 5,
713
+ "setup_type": "single widowx robotic arm in bridge",
714
+ "control_mode": "delta end-effector pose",
715
+ "action_stats": {
716
+ "min": [
717
+ -0.4007510244846344,
718
+ -0.13874775171279907,
719
+ -0.22553899884223938,
720
+ -3.2010786533355713,
721
+ -1.8618112802505493,
722
+ -6.279075622558594,
723
+ 0.0
724
+ ],
725
+ "max": [
726
+ 0.41691166162490845,
727
+ 0.25864794850349426,
728
+ 0.21218234300613403,
729
+ 3.122201919555664,
730
+ 1.8618112802505493,
731
+ 6.272472858428955,
732
+ 1.0
733
+ ],
734
+ "mean": [
735
+ 0.00022731789976267202,
736
+ 0.0001311203695138562,
737
+ -0.00012641641264803482,
738
+ -0.00014410962647987843,
739
+ -0.0003903070519037156,
740
+ 0.00024063480455490454,
741
+ 0.5765894392570026
742
+ ],
743
+ "std": [
744
+ 0.009782343005332487,
745
+ 0.013714070718580267,
746
+ 0.012687395519404626,
747
+ 0.02848996416069207,
748
+ 0.030552792886390234,
749
+ 0.07751153262919225,
750
+ 0.49409209255711634
751
+ ],
752
+ "count": [
753
+ 1893026.0
754
+ ],
755
+ "q01": [
756
+ -0.02871995611488819,
757
+ -0.04170781908448411,
758
+ -0.02608340910386921,
759
+ -0.0808367313719228,
760
+ -0.09246813206247581,
761
+ -0.20693750972396757,
762
+ -1.000000013351432e-10
763
+ ],
764
+ "q10": [
765
+ -0.010151055597043716,
766
+ -0.014922217821287087,
767
+ -0.01393665282931255,
768
+ -0.029593090264604636,
769
+ -0.03406380769665256,
770
+ -0.06413116391050117,
771
+ -1.000000013351432e-10
772
+ ],
773
+ "q50": [
774
+ 2.1248139354103056e-05,
775
+ -9.382913823534339e-06,
776
+ -0.0008275577521357758,
777
+ -0.00014731252460737677,
778
+ 0.00047152176188271845,
779
+ 0.0012537133528066303,
780
+ 0.9998265319000765
781
+ ],
782
+ "q90": [
783
+ 0.011082387395765428,
784
+ 0.015737555353724994,
785
+ 0.016874204550636374,
786
+ 0.02832893788750676,
787
+ 0.0322629905973504,
788
+ 0.06417266804375155,
789
+ 0.9999653063800154
790
+ ],
791
+ "q99": [
792
+ 0.028291364035668985,
793
+ 0.040898679036702676,
794
+ 0.04018220331768194,
795
+ 0.08177042032653538,
796
+ 0.07759675528459531,
797
+ 0.203201938256362,
798
+ 0.9999965306380015
799
+ ],
800
+ "names": [
801
+ "x",
802
+ "y",
803
+ "z",
804
+ "roll",
805
+ "pitch",
806
+ "yaw",
807
+ "gripper"
808
+ ],
809
+ "mask": [
810
+ true,
811
+ true,
812
+ true,
813
+ true,
814
+ true,
815
+ true,
816
+ false
817
+ ]
818
+ },
819
+ "state_stats": {
820
+ "min": [
821
+ -0.04167502000927925,
822
+ -0.3563207685947418,
823
+ -0.15537554025650024,
824
+ -3.141592502593994,
825
+ -1.4992541074752808,
826
+ -3.14153790473938,
827
+ 0.0,
828
+ 0.04637829214334488
829
+ ],
830
+ "max": [
831
+ 0.5862360596656799,
832
+ 0.4034728705883026,
833
+ 0.3568263053894043,
834
+ 1.3517684936523438,
835
+ 1.570796251296997,
836
+ 3.141204357147217,
837
+ 0.0,
838
+ 1.1121242046356201
839
+ ],
840
+ "mean": [
841
+ 0.3094503633235095,
842
+ 0.030725376723448255,
843
+ 0.06443996750169499,
844
+ 0.0064906683342908335,
845
+ -0.07720050195254197,
846
+ 0.10766038148835028,
847
+ 0.0,
848
+ 0.7081244810708762
849
+ ],
850
+ "std": [
851
+ 0.06060302901710459,
852
+ 0.0919536927343182,
853
+ 0.05159382707079282,
854
+ 0.1312174751351825,
855
+ 0.16924010047039229,
856
+ 0.5787203550709503,
857
+ 0.0,
858
+ 0.35365012001260804
859
+ ],
860
+ "count": [
861
+ 1893026.0
862
+ ],
863
+ "q01": [
864
+ 0.17102651970064053,
865
+ -0.16977934478310977,
866
+ -0.05565095783375642,
867
+ -0.3649685841887744,
868
+ -0.5418705685890239,
869
+ -1.3540046312592247,
870
+ 0.0,
871
+ 0.05212163980268402
872
+ ],
873
+ "q10": [
874
+ 0.234054275333357,
875
+ -0.08584102855009192,
876
+ 0.007129108058706664,
877
+ -0.13279207613930774,
878
+ -0.2879179685802783,
879
+ -0.47590377710082316,
880
+ 0.0,
881
+ 0.08160105384386226
882
+ ],
883
+ "q50": [
884
+ 0.30824996150509265,
885
+ 0.02806205006373531,
886
+ 0.061364141277506515,
887
+ 0.003477529234181987,
888
+ -0.06586482997881163,
889
+ 0.033681061760553146,
890
+ 0.0,
891
+ 0.9850432498405283
892
+ ],
893
+ "q90": [
894
+ 0.3866535994382209,
895
+ 0.15225549791502352,
896
+ 0.1303319111924363,
897
+ 0.14920492884702988,
898
+ 0.11511126950562722,
899
+ 0.8206040455663128,
900
+ 0.0,
901
+ 1.0013512433353218
902
+ ],
903
+ "q99": [
904
+ 0.453255677819252,
905
+ 0.23543677111215228,
906
+ 0.19489739182202712,
907
+ 0.378015822982788,
908
+ 0.27597790842706504,
909
+ 1.8504199743270873,
910
+ 0.0,
911
+ 1.0106366157291133
912
+ ],
913
+ "names": [
914
+ "x",
915
+ "y",
916
+ "z",
917
+ "roll",
918
+ "pitch",
919
+ "yaw",
920
+ "pad",
921
+ "gripper"
922
+ ],
923
+ "mask": [
924
+ true,
925
+ true,
926
+ true,
927
+ true,
928
+ true,
929
+ true,
930
+ true,
931
+ false
932
+ ]
933
+ }
934
+ },
935
+ "so100_so101_molmoact2": {
936
+ "action_key": "action",
937
+ "state_key": "observation.state",
938
+ "camera_keys": [],
939
+ "normalize_gripper": true,
940
+ "action_horizon": 30,
941
+ "n_action_steps": 30,
942
+ "setup_type": "single so100/so101 robotic arm in molmoact2",
943
+ "control_mode": "absolute joint pose",
944
+ "action_stats": {
945
+ "min": [
946
+ -122.607421875,
947
+ -270.0,
948
+ -269.208984375,
949
+ -125.771484375,
950
+ -269.912109375,
951
+ -31.57327651977539
952
+ ],
953
+ "max": [
954
+ 179.208984375,
955
+ 219.638671875,
956
+ 195.380859375,
957
+ 178.9453125,
958
+ 269.82421875,
959
+ 119.40789031982422
960
+ ],
961
+ "mean": [
962
+ 3.343996486826433,
963
+ 125.7905980370996,
964
+ 120.20220128113388,
965
+ 55.88144220174933,
966
+ -11.543010633027725,
967
+ 11.25886240824774
968
+ ],
969
+ "std": [
970
+ 28.909870406169997,
971
+ 52.25069634659296,
972
+ 47.94432906599221,
973
+ 36.01019142727721,
974
+ 69.35504013212369,
975
+ 17.116239869449775
976
+ ],
977
+ "count": [
978
+ 19619650.0
979
+ ],
980
+ "q01": [
981
+ -42.1300246338976,
982
+ 45.18258358164995,
983
+ 35.40059182962813,
984
+ 4.929781836327758,
985
+ -65.57568617645342,
986
+ -0.3016556932619033
987
+ ],
988
+ "q10": [
989
+ -25.040070398997557,
990
+ 68.27827215165794,
991
+ 65.76540485606242,
992
+ 26.58811186925123,
993
+ -39.81868441470048,
994
+ 0.26123181871944706
995
+ ],
996
+ "q50": [
997
+ 3.0828094324713105,
998
+ 124.5495736487354,
999
+ 122.75175717637279,
1000
+ 57.77960070056314,
1001
+ -11.094802886190045,
1002
+ 4.866634607477139
1003
+ ],
1004
+ "q90": [
1005
+ 31.591544866079253,
1006
+ 181.76986724267596,
1007
+ 168.5741215400282,
1008
+ 82.4353358815596,
1009
+ 16.05609349144359,
1010
+ 32.12324970648343
1011
+ ],
1012
+ "q99": [
1013
+ 48.55349563198916,
1014
+ 186.10646680077767,
1015
+ 173.6076722013997,
1016
+ 93.41056417929472,
1017
+ 43.53107398260694,
1018
+ 44.74649336930881
1019
+ ],
1020
+ "names": [
1021
+ "shoulder_pan",
1022
+ "shoulder_lift",
1023
+ "elbow_flex",
1024
+ "wrist_flex",
1025
+ "wrist_roll",
1026
+ "gripper"
1027
+ ],
1028
+ "mask": [
1029
+ true,
1030
+ true,
1031
+ true,
1032
+ true,
1033
+ true,
1034
+ true
1035
+ ]
1036
+ },
1037
+ "state_stats": {
1038
+ "min": [
1039
+ -115.048828125,
1040
+ -270.0,
1041
+ -235.8984375,
1042
+ -113.818359375,
1043
+ -268.9453125,
1044
+ -8.521058082580566
1045
+ ],
1046
+ "max": [
1047
+ 178.505859375,
1048
+ 218.49609375,
1049
+ 192.041015625,
1050
+ 207.861328125,
1051
+ 250.048828125,
1052
+ 118.2519302368164
1053
+ ],
1054
+ "mean": [
1055
+ 3.3225097946752244,
1056
+ 124.40594064960378,
1057
+ 121.59550610749059,
1058
+ 55.903039878016074,
1059
+ -11.41740021122887,
1060
+ 13.358497334686597
1061
+ ],
1062
+ "std": [
1063
+ 28.79265204113751,
1064
+ 52.702867303079756,
1065
+ 47.00596021941705,
1066
+ 35.53803566355756,
1067
+ 69.12836626047817,
1068
+ 16.333280282904557
1069
+ ],
1070
+ "count": [
1071
+ 19619650.0
1072
+ ],
1073
+ "q01": [
1074
+ -41.90962240941357,
1075
+ 43.66791235922949,
1076
+ 38.38770483255723,
1077
+ 5.711740446834044,
1078
+ -63.44539045209019,
1079
+ 0.9435577790191543
1080
+ ],
1081
+ "q10": [
1082
+ -24.949315993050774,
1083
+ 66.30007546431412,
1084
+ 68.16816985859437,
1085
+ 27.120731646136054,
1086
+ -39.50255020332888,
1087
+ 1.6190225837869365
1088
+ ],
1089
+ "q50": [
1090
+ 3.066375725640164,
1091
+ 123.16482094240277,
1092
+ 124.39930058290133,
1093
+ 57.88605464633133,
1094
+ -11.037436711677765,
1095
+ 9.241478261568748
1096
+ ],
1097
+ "q90": [
1098
+ 31.472920732960127,
1099
+ 180.87158401301218,
1100
+ 168.5699720215359,
1101
+ 81.64709150074712,
1102
+ 15.887605114617852,
1103
+ 31.887861734718296
1104
+ ],
1105
+ "q99": [
1106
+ 48.29435703371732,
1107
+ 185.2611055842669,
1108
+ 173.13578487933165,
1109
+ 91.78122415137209,
1110
+ 42.94491979114059,
1111
+ 44.13755601580974
1112
+ ],
1113
+ "names": [
1114
+ "shoulder_pan",
1115
+ "shoulder_lift",
1116
+ "elbow_flex",
1117
+ "wrist_flex",
1118
+ "wrist_roll",
1119
+ "gripper"
1120
+ ],
1121
+ "mask": [
1122
+ true,
1123
+ true,
1124
+ true,
1125
+ true,
1126
+ true,
1127
+ true
1128
+ ]
1129
+ }
1130
+ },
1131
+ "google_robot_bc_z": {
1132
+ "action_key": "action",
1133
+ "state_key": "observation.state",
1134
+ "camera_keys": [
1135
+ "observation.images.image"
1136
+ ],
1137
+ "normalize_gripper": false,
1138
+ "action_horizon": 10,
1139
+ "n_action_steps": 10,
1140
+ "setup_type": "google robot in bc_z",
1141
+ "control_mode": "delta end-effector pose",
1142
+ "action_stats": {
1143
+ "min": [
1144
+ -0.1677047461271286,
1145
+ -0.14630407094955444,
1146
+ -0.10066790133714676,
1147
+ -0.29421567916870117,
1148
+ -0.32101404666900635,
1149
+ -0.4635624885559082,
1150
+ 0.0
1151
+ ],
1152
+ "max": [
1153
+ 0.2165454924106598,
1154
+ 0.1251407265663147,
1155
+ 0.09988310933113098,
1156
+ 0.33544227480888367,
1157
+ 0.28117990493774414,
1158
+ 0.40614867210388184,
1159
+ 1.0
1160
+ ],
1161
+ "mean": [
1162
+ -0.009960200864471745,
1163
+ 0.0009084977087131892,
1164
+ 0.00499393515302369,
1165
+ 0.00028739003438370427,
1166
+ -0.00871610909893306,
1167
+ -0.030692461306736755,
1168
+ 0.8343520005664466
1169
+ ],
1170
+ "std": [
1171
+ 0.03080177058689462,
1172
+ 0.023236620172139833,
1173
+ 0.020777592916798007,
1174
+ 0.041763587623031895,
1175
+ 0.046686683400427,
1176
+ 0.07753463216688747,
1177
+ 0.3717643553432202
1178
+ ],
1179
+ "count": [
1180
+ 5471693.0
1181
+ ],
1182
+ "q01": [
1183
+ -0.09213472068957661,
1184
+ -0.06450906318665113,
1185
+ -0.04912072456744037,
1186
+ -0.11609895664024446,
1187
+ -0.1413486404610977,
1188
+ -0.22517701597416145,
1189
+ -1.000000013351432e-10
1190
+ ],
1191
+ "q10": [
1192
+ -0.05253115985050928,
1193
+ -0.028533985817234882,
1194
+ -0.021736428190829056,
1195
+ -0.04809403695382897,
1196
+ -0.0664864549799673,
1197
+ -0.1391167833364122,
1198
+ -1.000000013351432e-10
1199
+ ],
1200
+ "q50": [
1201
+ -0.0031453596109414592,
1202
+ 0.0004054125482836473,
1203
+ 0.0023481391860319715,
1204
+ -8.489440239357886e-05,
1205
+ -0.002574837787014793,
1206
+ -0.014108526356650069,
1207
+ 0.9998801266205536
1208
+ ],
1209
+ "q90": [
1210
+ 0.019494707527676427,
1211
+ 0.029460992205482695,
1212
+ 0.032557826189659966,
1213
+ 0.04931595102291217,
1214
+ 0.042994841552155126,
1215
+ 0.05302803170853769,
1216
+ 0.9999760253241107
1217
+ ],
1218
+ "q99": [
1219
+ 0.07630278211772451,
1220
+ 0.05802308552485688,
1221
+ 0.052553275338456634,
1222
+ 0.1173714221625478,
1223
+ 0.11711249897425843,
1224
+ 0.1673988100025391,
1225
+ 0.9999976025324111
1226
+ ],
1227
+ "names": [
1228
+ "x",
1229
+ "y",
1230
+ "z",
1231
+ "roll",
1232
+ "pitch",
1233
+ "yaw",
1234
+ "gripper"
1235
+ ],
1236
+ "mask": [
1237
+ true,
1238
+ true,
1239
+ true,
1240
+ true,
1241
+ true,
1242
+ true,
1243
+ false
1244
+ ]
1245
+ },
1246
+ "state_stats": {
1247
+ "min": [
1248
+ -0.7190948724746704,
1249
+ -0.3756217360496521,
1250
+ -0.281008243560791,
1251
+ -2.400146484375,
1252
+ -2.500656843185425,
1253
+ -3.1274476051330566,
1254
+ 0.0,
1255
+ 0.0
1256
+ ],
1257
+ "max": [
1258
+ 0.6597589254379272,
1259
+ 0.7259413599967957,
1260
+ 1.1217665672302246,
1261
+ 2.2803165912628174,
1262
+ 1.8151572942733765,
1263
+ 3.1237573623657227,
1264
+ 0.0,
1265
+ 1.0
1266
+ ],
1267
+ "mean": [
1268
+ 0.0176884768449917,
1269
+ 0.10948195169606133,
1270
+ 0.784290845584472,
1271
+ -0.5290053991424425,
1272
+ -0.22605912165135514,
1273
+ -0.17858785012278866,
1274
+ 0.0,
1275
+ 0.5600556496096702
1276
+ ],
1277
+ "std": [
1278
+ 0.1841601172406892,
1279
+ 0.09627411033983578,
1280
+ 0.08699189118288073,
1281
+ 0.24700645691257475,
1282
+ 0.4286554852012691,
1283
+ 1.0001615516228195,
1284
+ 0.0,
1285
+ 0.3586031013748201
1286
+ ],
1287
+ "count": [
1288
+ 5471693.0
1289
+ ],
1290
+ "q01": [
1291
+ -0.38789819221198557,
1292
+ -0.1118956928319213,
1293
+ 0.6110697470322705,
1294
+ -1.0415028765133625,
1295
+ -1.1876200204022105,
1296
+ -2.3808376895782,
1297
+ 0.0,
1298
+ 0.19986777120588917
1299
+ ],
1300
+ "q10": [
1301
+ -0.2318964688694949,
1302
+ -0.015558046064633315,
1303
+ 0.6822309043992328,
1304
+ -0.7563316012340816,
1305
+ -0.7533119325741587,
1306
+ -1.3938289285869132,
1307
+ 0.0,
1308
+ 0.2000496453831541
1309
+ ],
1310
+ "q50": [
1311
+ 0.022859327635303822,
1312
+ 0.10637610222856157,
1313
+ 0.776611691927557,
1314
+ -0.5671171062825059,
1315
+ -0.24114911945667813,
1316
+ -0.25162686787881255,
1317
+ 0.0,
1318
+ 0.3501994619818789
1319
+ ],
1320
+ "q90": [
1321
+ 0.2666238156546802,
1322
+ 0.23844897018337458,
1323
+ 0.9059002565684082,
1324
+ -0.26983885858517637,
1325
+ 0.3994129877275485,
1326
+ 1.374448904817122,
1327
+ 0.0,
1328
+ 0.999900866490248
1329
+ ],
1330
+ "q99": [
1331
+ 0.3325375374171561,
1332
+ 0.31715197447407467,
1333
+ 0.982179447052214,
1334
+ 0.34632693633800826,
1335
+ 0.7713777675821983,
1336
+ 2.029990628516839,
1337
+ 0.0,
1338
+ 0.9999900866490248
1339
+ ],
1340
+ "names": [
1341
+ "x",
1342
+ "y",
1343
+ "z",
1344
+ "roll",
1345
+ "pitch",
1346
+ "yaw",
1347
+ "pad",
1348
+ "gripper"
1349
+ ],
1350
+ "mask": [
1351
+ true,
1352
+ true,
1353
+ true,
1354
+ true,
1355
+ true,
1356
+ true,
1357
+ true,
1358
+ false
1359
+ ]
1360
+ }
1361
+ },
1362
+ "yam_dual_molmoact2": {
1363
+ "action_key": "action",
1364
+ "state_key": "observation.state",
1365
+ "camera_keys": [
1366
+ "observation.images.top",
1367
+ "observation.images.left",
1368
+ "observation.images.right"
1369
+ ],
1370
+ "normalize_gripper": false,
1371
+ "action_horizon": 30,
1372
+ "n_action_steps": 30,
1373
+ "setup_type": "bimanual yam robotic arms in molmoact2",
1374
+ "control_mode": "absolute joint pose",
1375
+ "action_stats": {
1376
+ "min": [
1377
+ -1.9876782894134521,
1378
+ -0.007057297509163618,
1379
+ -0.002861066721379757,
1380
+ -1.6958495378494263,
1381
+ -1.5730143785476685,
1382
+ -2.184138298034668,
1383
+ 0.0,
1384
+ -1.6771572828292847,
1385
+ -0.00667582219466567,
1386
+ -0.0032425422687083483,
1387
+ -1.7061493396759033,
1388
+ -1.6287097930908203,
1389
+ -2.143320322036743,
1390
+ 0.0
1391
+ ],
1392
+ "max": [
1393
+ 1.808003306388855,
1394
+ 3.1988632678985596,
1395
+ 3.1507973670959473,
1396
+ 1.592851161956787,
1397
+ 1.5890363454818726,
1398
+ 2.2081711292266846,
1399
+ 1.0,
1400
+ 2.440871238708496,
1401
+ 3.1084535121917725,
1402
+ 3.1530861854553223,
1403
+ 1.6649500131607056,
1404
+ 1.5947585105895996,
1405
+ 2.1639199256896973,
1406
+ 1.0
1407
+ ],
1408
+ "mean": [
1409
+ -0.08857854148141169,
1410
+ 1.3813960226201991,
1411
+ 1.2242081192216245,
1412
+ -0.7456114034786908,
1413
+ 0.15342910390834139,
1414
+ -0.2406550926649683,
1415
+ 0.6405881969404109,
1416
+ 0.11816370494944337,
1417
+ 1.3440412881232742,
1418
+ 1.1275448419933234,
1419
+ -0.6567647967296087,
1420
+ -0.15745777770921981,
1421
+ 0.20879381691599022,
1422
+ 0.5971762495146153
1423
+ ],
1424
+ "std": [
1425
+ 0.31549225693975164,
1426
+ 0.7241109409894698,
1427
+ 0.6724976443740277,
1428
+ 0.4912531895036823,
1429
+ 0.3766601597067631,
1430
+ 0.3683009171682207,
1431
+ 0.41042883365599214,
1432
+ 0.33538355728349317,
1433
+ 0.8035033283123882,
1434
+ 0.7129305114483252,
1435
+ 0.5147389512393373,
1436
+ 0.37362261558635523,
1437
+ 0.35878804842243267,
1438
+ 0.42346789755808983
1439
+ ],
1440
+ "count": [
1441
+ 76046658.0
1442
+ ],
1443
+ "q01": [
1444
+ -0.6603105582072047,
1445
+ 0.0041340051935240115,
1446
+ 0.013831665477596221,
1447
+ -1.3744044717113109,
1448
+ -0.3593570239425977,
1449
+ -0.9302641712677729,
1450
+ 0.051016362361406005,
1451
+ -0.49367228465810536,
1452
+ 0.004744360313868616,
1453
+ 0.017154297804418434,
1454
+ -1.4240273823045295,
1455
+ -0.9737084779331572,
1456
+ -0.4719268433374943,
1457
+ 0.033350514024370274
1458
+ ],
1459
+ "q10": [
1460
+ -0.4158939180171844,
1461
+ 0.49040349295087926,
1462
+ 0.48318427047331663,
1463
+ -1.1595704371830307,
1464
+ -0.13299944787425266,
1465
+ -0.5670792130135129,
1466
+ 0.11117863560492024,
1467
+ -0.19067792775434206,
1468
+ 0.19335683280594596,
1469
+ 0.1783492294932824,
1470
+ -1.165289828212844,
1471
+ -0.5363078842413471,
1472
+ -0.11410713925580458,
1473
+ 0.054251135868839034
1474
+ ],
1475
+ "q50": [
1476
+ -0.07347940057883112,
1477
+ 1.4486934996424023,
1478
+ 1.2826819985862519,
1479
+ -0.8018464396181274,
1480
+ 0.11333067563787286,
1481
+ -0.22188306769880142,
1482
+ 0.7333514901431821,
1483
+ 0.08159376899519756,
1484
+ 1.542016049355695,
1485
+ 1.2518141457542857,
1486
+ -0.6816567194944295,
1487
+ -0.12921257250905716,
1488
+ 0.19217648232095094,
1489
+ 0.6965966006454063
1490
+ ],
1491
+ "q90": [
1492
+ 0.21224325405051755,
1493
+ 2.0044457220962184,
1494
+ 1.7599272535504926,
1495
+ -0.17992348512991949,
1496
+ 0.5121005560866031,
1497
+ 0.06588770556098025,
1498
+ 0.9798257827982823,
1499
+ 0.49762827627115913,
1500
+ 2.062871328579572,
1501
+ 1.7914606668876476,
1502
+ -0.07308204053490945,
1503
+ 0.182291301998786,
1504
+ 0.5569780500008801,
1505
+ 0.9922195168313757
1506
+ ],
1507
+ "q99": [
1508
+ 0.4704245731743921,
1509
+ 2.244327078820327,
1510
+ 2.0080105207169177,
1511
+ 0.13399061379118773,
1512
+ 0.8834156417282395,
1513
+ 0.334483290041328,
1514
+ 0.987078674113364,
1515
+ 0.7377501348730936,
1516
+ 2.285076596429336,
1517
+ 2.0605540868103542,
1518
+ 0.23968854170206916,
1519
+ 0.5304791687465945,
1520
+ 0.9621494841801348,
1521
+ 0.9953596816858612
1522
+ ],
1523
+ "names": [
1524
+ "left_joint_0.pos",
1525
+ "left_joint_1.pos",
1526
+ "left_joint_2.pos",
1527
+ "left_joint_3.pos",
1528
+ "left_joint_4.pos",
1529
+ "left_joint_5.pos",
1530
+ "left_gripper.pos",
1531
+ "right_joint_0.pos",
1532
+ "right_joint_1.pos",
1533
+ "right_joint_2.pos",
1534
+ "right_joint_3.pos",
1535
+ "right_joint_4.pos",
1536
+ "right_joint_5.pos",
1537
+ "right_gripper.pos"
1538
+ ],
1539
+ "mask": [
1540
+ true,
1541
+ true,
1542
+ true,
1543
+ true,
1544
+ true,
1545
+ true,
1546
+ false,
1547
+ true,
1548
+ true,
1549
+ true,
1550
+ true,
1551
+ true,
1552
+ true,
1553
+ false
1554
+ ]
1555
+ },
1556
+ "state_stats": {
1557
+ "min": [
1558
+ -1.971656322479248,
1559
+ 0.00019073777366429567,
1560
+ 0.001716639962978661,
1561
+ -1.7023346424102783,
1562
+ -1.576829195022583,
1563
+ -2.0963988304138184,
1564
+ 0.0005250918911769986,
1565
+ -1.6741054058074951,
1566
+ -0.0009536888683214784,
1567
+ 0.004386968910694122,
1568
+ -1.737811803817749,
1569
+ -1.574158787727356,
1570
+ -2.0941100120544434,
1571
+ 0.003973988350480795
1572
+ ],
1573
+ "max": [
1574
+ 1.813725471496582,
1575
+ 3.101205348968506,
1576
+ 3.1466009616851807,
1577
+ 1.5821698904037476,
1578
+ 1.6222248077392578,
1579
+ 2.1040284633636475,
1580
+ 0.9997128844261169,
1581
+ 2.4343862533569336,
1582
+ 3.11112380027771,
1583
+ 3.1492714881896973,
1584
+ 1.5836957693099976,
1585
+ 1.6062028408050537,
1586
+ 2.1452276706695557,
1587
+ 1.0
1588
+ ],
1589
+ "mean": [
1590
+ -0.08969431138176573,
1591
+ 1.3833397954729871,
1592
+ 1.2214299123909826,
1593
+ -0.7438162535789633,
1594
+ 0.15467924320885904,
1595
+ -0.2444551331990551,
1596
+ 0.6477599794157677,
1597
+ 0.11772745375836342,
1598
+ 1.3475698442605246,
1599
+ 1.1241839262647857,
1600
+ -0.657754523106273,
1601
+ -0.16024992695882134,
1602
+ 0.2095172679704065,
1603
+ 0.6019240399143698
1604
+ ],
1605
+ "std": [
1606
+ 0.3152726802877428,
1607
+ 0.7215555774539155,
1608
+ 0.6677525379386945,
1609
+ 0.49249044506684236,
1610
+ 0.3669531426180722,
1611
+ 0.36500773276171394,
1612
+ 0.4034043094483581,
1613
+ 0.3350780291739786,
1614
+ 0.8015514140140498,
1615
+ 0.7087483761552382,
1616
+ 0.5140769455948587,
1617
+ 0.36485948060191936,
1618
+ 0.35558886385685473,
1619
+ 0.4187505380995499
1620
+ ],
1621
+ "count": [
1622
+ 76046658.0
1623
+ ],
1624
+ "q01": [
1625
+ -0.6603467782218314,
1626
+ 0.012553692652370085,
1627
+ 0.021776265158983142,
1628
+ -1.3705572057237516,
1629
+ -0.3332034826366618,
1630
+ -0.9193192400336088,
1631
+ 0.059239047676073166,
1632
+ -0.4935656974138795,
1633
+ 0.012780929401173773,
1634
+ 0.022236669213863816,
1635
+ -1.4227596196972356,
1636
+ -0.9434528867624581,
1637
+ -0.4598343195103144,
1638
+ 0.037835498581155064
1639
+ ],
1640
+ "q10": [
1641
+ -0.41642163282166217,
1642
+ 0.49507907198249584,
1643
+ 0.486584320872561,
1644
+ -1.1582997707602973,
1645
+ -0.12275828541607876,
1646
+ -0.5663963402767317,
1647
+ 0.1261316463154828,
1648
+ -0.1908506486628405,
1649
+ 0.1993559996076043,
1650
+ 0.18204643795012038,
1651
+ -1.1656159852054215,
1652
+ -0.5295295866303873,
1653
+ -0.10955673634265617,
1654
+ 0.06449180996120647
1655
+ ],
1656
+ "q50": [
1657
+ -0.07460956719060403,
1658
+ 1.4518741988602484,
1659
+ 1.2790339607814287,
1660
+ -0.8004009901188069,
1661
+ 0.11633919425925929,
1662
+ -0.2256239564587,
1663
+ 0.7410515739838786,
1664
+ 0.08125296212737787,
1665
+ 1.546374492933441,
1666
+ 1.2473645258976782,
1667
+ -0.6826830989658852,
1668
+ -0.13268823647237576,
1669
+ 0.19324335769817771,
1670
+ 0.6975293719700979
1671
+ ],
1672
+ "q90": [
1673
+ 0.21065792154366084,
1674
+ 2.001783199519663,
1675
+ 1.7536904322237028,
1676
+ -0.17570477043577734,
1677
+ 0.5016373270395832,
1678
+ 0.057081945863381375,
1679
+ 0.9793483311612012,
1680
+ 0.496661138954089,
1681
+ 2.0633422575822404,
1682
+ 1.784104252873167,
1683
+ -0.07449674242785952,
1684
+ 0.17045548433242785,
1685
+ 0.5532139533377123,
1686
+ 0.9916430884848699
1687
+ ],
1688
+ "q99": [
1689
+ 0.4683004661020414,
1690
+ 2.2309715341843326,
1691
+ 1.9982285068319416,
1692
+ 0.13319204881075056,
1693
+ 0.8574646079271142,
1694
+ 0.31881311685642116,
1695
+ 0.9862640952345495,
1696
+ 0.736253091937041,
1697
+ 2.276675221510269,
1698
+ 2.0496951704229227,
1699
+ 0.23446313153252643,
1700
+ 0.503194049828884,
1701
+ 0.9489437100128476,
1702
+ 0.9945109907992316
1703
+ ],
1704
+ "names": [
1705
+ "left_joint_0.pos",
1706
+ "left_joint_1.pos",
1707
+ "left_joint_2.pos",
1708
+ "left_joint_3.pos",
1709
+ "left_joint_4.pos",
1710
+ "left_joint_5.pos",
1711
+ "left_gripper.pos",
1712
+ "right_joint_0.pos",
1713
+ "right_joint_1.pos",
1714
+ "right_joint_2.pos",
1715
+ "right_joint_3.pos",
1716
+ "right_joint_4.pos",
1717
+ "right_joint_5.pos",
1718
+ "right_gripper.pos"
1719
+ ],
1720
+ "mask": [
1721
+ true,
1722
+ true,
1723
+ true,
1724
+ true,
1725
+ true,
1726
+ true,
1727
+ false,
1728
+ true,
1729
+ true,
1730
+ true,
1731
+ true,
1732
+ true,
1733
+ true,
1734
+ false
1735
+ ]
1736
+ }
1737
+ }
1738
+ }
1739
+ }
processing_molmoact2.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Processor class for MolmoAct2.
3
+ """
4
+ from typing import Optional, Union
5
+ import dataclasses
6
+
7
+ import numpy as np
8
+
9
+ from transformers.image_utils import ImageInput
10
+ from transformers.video_utils import VideoInput
11
+ from transformers.processing_utils import (
12
+ Unpack,
13
+ ProcessingKwargs,
14
+ ProcessorMixin,
15
+ )
16
+ from transformers.feature_extraction_utils import BatchFeature
17
+ from transformers.tokenization_utils_base import TextInput, PreTokenizedInput
18
+ from transformers.utils import logging
19
+
20
+ from transformers import AutoTokenizer
21
+ from .image_processing_molmoact2 import MolmoAct2ImagesKwargs, MolmoAct2ImageProcessor
22
+ from .video_processing_molmoact2 import MolmoAct2VideoProcessorKwargs, MolmoAct2VideoProcessor
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ # Special tokens, these should be present in any tokenizer we use since the preprocessor uses them
29
+ IMAGE_PATCH_TOKEN = f"<im_patch>" # Where to insert high-res tokens
30
+ IMAGE_LOW_RES_TOKEN = f"<im_low>" # Where to insert low-res tokens
31
+ IM_START_TOKEN = f"<im_start>"
32
+ LOW_RES_IMAGE_START_TOKEN = f"<low_res_im_start>"
33
+ FRAME_START_TOKEN = f"<frame_start>"
34
+ IM_END_TOKEN = f"<im_end>"
35
+ FRAME_END_TOKEN= f"<frame_end>"
36
+ IM_COL_TOKEN = f"<im_col>"
37
+ IMAGE_PROMPT = "<|image|>"
38
+ VIDEO_PROMPT = "<|video|>"
39
+
40
+ IMAGE_TOKENS = [
41
+ IMAGE_PATCH_TOKEN,
42
+ IM_COL_TOKEN,
43
+ IM_START_TOKEN,
44
+ LOW_RES_IMAGE_START_TOKEN,
45
+ FRAME_START_TOKEN,
46
+ IM_END_TOKEN,
47
+ FRAME_END_TOKEN,
48
+ IMAGE_LOW_RES_TOKEN,
49
+ ]
50
+
51
+
52
+ class MolmoAct2ProcessorKwargs(ProcessingKwargs, total=False):
53
+ """MolmoAct2 processor kwargs"""
54
+ images_kwargs: MolmoAct2ImagesKwargs
55
+ videos_kwargs: MolmoAct2VideoProcessorKwargs
56
+ _defaults = {
57
+ "text_kwargs": {
58
+ "padding": False,
59
+ "return_mm_token_type_ids": True,
60
+ },
61
+ "videos_kwargs": {"return_metadata": True},
62
+ }
63
+
64
+
65
+ class MolmoAct2Processor(ProcessorMixin):
66
+ attributes = ["image_processor", "video_processor", "tokenizer"]
67
+ optional_attributes = [
68
+ "chat_template",
69
+ "time_mode",
70
+ "image_use_col_tokens",
71
+ "use_single_crop_col_tokens",
72
+ "use_single_crop_start_token",
73
+ "video_use_col_tokens",
74
+ "use_frame_special_tokens",
75
+ ]
76
+ image_processor_class = "AutoImageProcessor"
77
+ video_processor_class = "AutoVideoProcessor"
78
+ tokenizer_class = "AutoTokenizer"
79
+
80
+ def __init__(
81
+ self,
82
+ image_processor: MolmoAct2ImageProcessor = None,
83
+ video_processor: MolmoAct2VideoProcessor = None,
84
+ tokenizer: AutoTokenizer = None,
85
+ chat_template: Optional[str] = None,
86
+ image_use_col_tokens: Optional[bool] = True,
87
+ use_single_crop_col_tokens: Optional[bool] = None,
88
+ use_single_crop_start_token: Optional[bool] = True,
89
+ video_use_col_tokens: Optional[bool] = False,
90
+ use_frame_special_tokens: Optional[bool] = True,
91
+ **kwargs
92
+ ) -> None:
93
+ super().__init__(
94
+ image_processor,
95
+ video_processor,
96
+ tokenizer,
97
+ chat_template=chat_template,
98
+ )
99
+ self.image_use_col_tokens = image_use_col_tokens
100
+ self.use_single_crop_col_tokens = use_single_crop_col_tokens
101
+ self.use_single_crop_start_token = use_single_crop_start_token
102
+ self.video_use_col_tokens = video_use_col_tokens
103
+ self.use_frame_special_tokens = use_frame_special_tokens
104
+
105
+ self.image_placeholder_token = IMAGE_PROMPT
106
+ self.video_placeholder_token = VIDEO_PROMPT
107
+ self.image_token_ids = [
108
+ tokenizer.convert_tokens_to_ids(token)
109
+ for token in IMAGE_TOKENS
110
+ ]
111
+
112
+ def get_image_tokens(self, image_grid: np.ndarray):
113
+ resized_h, resized_w, height, width = image_grid
114
+ if int(height) == 0 or int(width) == 0:
115
+ per_row = np.full(resized_w, IMAGE_PATCH_TOKEN)
116
+ use_single_crop_col_tokens = (
117
+ self.image_use_col_tokens
118
+ if self.use_single_crop_col_tokens is None
119
+ else self.use_single_crop_col_tokens
120
+ )
121
+ if use_single_crop_col_tokens:
122
+ per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
123
+ joint = [
124
+ [IM_START_TOKEN],
125
+ np.tile(per_row, [resized_h]),
126
+ [IM_END_TOKEN],
127
+ ]
128
+ return np.concatenate(joint)
129
+ per_row = np.full(width, IMAGE_PATCH_TOKEN)
130
+ if self.image_use_col_tokens:
131
+ per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
132
+ joint = [
133
+ [IM_START_TOKEN],
134
+ np.tile(per_row, [height]),
135
+ [IM_END_TOKEN],
136
+ ]
137
+ per_row = np.full(resized_w, IMAGE_PATCH_TOKEN)
138
+ use_single_crop_col_tokens = (
139
+ self.image_use_col_tokens
140
+ if self.use_single_crop_col_tokens is None
141
+ else self.use_single_crop_col_tokens
142
+ )
143
+ image_start_token = (
144
+ LOW_RES_IMAGE_START_TOKEN
145
+ if self.use_single_crop_start_token
146
+ else IM_START_TOKEN
147
+ )
148
+ if use_single_crop_col_tokens:
149
+ per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
150
+ joint = [
151
+ [image_start_token],
152
+ np.tile(per_row, [resized_h]),
153
+ [IM_END_TOKEN],
154
+ ] + joint
155
+
156
+ return np.concatenate(joint)
157
+
158
+ def get_video_string(
159
+ self,
160
+ video_grid: np.ndarray,
161
+ timestamps: np.ndarray,
162
+ ):
163
+ if self.use_frame_special_tokens:
164
+ start_token_id = FRAME_START_TOKEN
165
+ end_token_id = FRAME_END_TOKEN
166
+ else:
167
+ start_token_id = IM_START_TOKEN
168
+ end_token_id = IM_END_TOKEN
169
+
170
+ num_frames, h, w = video_grid
171
+ video_string: str = ""
172
+ for frame_idx, frame_time in enumerate(timestamps):
173
+ # `per-frame-compact` time mode
174
+ prev_space = " " if frame_idx > 0 else ""
175
+ frame_prefix = prev_space + f"{frame_time:.1f} " # explicit whitespace before/after image tokens
176
+
177
+ video_string += frame_prefix
178
+ per_row = np.full(w, IMAGE_PATCH_TOKEN)
179
+ if self.video_use_col_tokens:
180
+ per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
181
+ extra_tokens = np.tile(per_row, [h])
182
+ video_tokens = [
183
+ [start_token_id],
184
+ extra_tokens,
185
+ [end_token_id],
186
+ ]
187
+ video_string += "".join(np.concatenate(video_tokens, 0))
188
+
189
+ return video_string
190
+
191
+ def insert_bos(
192
+ self,
193
+ input_ids: np.ndarray,
194
+ attention_mask: np.ndarray,
195
+ bos_token_id: int,
196
+ pad_token_id: int,
197
+ ):
198
+ """
199
+ Args:
200
+ input_ids: [B, S] array with left padding
201
+ attention_mask: [B, S] array (0 for pad, 1 for valid)
202
+ bos_token_id: int
203
+ pad_token_id: int
204
+ Returns:
205
+ input_ids_out: [B, S] or [B, S+1] array with bos inserted if needed
206
+ attention_mask_out: same shape as input_ids_out
207
+ """
208
+
209
+ need_to_expand = len(input_ids.shape) == 1
210
+ if need_to_expand:
211
+ input_ids = input_ids[None, :]
212
+ attention_mask = attention_mask[None, :]
213
+
214
+ B, S = input_ids.shape
215
+
216
+ # Handle zero-length sequence
217
+ if S == 0:
218
+ new_input_ids = np.full((B, 1), bos_token_id, dtype=input_ids.dtype)
219
+ new_attention_mask = np.ones((B, 1), dtype=attention_mask.dtype)
220
+ if need_to_expand:
221
+ new_input_ids = new_input_ids[0]
222
+ new_attention_mask = new_attention_mask[0]
223
+ return new_input_ids, new_attention_mask
224
+
225
+ first_valid_index = (attention_mask == 1).argmax(axis=-1) # [B]
226
+ bos_already_present = np.all(input_ids[np.arange(B), first_valid_index] == bos_token_id)
227
+
228
+ if bos_already_present:
229
+ if need_to_expand:
230
+ input_ids = input_ids[0]
231
+ attention_mask = attention_mask[0]
232
+ return input_ids, attention_mask
233
+ else:
234
+ new_input_ids = np.full((B, S+1), pad_token_id, dtype=input_ids.dtype)
235
+ new_attention_mask = np.zeros((B, S+1), dtype=attention_mask.dtype)
236
+
237
+ src_idx = np.tile(np.arange(S), (B, 1)) # [B, S]
238
+ valid_mask = src_idx >= first_valid_index[:, None] # [B, S]
239
+ tgt_idx = src_idx + 1 # shit right
240
+ batch_idx = np.tile(np.arange(B)[:, None], (1, S)) # [B, S]
241
+
242
+ # flatten valid_positions
243
+ flat_vals = input_ids[valid_mask]
244
+ flat_batch = batch_idx[valid_mask]
245
+ flat_tgt = tgt_idx[valid_mask]
246
+
247
+ new_input_ids[flat_batch, flat_tgt] = flat_vals
248
+ new_attention_mask[flat_batch, flat_tgt] = 1
249
+
250
+ insert_pos = first_valid_index
251
+ new_input_ids[np.arange(B), insert_pos] = bos_token_id
252
+ new_attention_mask[np.arange(B), insert_pos] = 1
253
+
254
+ if need_to_expand:
255
+ new_input_ids = new_input_ids[0]
256
+ new_attention_mask = new_attention_mask[0]
257
+
258
+ return new_input_ids, new_attention_mask
259
+
260
+ def __call__(
261
+ self,
262
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
263
+ images: ImageInput = None,
264
+ videos: VideoInput = None,
265
+ **kwargs: Unpack[MolmoAct2ProcessorKwargs],
266
+ ) -> BatchFeature:
267
+ """
268
+
269
+ Args:
270
+ text (`str`, `list[str]`, `list[list[str]]`):
271
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
272
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
273
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
274
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
275
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
276
+ tensor. Both channels-first and channels-last formats are supported.
277
+ videos (`dict[str, Any]` or `list[dict[str, Any]]`):
278
+ The video or batch of videos to be prepared. Each video can be a dictionary with the following keys:
279
+ - `"frames"`: `np.ndarray` of shape (T, H, W, 3)
280
+ - `"timestamps"`: `np.ndarray` of shape (T,)
281
+ - `"sampled_fps"`: `float` (optional)
282
+ - `"sampling_augmentation"`: `str` (optional)
283
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
284
+ If set, will return tensors of a particular framework. Acceptable values are:
285
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
286
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
287
+ - `'np'`: Return NumPy `np.ndarray` objects.
288
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
289
+
290
+ Returns:
291
+ `BatchFeature`: A [`BatchFeature`] with the following fields:
292
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
293
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
294
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`).
295
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
296
+ - **image_token_pooling** -- Indices of the patches in `image_grids` to pool for each token in `image_tokens`.
297
+ Returned when `images` is not `None`.
298
+ - **image_grids** -- Grids of images. Returned when `images` is not `None`.
299
+ - **image_num_crops** -- Number of crops for each image. Returned when `images` is not `None`.
300
+ - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
301
+ - **video_token_pooling** -- Indices of the patches in `video_grids` to pool for each token in `video_tokens`.
302
+ Returned when `videos` is not `None`.
303
+ - **video_grids** -- Grids of videos. Returned when `videos` is not `None`.
304
+ """
305
+
306
+ output_kwargs = self._merge_kwargs(
307
+ MolmoAct2ProcessorKwargs,
308
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
309
+ **kwargs,
310
+ )
311
+
312
+ if images is not None:
313
+ image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
314
+ image_grids = image_inputs["image_grids"]
315
+ else:
316
+ image_inputs = {}
317
+ image_grids = None
318
+
319
+ if videos is not None:
320
+ videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
321
+ video_grids = videos_inputs["video_grids"]
322
+ # If user has not requested video metadata, pop it
323
+ if "return_metadata" not in kwargs:
324
+ video_metadata = videos_inputs.pop("video_metadata")
325
+ else:
326
+ video_metadata = videos_inputs["video_metadata"]
327
+ else:
328
+ videos_inputs = {}
329
+ video_grids = None
330
+
331
+ if not isinstance(text, list):
332
+ text = [text]
333
+
334
+ text = text.copy() # below lines change text in-place
335
+
336
+ if image_grids is not None:
337
+ index = 0
338
+ for i in range(len(text)):
339
+ num_images = text[i].count(self.image_placeholder_token)
340
+ image_grids_i = image_grids[index:index+num_images]
341
+ for image_grid in image_grids_i:
342
+ image_tokens = self.get_image_tokens(image_grid)
343
+ image_string = "".join(image_tokens)
344
+ text[i] = text[i].replace(self.image_placeholder_token, image_string, 1)
345
+ index += num_images
346
+
347
+ if video_grids is not None:
348
+ index = 0
349
+ for i in range(len(text)):
350
+ num_videos = text[i].count(self.video_placeholder_token)
351
+ assert num_videos in {0, 1}, "At most one video is supported for now"
352
+ video_grids_i = video_grids[index:index+num_videos]
353
+ metadata_i = video_metadata[index:index+num_videos]
354
+ for video_grid, metadata in zip(video_grids_i, metadata_i):
355
+ video_string = self.get_video_string(
356
+ video_grid,
357
+ metadata.timestamps,
358
+ )
359
+ text[i] = text[i].replace(self.video_placeholder_token, video_string, 1)
360
+ index += num_videos
361
+
362
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
363
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
364
+ text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
365
+
366
+ input_ids = text_inputs["input_ids"]
367
+ attention_mask = text_inputs["attention_mask"]
368
+
369
+ input_ids = np.array(input_ids)
370
+ attention_mask = np.array(attention_mask)
371
+
372
+ bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
373
+ input_ids, attention_mask = self.insert_bos(
374
+ input_ids, attention_mask, bos, self.tokenizer.pad_token_id
375
+ )
376
+
377
+ if return_mm_token_type_ids:
378
+ image_tokens = np.array(self.image_token_ids).astype(input_ids.dtype)
379
+ token_type_ids = np.any(input_ids[:, :, None] == image_tokens[None, None, :], axis=-1)
380
+ text_inputs["token_type_ids"] = token_type_ids.tolist()
381
+
382
+ text_inputs["input_ids"] = input_ids.tolist()
383
+ text_inputs["attention_mask"] = attention_mask.tolist()
384
+
385
+ return BatchFeature(
386
+ data={**text_inputs, **image_inputs, **videos_inputs},
387
+ tensor_type=return_tensors,
388
+ )
389
+
390
+ def post_process_image_text_to_text(
391
+ self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
392
+ ):
393
+ """
394
+ Post-process the output of the model to decode the text.
395
+
396
+ Args:
397
+ generated_outputs (`torch.Tensor` or `np.ndarray`):
398
+ The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
399
+ or `(sequence_length,)`.
400
+ skip_special_tokens (`bool`, *optional*, defaults to `True`):
401
+ Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
402
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
403
+ Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
404
+ **kwargs:
405
+ Additional arguments to be passed to the tokenizer's `batch_decode method`.
406
+
407
+ Returns:
408
+ `list[str]`: The decoded text.
409
+ """
410
+ return self.tokenizer.batch_decode(
411
+ generated_outputs,
412
+ skip_special_tokens=skip_special_tokens,
413
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
414
+ **kwargs,
415
+ )
416
+
417
+
418
+ MolmoAct2Processor.register_for_auto_class()
processor_config.json ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_molmoact2.MolmoAct2Processor"
4
+ },
5
+ "image_processor": {
6
+ "auto_map": {
7
+ "AutoImageProcessor": "image_processing_molmoact2.MolmoAct2ImageProcessor",
8
+ "AutoProcessor": "processing_molmoact2.MolmoAct2Processor"
9
+ },
10
+ "crop_mode": "resize",
11
+ "do_convert_rgb": true,
12
+ "image_mean": [
13
+ 0.5,
14
+ 0.5,
15
+ 0.5
16
+ ],
17
+ "image_processor_type": "MolmoAct2ImageProcessor",
18
+ "image_std": [
19
+ 0.5,
20
+ 0.5,
21
+ 0.5
22
+ ],
23
+ "max_crops": 8,
24
+ "overlap_margins": [
25
+ 4,
26
+ 4
27
+ ],
28
+ "patch_size": 14,
29
+ "pooling_size": [
30
+ 2,
31
+ 2
32
+ ],
33
+ "resample": 2,
34
+ "size": {
35
+ "height": 378,
36
+ "width": 378
37
+ }
38
+ },
39
+ "image_use_col_tokens": true,
40
+ "processor_class": "MolmoAct2Processor",
41
+ "use_frame_special_tokens": true,
42
+ "use_single_crop_col_tokens": false,
43
+ "use_single_crop_start_token": true,
44
+ "video_processor": {
45
+ "auto_map": {
46
+ "AutoProcessor": "processing_molmoact2.MolmoAct2Processor",
47
+ "AutoVideoProcessor": "video_processing_molmoact2.MolmoAct2VideoProcessor"
48
+ },
49
+ "data_format": "channels_first",
50
+ "default_to_square": true,
51
+ "do_convert_rgb": true,
52
+ "do_normalize": true,
53
+ "do_rescale": true,
54
+ "do_resize": true,
55
+ "do_sample_frames": true,
56
+ "frame_sample_mode": "uniform_last_frame",
57
+ "image_mean": [
58
+ 0.5,
59
+ 0.5,
60
+ 0.5
61
+ ],
62
+ "image_std": [
63
+ 0.5,
64
+ 0.5,
65
+ 0.5
66
+ ],
67
+ "max_fps": 2.0,
68
+ "num_frames": 8,
69
+ "patch_size": 14,
70
+ "pooling_size": [
71
+ 3,
72
+ 3
73
+ ],
74
+ "resample": 2,
75
+ "rescale_factor": 0.00392156862745098,
76
+ "return_metadata": false,
77
+ "sampling_fps": 2,
78
+ "size": {
79
+ "height": 378,
80
+ "width": 378
81
+ },
82
+ "video_processor_type": "MolmoAct2VideoProcessor"
83
+ },
84
+ "video_use_col_tokens": false
85
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b6aeec78de2b0c7e95d7ae9d71cd04eba3d57351045a86c95520730e9c80d83
3
+ size 12176547
tokenizer_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "auto_map": {
4
+ "AutoProcessor": "processing_molmoact2.MolmoAct2Processor"
5
+ },
6
+ "backend": "tokenizers",
7
+ "bos_token": "<|im_end|>",
8
+ "clean_up_tokenization_spaces": false,
9
+ "eos_token": "<|im_end|>",
10
+ "errors": "replace",
11
+ "extra_special_tokens": [
12
+ "<im_start>",
13
+ "<im_end>",
14
+ "<im_patch>",
15
+ "<im_col>",
16
+ "<low_res_im_start>",
17
+ "<|image|>",
18
+ "<im_low>",
19
+ "<frame_start>",
20
+ "<frame_end>",
21
+ "<|video|>",
22
+ "<|points|>",
23
+ "<|token_index|>",
24
+ "<|vit_index|>",
25
+ "<|vit_loc|>"
26
+ ],
27
+ "is_local": false,
28
+ "model_max_length": 1010000,
29
+ "pad_token": "<|endoftext|>",
30
+ "processor_class": "MolmoAct2Processor",
31
+ "split_special_tokens": false,
32
+ "tokenizer_class": "Qwen2Tokenizer",
33
+ "unk_token": null
34
+ }
video_processing_molmoact2.py ADDED
@@ -0,0 +1,969 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Video processor class for MolmoAct2"""
2
+ from functools import partial
3
+ import os
4
+ import warnings
5
+ from contextlib import redirect_stdout
6
+ from io import BytesIO
7
+ from urllib.parse import urlparse
8
+ from typing import Optional, Union, Callable
9
+
10
+ import numpy as np
11
+ import requests
12
+ import einops
13
+ import torch
14
+ import torchvision.transforms
15
+
16
+ from transformers.image_utils import (
17
+ IMAGENET_STANDARD_MEAN,
18
+ IMAGENET_STANDARD_STD,
19
+ ImageInput,
20
+ PILImageResampling,
21
+ SizeDict,
22
+ validate_kwargs,
23
+ )
24
+ from transformers.video_utils import (
25
+ VideoInput,
26
+ is_valid_video,
27
+ make_batched_videos,
28
+ make_batched_metadata,
29
+ VideoMetadata,
30
+ )
31
+ from transformers.processing_utils import Unpack, VideosKwargs
32
+ from transformers.video_processing_utils import BaseVideoProcessor
33
+ from transformers.utils import logging
34
+ from transformers.feature_extraction_utils import BatchFeature
35
+ from transformers.utils import (
36
+ is_av_available,
37
+ is_decord_available,
38
+ is_torchcodec_available,
39
+ is_yt_dlp_available,
40
+ TensorType,
41
+ logging,
42
+ to_numpy,
43
+ )
44
+
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+ MAX_VIDEO_FPS = 8
49
+
50
+
51
+ def normalize_image(
52
+ image: np.ndarray,
53
+ image_mean: list[float],
54
+ image_std: list[float],
55
+ ) -> np.ndarray:
56
+ if np.allclose(image_mean, [0.5, 0.5, 0.5]) and np.allclose(image_std, [0.5, 0.5, 0.5]):
57
+ return image * np.asarray(2.0, dtype=np.float32) - np.asarray(1.0, dtype=np.float32)
58
+ image -= np.array(image_mean, dtype=np.float32)[None, None, :]
59
+ image /= np.array(image_std, dtype=np.float32)[None, None, :]
60
+ return image
61
+
62
+
63
+ def resize_image(
64
+ image: np.ndarray,
65
+ desired_output_size: list[int],
66
+ resample: PILImageResampling,
67
+ ) -> np.ndarray:
68
+ if len(image.shape) == 3:
69
+ is_video = False
70
+ image = torch.permute(torch.from_numpy(image), [2, 0, 1])
71
+ else:
72
+ is_video = True
73
+ image = torch.permute(torch.from_numpy(image), [0, 3, 1, 2])
74
+ dtype = image.dtype
75
+ if torch.is_floating_point(image):
76
+ in_min = 0.0
77
+ in_max = 1.0
78
+ resized = torchvision.transforms.Resize(
79
+ desired_output_size,
80
+ resample,
81
+ antialias=False,
82
+ )(image)
83
+ resized = torch.clip(resized, 0.0, 1.0).to(dtype)
84
+ else:
85
+ assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(image.dtype)
86
+ in_min = 0.0
87
+ in_max = 255.0
88
+ resized = torchvision.transforms.Resize(
89
+ desired_output_size,
90
+ resample,
91
+ antialias=False,
92
+ )(image)
93
+ resized = torch.clip(resized, 0, 255).to(dtype)
94
+
95
+ resized = resized.to(torch.float32)
96
+ resized = (resized - in_min) / (in_max - in_min)
97
+
98
+ if is_video:
99
+ resized = torch.permute(resized, [0, 2, 3, 1]).numpy()
100
+ else:
101
+ resized = torch.permute(resized, [1, 2, 0]).numpy()
102
+
103
+ return resized
104
+
105
+
106
+ def build_resized_image(
107
+ image: np.ndarray,
108
+ base_image_input_size: list[int],
109
+ resample: PILImageResampling,
110
+ image_mean: list[float],
111
+ image_std: list[float],
112
+ image_patch_size: int,
113
+ ) -> tuple[np.ndarray, np.ndarray]:
114
+ resized = resize_image(
115
+ image, base_image_input_size, resample,
116
+ )
117
+ resized = normalize_image(resized, image_mean, image_std)
118
+ if len(resized.shape) == 3:
119
+ resized = np.expand_dims(resized, 0)
120
+ crop_patch_w = base_image_input_size[1] // image_patch_size
121
+ crop_patch_h = base_image_input_size[0] // image_patch_size
122
+ resize_idx = np.arange(crop_patch_w*crop_patch_h).reshape([crop_patch_h, crop_patch_w])
123
+ return resized, resize_idx
124
+
125
+
126
+ def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
127
+ """Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
128
+ if len(array.shape) == 3:
129
+ n_crops, h, w = array.shape
130
+ h_patches = h//patch_size
131
+ w_patches = w//patch_size
132
+ array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
133
+ array = np.transpose(array, [0, 1, 3, 2, 4])
134
+ array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size])
135
+ return array
136
+ else:
137
+ n_crops, h, w, c = array.shape
138
+ h_patches = h//patch_size
139
+ w_patches = w//patch_size
140
+ array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
141
+ array = np.transpose(array, [0, 1, 3, 2, 4, 5])
142
+ array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size*c])
143
+ return array
144
+
145
+
146
+ def arange_for_pooling(
147
+ idx_arr: np.ndarray,
148
+ pool_h: int,
149
+ pool_w: int,
150
+ ) -> np.ndarray:
151
+ h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0]
152
+ w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1]
153
+ idx_arr = np.pad(idx_arr, [[h_pad//2, (h_pad+1)//2], [w_pad//2, (w_pad+1)//2]],
154
+ mode='constant',constant_values=-1)
155
+ return einops.rearrange(
156
+ idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
157
+
158
+
159
+ def image_to_patches_and_grids(
160
+ image: ImageInput,
161
+ base_image_input_size: list[int],
162
+ resample: PILImageResampling,
163
+ image_mean: list[float],
164
+ image_std: list[float],
165
+ image_patch_size: int,
166
+ image_pooling_w: int,
167
+ image_pooling_h: int,
168
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
169
+ """
170
+ :return image_grids, the shape of each image after pooling
171
+ :return crops, the image crops to processes with the ViT
172
+ :return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the
173
+ patches in `crops` to pool for that token, masked with -1
174
+ """
175
+ if isinstance(base_image_input_size, int):
176
+ base_image_input_size = (base_image_input_size, base_image_input_size)
177
+
178
+ pooling_w = image_pooling_w
179
+ pooling_h = image_pooling_h
180
+
181
+ resized, resize_idx = build_resized_image(
182
+ image,
183
+ base_image_input_size,
184
+ resample,
185
+ image_mean,
186
+ image_std,
187
+ image_patch_size,
188
+ )
189
+ pooling_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
190
+ h, w = pooling_idx.shape[:2]
191
+ pooling_idx = pooling_idx.reshape([-1, pooling_h*pooling_w])
192
+ image_grid = [h, w]
193
+ return (
194
+ image_grid,
195
+ batch_pixels_to_patches(resized, image_patch_size),
196
+ pooling_idx,
197
+ )
198
+
199
+
200
+ def get_candidate_target_fps(
201
+ video_fps: Union[int, float],
202
+ sampling_fps: Union[int, float],
203
+ max_fps: Union[int, float] = MAX_VIDEO_FPS,
204
+ ) -> list[float]:
205
+ """
206
+ Return the subset of `video_fps` factors that remain multiples of `sampling_fps`.
207
+
208
+ Examples:
209
+ >>> get_candidate_target_fps(video_fps=6, sampling_fps=2)
210
+ [2, 6]
211
+ >>> get_candidate_target_fps(video_fps=5, sampling_fps=1)
212
+ [1, 5]
213
+ >>> get_candidate_target_fps(video_fps=2, sampling_fps=2)
214
+ [2]
215
+ >>> get_candidate_target_fps(video_fps=5, sampling_fps=2)
216
+ Traceback (most recent call last):
217
+ ...
218
+ ValueError: sampling_fps=2 must divide video_fps=5 to produce consistent frame steps.
219
+ """
220
+ video_fps = int(video_fps)
221
+ sampling_fps = int(sampling_fps)
222
+ max_fps = int(max_fps)
223
+
224
+ if sampling_fps is None:
225
+ raise ValueError("sampling_fps must be provided")
226
+ if video_fps <= 0 or sampling_fps <= 0:
227
+ raise ValueError(f"video_fps and sampling_fps must be positive (got {video_fps}, {sampling_fps})")
228
+ if video_fps % sampling_fps != 0:
229
+ raise ValueError(f"sampling_fps={sampling_fps} must divide video_fps={video_fps}.")
230
+
231
+ candidates = []
232
+ for candidate in range(sampling_fps, video_fps + 1, sampling_fps):
233
+ if candidate > max_fps:
234
+ break
235
+ if video_fps % candidate == 0:
236
+ candidates.append(float(candidate))
237
+
238
+ return candidates
239
+
240
+
241
+ def read_video_decord(
242
+ video_path,
243
+ sample_timestamps_fn: Callable,
244
+ **kwargs,
245
+ ) -> np.ndarray:
246
+ """
247
+ Decode a video using the Decord backend.
248
+
249
+ Args:
250
+ video_path (`str`):
251
+ Path to the video file.
252
+ sample_timestamps_fn (`Callable`):
253
+ A callable function that will return timestamps at which the video should be sampled.
254
+
255
+ Returns:
256
+ tuple[`np.array`, `VideoMetadata`]: A tuple containing:
257
+ - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
258
+ - `VideoMetadata` object.
259
+ """
260
+ # Lazy import from decord
261
+ import importlib
262
+ decord = importlib.import_module("decord")
263
+
264
+ vr = decord.VideoReader(uri=video_path, ctx=decord.cpu(0)) # decord has problems with gpu
265
+ video_fps = vr.get_avg_fps()
266
+ total_num_frames = len(vr)
267
+ time_stamps = vr.get_frame_timestamp(list(range(len(vr))))
268
+ duration = time_stamps[-1][1] - time_stamps[0][0]
269
+
270
+ metadata = VideoMetadata(
271
+ total_num_frames=int(total_num_frames),
272
+ fps=float(video_fps),
273
+ duration=float(duration),
274
+ video_backend="decord",
275
+ )
276
+
277
+ target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
278
+ target_timestamps = np.array(target_timestamps)
279
+ offset = time_stamps[0, 0]
280
+
281
+ ix = np.searchsorted(time_stamps[:, 1], target_timestamps + offset, side='right')
282
+ ix = np.minimum(ix, len(time_stamps) - 1)
283
+
284
+ video = vr.get_batch(ix).asnumpy()
285
+ metadata.update(
286
+ {
287
+ "frames_indices": target_timestamps * video_fps,
288
+ "height": video.shape[1],
289
+ "width": video.shape[2],
290
+ }
291
+ )
292
+ return video, metadata
293
+
294
+
295
+ def read_video_torchcodec(
296
+ video_path,
297
+ sample_timestamps_fn: Callable,
298
+ **kwargs,
299
+ ) -> np.ndarray:
300
+ """
301
+ Decode a video using torchcodec decoder.
302
+
303
+ Args:
304
+ video_path (`str`):
305
+ Path to the video file.
306
+ sample_timestamps_fn (`Callable`):
307
+ A callable function that will return timestamps at which the video should be sampled.
308
+
309
+ Returns:
310
+ tuple[`np.array`, `VideoMetadata`]: A tuple containing:
311
+ - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
312
+ - `VideoMetadata` object.
313
+ """
314
+ # Lazy import torchcodec
315
+ import importlib
316
+ torchcodec = importlib.import_module("torchcodec")
317
+
318
+ decoder = torchcodec.decoders.VideoDecoder(
319
+ video_path,
320
+ # Interestingly `exact` mode takes less than approximate when we load the whole video
321
+ seek_mode="exact",
322
+ # Allow FFmpeg decide on the number of threads for efficiency
323
+ num_ffmpeg_threads=0,
324
+ )
325
+ # If the first frame starts at > 0, we effectively clip the video starting at that time
326
+ # since (most) video players would also skip to that time
327
+ time_offset = decoder.metadata.begin_stream_seconds_from_content
328
+ # Note this duration does assume we started playing at `time_offset`
329
+ duration = decoder.metadata.duration_seconds
330
+
331
+ metadata = VideoMetadata(
332
+ total_num_frames=decoder.metadata.num_frames,
333
+ fps=decoder.metadata.average_fps,
334
+ duration=duration,
335
+ video_backend="torchcodec",
336
+ height=decoder.metadata.height,
337
+ width=decoder.metadata.width,
338
+ )
339
+
340
+ target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
341
+
342
+ # Floating point/rounding issues might cause `target_timestamps` to be very slightly
343
+ # out-of-bounds, to handle this we sanity check then clip them
344
+ assert all(x >= 0 for x in target_timestamps)
345
+ assert all(x < duration+1e-6 for x in target_timestamps)
346
+ # 1e-6 padding since torchcodec can throw out-of-bounds errors even if you ask for the
347
+ # exact boundary value, we should still get the first/last frame anyway
348
+ max_timestamp = decoder.metadata.end_stream_seconds_from_content - 1e-6
349
+ min_timestamp = decoder.metadata.begin_stream_seconds_from_content + 1e-6
350
+ # Note we avoid using numpy ops here to reduce floating precision issues
351
+ timestamps = [x + time_offset for x in target_timestamps]
352
+ timestamps = [max(min_timestamp, min(max_timestamp, x)) for x in timestamps]
353
+
354
+ video = decoder.get_frames_played_at(timestamps).data.numpy().transpose(0, 2, 3, 1) # Convert to THWC format
355
+ target_timestamps = np.array(target_timestamps)
356
+ metadata.frames_indices = target_timestamps * metadata.fps
357
+
358
+ return video, metadata
359
+
360
+
361
+ def read_video_pyav(
362
+ video_path,
363
+ sample_timestamps_fn: Callable,
364
+ **kwargs,
365
+ ) -> np.ndarray:
366
+ """
367
+ Decode a video using the PyAV backend.
368
+
369
+ Args:
370
+ video_path (`str`):
371
+ Path to the video file.
372
+ sample_timestamps_fn (`Callable`):
373
+ A callable function that will return timestamps at which the video should be sampled.
374
+
375
+ Returns:
376
+ tuple[`np.array`, `VideoMetadata`]: A tuple containing:
377
+ - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
378
+ - `VideoMetadata` object.
379
+ """
380
+ # Lazy import torchcodec
381
+ import importlib
382
+ av = importlib.import_module("av")
383
+
384
+ with av.open(video_path) as container:
385
+ video_stream = container.streams.video[0]
386
+ fps = video_stream.average_rate or video_stream.guessed_rate
387
+ it = container.decode(video=0)
388
+ frames = list(it)
389
+
390
+ stream = container.streams.video[0]
391
+ start = frames[0].pts * stream.time_base
392
+ container_end = stream.duration
393
+ if container_end is not None:
394
+ container_end *= stream.time_base
395
+ if container_end is None or container_end < frames[-1].pts:
396
+ # Some problem with stream duration, so use the frame PTS directly
397
+ # and guess the duration of the last frame
398
+ end = frames[-1].pts * stream.time_base + 1/fps
399
+ else:
400
+ end = container_end
401
+ duration = float(end - start)
402
+
403
+ metadata = VideoMetadata(
404
+ total_num_frames=len(frames),
405
+ fps=float(fps),
406
+ duration=float(duration),
407
+ video_backend="pyav",
408
+ height=video_stream.height,
409
+ width=video_stream.width,
410
+ )
411
+
412
+ target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
413
+ offset = float(start)
414
+
415
+ target_timestamps = np.array(target_timestamps)
416
+ end_time_stamps = np.array([float(frame.pts * stream.time_base) for frame in frames[1:]] + [duration])
417
+ indices = np.searchsorted(end_time_stamps, target_timestamps + offset, side='right')
418
+ indices = np.minimum(indices, len(end_time_stamps) - 1)
419
+
420
+ video = np.stack(
421
+ [frames[i].to_ndarray(format="rgb24", channel_last=True) for i in indices],
422
+ axis=0,
423
+ )
424
+
425
+ metadata.frames_indices = target_timestamps * fps
426
+
427
+ return video, metadata
428
+
429
+
430
+ VIDEO_DECODERS = {
431
+ "decord": read_video_decord,
432
+ "torchcodec": read_video_torchcodec,
433
+ "pyav": read_video_pyav,
434
+ }
435
+
436
+
437
+ def load_video(
438
+ video: VideoInput,
439
+ backend: str = "decord",
440
+ sample_timestamps_fn: Optional[Callable] = None,
441
+ **kwargs,
442
+ ):
443
+ """
444
+ Loads `video` to a numpy array.
445
+
446
+ Args:
447
+ video (`VideoInput`):
448
+ The video to convert to the numpy array format. Can be a link to video or local path.
449
+ backend (`str`, *optional*, defaults to `"decord"`):
450
+ The backend to use when loading the video. Can be any of ["decord", "pyav", ""torchcodec"]. Defaults to "decord".
451
+ sample_timestamps_fn (`Callable`):
452
+ A callable function that will return timestamps at which the video should be sampled.
453
+ """
454
+
455
+ # Early exit if provided an array or `PIL` frames
456
+ if not isinstance(video, str):
457
+ metadata = [None] * len(video)
458
+ return video, metadata
459
+
460
+ if urlparse(video).netloc in ["www.youtube.com", "youtube.com"]:
461
+ if not is_yt_dlp_available():
462
+ raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.")
463
+ # Lazy import from yt_dlp
464
+ import importlib
465
+ yt_dlp = importlib.import_module("yt_dlp")
466
+
467
+ buffer = BytesIO()
468
+ with redirect_stdout(buffer), yt_dlp.YoutubeDL() as f:
469
+ f.download([video])
470
+ bytes_obj = buffer.getvalue()
471
+ file_obj = BytesIO(bytes_obj)
472
+ elif video.startswith("http://") or video.startswith("https://"):
473
+ file_obj = BytesIO(requests.get(video).content)
474
+ elif os.path.isfile(video):
475
+ file_obj = video
476
+ else:
477
+ raise TypeError("Incorrect format used for video. Should be an url linking to an video or a local path.")
478
+
479
+ # can also load with decord, but not cv2/torchvision
480
+ # both will fail in case of url links
481
+ video_is_url = video.startswith("http://") or video.startswith("https://")
482
+ if video_is_url and backend == "opencv":
483
+ raise ValueError("If you are trying to load a video from URL, you cannot use 'opencv' as backend")
484
+
485
+ if (
486
+ (not is_decord_available() and backend == "decord")
487
+ or (not is_torchcodec_available() and backend == "torchcodec")
488
+ or (not is_av_available() and backend == "pyav")
489
+ ):
490
+ raise ImportError(
491
+ f"You chose backend={backend} for loading the video but the required library is not found in your environment "
492
+ f"Make sure to install {backend} before loading the video."
493
+ )
494
+
495
+ video_decoder = VIDEO_DECODERS[backend]
496
+ video, metadata = video_decoder(file_obj, sample_timestamps_fn, **kwargs)
497
+ return video, metadata
498
+
499
+
500
+ def get_target_fps(
501
+ video_fps: float,
502
+ max_frames: int,
503
+ total_frames: int,
504
+ frame_sample_mode: str,
505
+ candidate_target_fps: tuple[float],
506
+ ) -> float:
507
+ """
508
+ Get the target fps that best spans the video and has the most frames sampled
509
+ """
510
+ num_frames_sampled = 0
511
+ selected_target_fps = None
512
+ for target_fps in candidate_target_fps:
513
+ step_size = max(int(video_fps / target_fps), 1)
514
+ num_frames_sampled_at_fps = int(total_frames / step_size)
515
+ if num_frames_sampled == 0:
516
+ if "uniform" in frame_sample_mode:
517
+ if num_frames_sampled_at_fps > max_frames:
518
+ break
519
+ selected_target_fps = target_fps
520
+ num_frames_sampled = num_frames_sampled_at_fps
521
+
522
+ else:
523
+ # the candidate sampling fps increases so frame count can't decrease
524
+ assert num_frames_sampled <= num_frames_sampled_at_fps
525
+ if num_frames_sampled_at_fps > max_frames:
526
+ # choose the sampling fps that spans the video
527
+ continue
528
+
529
+ elif num_frames_sampled_at_fps > num_frames_sampled:
530
+ # both are less than max_frames, choose the one with higher density of frames sampled
531
+ selected_target_fps = target_fps
532
+ num_frames_sampled = num_frames_sampled_at_fps
533
+ return selected_target_fps
534
+
535
+
536
+ def get_frame_times_and_chosen_fps(
537
+ selected_target_fps,
538
+ total_frames,
539
+ max_frames,
540
+ video_fps
541
+ ):
542
+ if selected_target_fps is None:
543
+ frame_indices = np.linspace(0, total_frames, max_frames, endpoint=False, dtype=int)
544
+ else:
545
+ step_size = max(int(video_fps / selected_target_fps), 1)
546
+ frame_indices = np.arange(0, total_frames, step_size)
547
+ if len(frame_indices) > max_frames:
548
+ frame_indices = frame_indices[:max_frames]
549
+ return selected_target_fps, frame_indices
550
+
551
+
552
+ class MolmoAct2VideoProcessorKwargs(VideosKwargs, total=False):
553
+ patch_size: Optional[int]
554
+ pooling_size: Optional[list[int]]
555
+ frame_sample_mode: Optional[str]
556
+ max_fps: Optional[int]
557
+ sampling_fps: Optional[int]
558
+
559
+
560
+ class MolmoAct2VideoProcessor(BaseVideoProcessor):
561
+ resample = PILImageResampling.BILINEAR
562
+ size = {"height": 378, "width": 378}
563
+ image_mean = IMAGENET_STANDARD_MEAN
564
+ image_std = IMAGENET_STANDARD_STD
565
+ do_resize = True
566
+ do_rescale = True
567
+ do_normalize = True
568
+ do_convert_rgb = True
569
+ patch_size = 14
570
+ pooling_size = [3, 3]
571
+ do_sample_frames = True
572
+ frame_sample_mode = "uniform_last_frame"
573
+ max_fps = 2
574
+ sampling_fps = 2
575
+ valid_kwargs = MolmoAct2VideoProcessorKwargs
576
+ model_input_names = ["pixel_values_videos", "video_token_pooling", "video_grids"]
577
+
578
+ def __init__(self, **kwargs: Unpack[MolmoAct2VideoProcessorKwargs]):
579
+ super().__init__(**kwargs)
580
+ if self.size is not None and (
581
+ self.size.get("height", None) is None or self.size.get("width", None) is None
582
+ ):
583
+ raise ValueError("size must contain 'height' and 'width' keys.")
584
+
585
+ def _further_process_kwargs(
586
+ self,
587
+ size: Optional[SizeDict] = None,
588
+ **kwargs,
589
+ ) -> dict:
590
+ """
591
+ Update kwargs that need further processing before being validated
592
+ Can be overridden by subclasses to customize the processing of kwargs.
593
+ """
594
+ if size is not None and ("height" not in size or "width" not in size):
595
+ raise ValueError("size must contain 'height' and 'width' keys.")
596
+
597
+ return super()._further_process_kwargs(size=size, **kwargs)
598
+
599
+ def sample_times(
600
+ self,
601
+ metadata: VideoMetadata,
602
+ frame_sample_mode: str,
603
+ num_frames: int,
604
+ max_fps: Optional[int] = None,
605
+ sampling_fps: Optional[int] = None,
606
+ **kwargs,
607
+ ) -> np.ndarray:
608
+ """
609
+ Time-based sampling if an array video is passed
610
+ Args:
611
+ metadata (`VideoMetadata`):
612
+ Metadata of the video containing information about total duration, fps and total number of frames.
613
+ frame_sample_mode (`str`, *optional*):
614
+ Mode to sample frames. Defaults to `self.frame_sample_mode`.
615
+ num_frames (`int`, *optional*):
616
+ Maximum number of frames to sample. Defaults to `self.num_frames`.
617
+ man_fps (`int`, *optional*):
618
+ Maximum frames per second to sample.
619
+ sampling_fps (`int`, *optional*):
620
+ Sampling frames per second. Defaults to `self.sampling_fps`.
621
+ Used when `frame_sample_mode` is `"fps"`.
622
+ """
623
+ frame_sample_mode = frame_sample_mode or self.frame_sample_mode
624
+ num_frames = num_frames or self.num_frames
625
+ sampling_fps = sampling_fps or self.sampling_fps
626
+
627
+ duration = metadata.duration or metadata.total_num_frames / metadata.fps
628
+ if frame_sample_mode == "fps":
629
+ candidate_target_fps = get_candidate_target_fps(metadata.fps, sampling_fps)
630
+ # Try larger and larger FPSs until we hit one that can't span the video
631
+ target_fps = candidate_target_fps[0]
632
+ for candidate_fps in candidate_target_fps[1:]:
633
+ if num_frames / candidate_fps < duration:
634
+ break
635
+ target_fps = candidate_fps
636
+ times = np.arange(0, num_frames) / target_fps
637
+ times = times[times < duration]
638
+ return times
639
+ elif frame_sample_mode == "uniform_last_frame":
640
+ if max_fps is not None:
641
+ max_duration = (num_frames-1) / max_fps # -1 to include the last frame
642
+ if max_duration < duration:
643
+ times = np.linspace(
644
+ 0, duration, num=num_frames, endpoint=True, dtype=np.float64
645
+ )
646
+ else:
647
+ times = np.arange(0.0, stop=duration, step=1/max_fps)
648
+ times = np.concatenate([times, [duration]], axis=0)
649
+ assert len(times) <= num_frames
650
+ else:
651
+ times = np.linspace(
652
+ 0, duration, num=num_frames, endpoint=True, dtype=np.float64
653
+ )
654
+ return times
655
+ else:
656
+ raise NotImplementedError(frame_sample_mode)
657
+
658
+ def sample_frames(
659
+ self,
660
+ metadata: VideoMetadata,
661
+ frame_sample_mode: Optional[str] = None,
662
+ num_frames: Optional[int] = None,
663
+ max_fps: Optional[int] = None,
664
+ sampling_fps: Optional[int] = None,
665
+ **kwargs,
666
+ ) -> np.ndarray:
667
+ """
668
+ Frame-based sampling if an array video is passed
669
+ Args:
670
+ metadata (`VideoMetadata`):
671
+ Metadata of the video containing information about total duration, fps and total number of frames.
672
+ frame_sample_mode (`str`, *optional*):
673
+ Mode to sample frames. Defaults to `self.frame_sample_mode`.
674
+ num_frames (`int`, *optional*):
675
+ Maximum number of frames to sample. Defaults to `self.num_frames`.
676
+ max_fps (`int`, *optional*):
677
+ Maximum frames per second to sample.
678
+ sampling_fps (`int`, *optional*):
679
+ Sampling frames per second. Defaults to `self.sampling_fps`.
680
+ Used when `frame_sample_mode` is `"fps"`.
681
+ """
682
+ frame_sample_mode = frame_sample_mode or self.frame_sample_mode
683
+ num_frames = num_frames or self.num_frames
684
+ sampling_fps = sampling_fps or self.sampling_fps
685
+
686
+ total_num_frames = metadata.total_num_frames
687
+ if frame_sample_mode == "uniform_last_frame" and max_fps is not None:
688
+ duration = total_num_frames / metadata.fps
689
+ if total_num_frames <= 2:
690
+ return np.arange(total_num_frames).astype(int)
691
+ if duration > (num_frames - 1) / max_fps: # -1 to include the last frame
692
+ # uniform fallback
693
+ indices = np.linspace(
694
+ 0,
695
+ total_num_frames - 1,
696
+ num=min(num_frames, total_num_frames),
697
+ endpoint=True,
698
+ ).astype(int)
699
+ return indices
700
+ else:
701
+ float_indices = np.arange(
702
+ 0.0, stop=total_num_frames - 1, step=float(metadata.fps / max_fps),
703
+ )
704
+ if np.round(float_indices[-1]) != total_num_frames - 1:
705
+ float_indices = np.concatenate([float_indices, [total_num_frames - 1]], axis=0)
706
+ indices = np.round(float_indices).astype(int)
707
+ assert indices[-1] < total_num_frames
708
+ assert len(float_indices) <= num_frames
709
+ return indices
710
+ elif frame_sample_mode == "uniform_last_frame":
711
+ indices = np.linspace(
712
+ 0, total_num_frames - 1, num=min(num_frames, total_num_frames), endpoint=True,
713
+ ).astype(int)
714
+ return indices
715
+ elif frame_sample_mode == "fps":
716
+ candidate_target_fps = get_candidate_target_fps(metadata.fps, sampling_fps)
717
+ selected_target_fps = get_target_fps(
718
+ metadata.fps,
719
+ num_frames,
720
+ total_num_frames,
721
+ frame_sample_mode,
722
+ candidate_target_fps,
723
+ )
724
+ _, indices = get_frame_times_and_chosen_fps(
725
+ selected_target_fps,
726
+ total_num_frames,
727
+ num_frames,
728
+ metadata.fps,
729
+ )
730
+ return indices
731
+ else:
732
+ raise NotImplementedError(frame_sample_mode)
733
+
734
+ def fetch_videos(
735
+ self,
736
+ video_url_or_urls: Union[str, list[str], list[list[str]]],
737
+ sample_timestamps_fn=None
738
+ ):
739
+ """
740
+ Convert a single or a list of urls into the corresponding `np.array` objects.
741
+
742
+ If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
743
+ returned.
744
+ """
745
+ if (
746
+ (not is_decord_available())
747
+ and (not is_torchcodec_available())
748
+ and (not is_av_available())
749
+ ):
750
+ raise ImportError(
751
+ "MolmoAct2VideoProcessor requires `decord`, `torchcodec`, or `av` to be installed."
752
+ )
753
+
754
+ if is_decord_available():
755
+ backend = "decord"
756
+ elif is_torchcodec_available():
757
+ warnings.warn(
758
+ "`decord` is not installed and cannot be used to decode the video by default. "
759
+ "Falling back to `torchcodec`."
760
+ )
761
+ backend = "torchcodec"
762
+ else:
763
+ warnings.warn(
764
+ "`decord` is not installed and cannot be used to decode the video by default. "
765
+ "Falling back to `PyAV`."
766
+ )
767
+ backend = "pyav"
768
+
769
+ if isinstance(video_url_or_urls, list):
770
+ return list(zip(*[self.fetch_videos(x, sample_timestamps_fn=sample_timestamps_fn) for x in video_url_or_urls]))
771
+ else:
772
+ return load_video(video_url_or_urls, backend=backend, sample_timestamps_fn=sample_timestamps_fn)
773
+
774
+ def _decode_and_sample_videos(
775
+ self,
776
+ videos: VideoInput,
777
+ video_metadata: Union[VideoMetadata, dict],
778
+ do_sample_frames: Optional[bool] = None,
779
+ sample_indices_fn: Optional[Callable] = None,
780
+ sample_timestamps_fn: Optional[Callable] = None,
781
+ ):
782
+ """
783
+ Decode input videos and sample frames if needed.
784
+ """
785
+ videos = make_batched_videos(videos)
786
+ video_metadata = make_batched_metadata(videos, video_metadata=video_metadata)
787
+
788
+ # Framed-based sampling if an array video is passed
789
+ # Otherwise, time-based sampling with decoding
790
+ if is_valid_video(videos[0]) and do_sample_frames:
791
+ assert video_metadata[0].fps is not None, "FPS must be provided for video input"
792
+ sampled_videos = []
793
+ sampled_metadata = []
794
+ for video, metadata in zip(videos, video_metadata):
795
+ indices = sample_indices_fn(metadata=metadata)
796
+ metadata.frames_indices = indices
797
+ sampled_videos.append(video[indices])
798
+ sampled_metadata.append(metadata)
799
+ videos = sampled_videos
800
+ video_metadata = sampled_metadata
801
+ elif not is_valid_video(videos[0]):
802
+ if sample_indices_fn is None:
803
+ logger.warning(
804
+ "do_sample_frames is False, but video array is not provided: "
805
+ "Will decode the video and sample frames using MolmoAct2's default sampling mode"
806
+ )
807
+ if isinstance(videos[0], list):
808
+ raise ValueError(
809
+ "A list of images is not supported for video input!"
810
+ )
811
+ else:
812
+ videos, video_metadata = self.fetch_videos(videos, sample_timestamps_fn=sample_timestamps_fn)
813
+
814
+ return videos, video_metadata
815
+
816
+ def _prepare_input_videos(
817
+ self,
818
+ videos: VideoInput,
819
+ **kwargs,
820
+ ) -> list[np.ndarray]:
821
+ processed_videos = [to_numpy(video) for video in videos]
822
+ return processed_videos
823
+
824
+ def preprocess(
825
+ self,
826
+ videos: VideoInput,
827
+ **kwargs: Unpack[MolmoAct2VideoProcessorKwargs],
828
+ ) -> BatchFeature:
829
+ validate_kwargs(
830
+ captured_kwargs=kwargs.keys(),
831
+ valid_processor_keys=list(self.valid_kwargs.__annotations__.keys()) + ["return_tensors"],
832
+ )
833
+
834
+ # Set default kwargs from self. This ensures that if a kwarg is not provided
835
+ # by the user, it gets its default value from the instance, or is set to None.
836
+ for kwarg_name in self.valid_kwargs.__annotations__:
837
+ kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
838
+
839
+ do_sample_frames = kwargs.pop("do_sample_frames")
840
+ video_metadata = kwargs.pop("video_metadata")
841
+
842
+ sample_indices_fn = partial(self.sample_frames, **kwargs) if do_sample_frames else None
843
+ sample_timestamps_fn = partial(self.sample_times, **kwargs)
844
+ videos, video_metadata = self._decode_and_sample_videos(
845
+ videos,
846
+ video_metadata=video_metadata,
847
+ do_sample_frames=do_sample_frames,
848
+ sample_indices_fn=sample_indices_fn,
849
+ sample_timestamps_fn=sample_timestamps_fn,
850
+ )
851
+ videos = self._prepare_input_videos(videos=videos)
852
+
853
+ kwargs = self._further_process_kwargs(**kwargs)
854
+
855
+ return_metadata = kwargs.pop("return_metadata")
856
+ preprocessed_videos = self._preprocess(videos=videos, **kwargs)
857
+ if return_metadata:
858
+ preprocessed_videos["video_metadata"] = video_metadata
859
+ return preprocessed_videos
860
+
861
+ def _preprocess(
862
+ self,
863
+ videos: list[np.ndarray],
864
+ size: Optional[SizeDict] = None,
865
+ resample: Optional[PILImageResampling] = None,
866
+ image_mean: Optional[Union[float, list[float]]] = None,
867
+ image_std: Optional[Union[float, list[float]]] = None,
868
+ do_convert_rgb: Optional[bool] = None,
869
+ patch_size: Optional[int] = None,
870
+ pooling_size: Optional[list[int]] = None,
871
+ return_tensors: Optional[Union[str, TensorType]] = None,
872
+ **kwargs,
873
+ ) -> BatchFeature:
874
+ """
875
+ Preprocess a video for the model.
876
+ Args:
877
+ videos (`VideoInput`):
878
+ Video to preprocess.
879
+ size (`SizeDict`, *optional*, defaults to `self.size`):
880
+ Size of the image after resizing.
881
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
882
+ Resampling filter to use when resizing the image. This can be one of the enum `PILImageResampling`. Only
883
+ has an effect if `do_resize` is set to `True`.
884
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
885
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
886
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
887
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
888
+ `True`.
889
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
890
+ Whether to convert the image to RGB.
891
+ patch_size (`int`, *optional*, defaults to `self.patch_size`):
892
+ The spatial patch size of the vision encoder.
893
+ pooling_size (`list[int]`, *optional*, defaults to `self.pooling_size`):
894
+ The pooling size of the vision adapter.
895
+ return_tensors (`str` or `TensorType`, *optional*):
896
+ The type of tensors to return. Can be one of:
897
+ - Unset: Return a list of `np.ndarray`.
898
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
899
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
900
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
901
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
902
+
903
+ Returns:
904
+ A `BatchFeature` containing the following keys:
905
+ - `pixel_values_videos`: The preprocessed videos.
906
+ - `video_token_pooling`: The indices of the patches in `crops` to pool for each token in `video_tokens`.
907
+ - `video_grids`: The video grids.
908
+ """
909
+ if size.height is None or size.width is None:
910
+ raise ValueError("size must contain 'height' and 'width' keys.")
911
+
912
+ base_image_input_size = [size.height, size.width]
913
+
914
+ resample = resample or self.resample
915
+ image_mean = image_mean or self.image_mean
916
+ image_std = image_std or self.image_std
917
+ do_convert_rgb = do_convert_rgb or self.do_convert_rgb
918
+
919
+ patch_size = patch_size or self.patch_size
920
+ pooling_size = pooling_size or self.pooling_size
921
+
922
+ image_pooling_h, image_pooling_w = pooling_size
923
+
924
+ batch_grids = []
925
+ batch_crops = []
926
+ batch_pooled_patches_idx = []
927
+
928
+ for video in videos:
929
+ all_crops = []
930
+ pooled_patches_idx = []
931
+
932
+ for frame in video:
933
+ image_grid, crops, pooled_idx = image_to_patches_and_grids(
934
+ frame,
935
+ base_image_input_size,
936
+ resample,
937
+ image_mean,
938
+ image_std,
939
+ patch_size,
940
+ image_pooling_w,
941
+ image_pooling_h,
942
+ )
943
+ offset = sum(np.prod(x.shape[:2]) for x in all_crops)
944
+ pooled_idx_with_offset = np.where(pooled_idx >= 0, pooled_idx + offset, pooled_idx)
945
+ pooled_patches_idx.append(pooled_idx_with_offset)
946
+ all_crops.append(crops)
947
+
948
+ video_grid = np.array([len(video), image_grid[0], image_grid[1]])
949
+ all_crops = np.concatenate(all_crops, 0)
950
+ pooled_patches_idx = np.concatenate(pooled_patches_idx, 0)
951
+
952
+ batch_grids.append(video_grid)
953
+ batch_crops.append(all_crops)
954
+ batch_pooled_patches_idx.append(pooled_patches_idx)
955
+
956
+ video_grids = np.stack(batch_grids, 0)
957
+ pixel_values_videos = np.concatenate(batch_crops, 0)
958
+ video_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
959
+
960
+ data =dict(
961
+ pixel_values_videos=pixel_values_videos,
962
+ video_token_pooling=video_token_pooling,
963
+ video_grids=video_grids,
964
+ )
965
+
966
+ return BatchFeature(data, tensor_type=return_tensors)
967
+
968
+
969
+ MolmoAct2VideoProcessor.register_for_auto_class()