hqfang commited on
Commit
09ddc2a
·
verified ·
1 Parent(s): 57bfd8f

Add files using upload-large-folder tool

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,159 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags:
4
+ - molmoact2
5
+ - robotics
6
+ - image-text-to-text
7
+ - libero
8
+ ---
9
+
10
+ <img src="assets/MolmoAct2.svg" alt="MolmoAct Logo" style="width: auto; height: 50px;">
11
+
12
+ # **MolmoAct2-LIBERO**
13
+
14
+ MolmoAct2 is an open vision-language-action model for robot control. It builds on Molmo2-ER and attaches a flow-matching continuous action expert that conditions on the VLM key-value cache through a per-layer connection.
15
+
16
+ This checkpoint is fine-tuned on the full LIBERO training mixture, combining Spatial, Object, Goal, and Long suites. It is intended for both further fine-tuning and LIBERO policy inference.
17
+
18
+ ## Quick Links
19
+
20
+ - 📂 Models: [Models](https://huggingface.co/collections/allenai/molmoact2-models), [Finetuned Models](https://huggingface.co/collections/allenai/molmoact2-finetuned-models)
21
+ - 📂 Datasets: [MolmoAct2-BimanualYAM Dataset](https://huggingface.co/collections/allenai/molmoact2-datasets), [MolmoAct2 Datasets](https://huggingface.co/collections/allenai/molmoact2-datasets), [Molmo2-ER Datasets](https://huggingface.co/collections/allenai/molmo2-er-datasets)
22
+ - 📄 Paper:
23
+ - 💻 Code: [allenai/molmoact2](https://github.com/allenai/molmoact2)
24
+ - 🎥 Blog Post: [MolmoAct2](https://allenai.org/blog/molmoact2)
25
+
26
+ ## Intended Use
27
+
28
+ Use this checkpoint for LIBERO inference or for further fine-tuning. Dataset normalization metadata is stored in `norm_stats.json`; pass `norm_tag="libero"` at inference time.
29
+
30
+ Continuous action prediction is the intended and recommended inference mode. Discrete action prediction is exposed for parity and debugging, but we use continuous actions by default.
31
+
32
+ ## Install
33
+
34
+ ```bash
35
+ pip install torch transformers pillow numpy huggingface_hub
36
+ ```
37
+
38
+ ## Sample Input
39
+
40
+ This sample comes from `libero_10`, episode 0, frame 0. The LIBERO camera order is front/agent view followed by wrist view.
41
+
42
+ | Agentview RGB | Wrist RGB |
43
+ | --- | --- |
44
+ | ![Sample agentview RGB](assets/sample_agentview_rgb.png) | ![Sample wrist RGB](assets/sample_wrist_rgb.png) |
45
+
46
+ ```python
47
+ from huggingface_hub import hf_hub_download
48
+ from PIL import Image
49
+ import numpy as np
50
+
51
+ repo_id = "allenai/MolmoAct2-LIBERO"
52
+
53
+ agentview_rgb = Image.open(
54
+ hf_hub_download(repo_id, "assets/sample_agentview_rgb.png")
55
+ ).convert("RGB")
56
+ wrist_rgb = Image.open(
57
+ hf_hub_download(repo_id, "assets/sample_wrist_rgb.png")
58
+ ).convert("RGB")
59
+
60
+ task = "put the white mug on the left plate and put the yellow and white mug on the right plate"
61
+ robot_state = np.array(
62
+ [
63
+ -0.05338004603981972,
64
+ 0.007029631175100803,
65
+ 0.6783280968666077,
66
+ 3.1407692432403564,
67
+ 0.0017593271331861615,
68
+ -0.08994418382644653,
69
+ 0.03878866136074066,
70
+ -0.03878721222281456,
71
+ ],
72
+ dtype=np.float32,
73
+ )
74
+ ```
75
+
76
+ ## Continuous Actions
77
+
78
+ ```python
79
+ import numpy as np
80
+ import torch
81
+ from huggingface_hub import hf_hub_download
82
+ from PIL import Image
83
+ from transformers import AutoModelForImageTextToText, AutoProcessor
84
+
85
+ repo_id = "allenai/MolmoAct2-LIBERO"
86
+
87
+ agentview_rgb = Image.open(
88
+ hf_hub_download(repo_id, "assets/sample_agentview_rgb.png")
89
+ ).convert("RGB")
90
+ wrist_rgb = Image.open(
91
+ hf_hub_download(repo_id, "assets/sample_wrist_rgb.png")
92
+ ).convert("RGB")
93
+ task = "put the white mug on the left plate and put the yellow and white mug on the right plate"
94
+ robot_state = np.array(
95
+ [
96
+ -0.05338004603981972,
97
+ 0.007029631175100803,
98
+ 0.6783280968666077,
99
+ 3.1407692432403564,
100
+ 0.0017593271331861615,
101
+ -0.08994418382644653,
102
+ 0.03878866136074066,
103
+ -0.03878721222281456,
104
+ ],
105
+ dtype=np.float32,
106
+ )
107
+
108
+ processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
109
+ model = AutoModelForImageTextToText.from_pretrained(
110
+ repo_id,
111
+ trust_remote_code=True,
112
+ torch_dtype=torch.float32,
113
+ ).to("cuda").eval()
114
+
115
+ out = model.predict_action(
116
+ processor=processor,
117
+ images=[agentview_rgb, wrist_rgb],
118
+ task=task,
119
+ state=robot_state,
120
+ norm_tag="libero",
121
+ action_mode="continuous",
122
+ enable_depth_reasoning=False,
123
+ num_steps=10,
124
+ normalize_language=True,
125
+ enable_cuda_graph=True,
126
+ )
127
+
128
+ actions = out.actions
129
+ ```
130
+
131
+ `images` should preserve camera order, for example `[agentview_rgb, wrist_rgb]`. Images may be PIL images or RGB arrays. `state` is the raw robot state, and actions are returned in robot scale.
132
+
133
+ `normalize_language=True` is the default. It lowercases the task string and removes trailing sentence punctuation to match training preprocessing. Set it to `False` if you need to preserve the task text exactly.
134
+
135
+ `enable_cuda_graph=True` is the default. The first few calls can be slow because the model warms up and captures CUDA graphs; run several random warm-up calls before measuring deployment latency. `num_steps` controls the continuous flow solver and defaults to the checkpoint config value, 10.
136
+
137
+ Depth reasoning is disabled for this checkpoint. Calling `enable_depth_reasoning=True` will raise an error.
138
+
139
+ ## Discrete Actions
140
+
141
+ Discrete action inference requires a caller-provided action tokenizer. It is not saved in this repository. Discrete mode decodes action tokens directly; the continuous action expert is not used.
142
+
143
+ ```python
144
+ action_tokenizer = AutoProcessor.from_pretrained(
145
+ "YOUR_ACTION_TOKENIZER_REPO",
146
+ trust_remote_code=True,
147
+ )
148
+
149
+ out = model.predict_action(
150
+ processor=processor,
151
+ images=[agentview_rgb, wrist_rgb],
152
+ task=task,
153
+ state=robot_state,
154
+ norm_tag="libero",
155
+ action_mode="discrete",
156
+ action_tokenizer=action_tokenizer,
157
+ enable_depth_reasoning=False,
158
+ )
159
+ ```
assets/MolmoAct2.svg ADDED
assets/sample_agentview_rgb.png ADDED
assets/sample_wrist_rgb.png ADDED
chat_template.jinja ADDED
@@ -0,0 +1 @@
 
 
1
+ {% set DEMO_STYLES = ['point_count','pointing','cosyn_point','user_qa','long_caption','short_caption','video_long_caption','video_short_caption','video_point_track_per_frame','video_point_track_start_end','video_point_track_all_frames','video_single_point_track_start_end','video_transcript','video_clip_caption_start_end','video_clip_caption_start_end_in_seconds','video_clip_transcript_start_end','video_clip_transcript_start_end_in_seconds','video_frame_caption_timestamp','video_frame_caption_timestamp_in_seconds','correction_qa','text_sft','video_point','video_point_count','video_count','video_count_point','multi_image_pointing','multi_image_counting','multi_image_point_then_count','multi_image_count_then_point','demo','a_okvqa_mc','ai2_diagram_no_letter','ai2_diagram','science_qa','multi_image_mc','multi_image_mc_exp','mantis_instruct_mc','video_multiple_choice','video_multiple_choice_count_without_pointing','video_multiple_choice_multiple_correct','video_multiple_choice_w_subtitle'] %}{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% set has_subtitle = messages and messages[0]['role'].lower() == 'subtitle' %}{% for message in messages %}{% if message['content'] is not string %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% elif content['type'] == 'video' or 'video' in content or 'video_url' in content %}{% set video_count.value = video_count.value + 1 %}{% endif %}{% endfor %}{% endif %}{% endfor %}{% if image_count.value == 1 %}{{ '<|image|>' }}{% elif image_count.value > 1 %}{% for i in range(image_count.value) %}{{ 'Image ' ~ (i + 1) ~ '<|image|>' }}{% endfor %}{% endif %}{% for _ in range(video_count.value) %}{{ '<|video|>' }}{% endfor %}{% if has_subtitle %}{{ messages[0]['content'] }}{% endif %}{% for message in messages %}{% set role = message['role'].lower() %}{% if role == 'subtitle' %}{% continue %}{% endif %}{% set conv_index = loop.index - (1 if has_subtitle else 0) %}{%- if (conv_index % 2 == 1 and role != 'user') or (conv_index % 2 == 0 and role != 'assistant') -%}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{%- endif -%}{% if message['content'] is string %}{% set text_content = message['content'] %}{% else %}{% set m = namespace(text='') %}{% for content in message['content'] %}{% if content['type'] == 'text' %}{% if content['style'] is defined and content['style'] not in DEMO_STYLES %}{% set seg = content['style'] ~ ': ' ~ content['text'] %}{% else %}{% set seg = content['text'] %}{% endif %}{% set m.text = m.text ~ ('' if not m.text else ' ') ~ seg %}{% endif %}{% endfor %}{% set text_content = m.text %}{% endif %}{% if role == 'user' %}{% if not (has_subtitle and loop.index == 2) and not (not has_subtitle and loop.first) %}{{ '<|im_end|>\n' }}{% endif %}{{ '<|im_start|>user\n' }}{{ text_content }}{{ '<|im_end|>\n' }}{% else %} {# assistant #}{{ '<|im_start|>assistant\n' }}{{ text_content }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}
config.json ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "action_end_token_id": 151933,
3
+ "action_expert_condition_source": "kv_cache",
4
+ "action_expert_config": {
5
+ "attn_dropout": 0.0,
6
+ "causal_attn": false,
7
+ "compile": "blocks",
8
+ "context_layer_norm": true,
9
+ "dropout": 0.0,
10
+ "ffn_multiple_of": 256,
11
+ "hidden_size": 768,
12
+ "implementation": "new",
13
+ "max_action_dim": 32,
14
+ "max_horizon": 32,
15
+ "mlp_ratio": 4.0,
16
+ "model_type": "molmoact2_action_expert",
17
+ "num_heads": 8,
18
+ "num_layers": 36,
19
+ "qk_norm": true,
20
+ "qk_norm_eps": 1e-06,
21
+ "rope": true,
22
+ "rope_on_cross_attention": true,
23
+ "timestep_embed_dim": 256
24
+ },
25
+ "action_expert_depth_gate": false,
26
+ "action_expert_depth_gate_init_bias": -4.0,
27
+ "action_expert_depth_gate_per_layer": false,
28
+ "action_expert_layer_mode": "per_layer",
29
+ "action_format": "both",
30
+ "action_horizon": 10,
31
+ "action_output_token_id": 151931,
32
+ "action_start_token_id": 151932,
33
+ "action_token_start_id": 151934,
34
+ "adapter_config": {
35
+ "attention_dropout": 0.0,
36
+ "attn_implementation": "sdpa",
37
+ "float32_attention": true,
38
+ "head_dim": 72,
39
+ "hidden_act": "silu",
40
+ "hidden_size": 1152,
41
+ "image_feature_dropout": 0.0,
42
+ "initializer_range": 0.02,
43
+ "intermediate_size": 9728,
44
+ "model_type": "molmoact2",
45
+ "num_attention_heads": 16,
46
+ "num_key_value_heads": 16,
47
+ "pooling_attention_mask": true,
48
+ "residual_dropout": 0.0,
49
+ "text_hidden_size": 2560,
50
+ "vit_layers": [
51
+ -3,
52
+ -9
53
+ ]
54
+ },
55
+ "add_action_expert": true,
56
+ "add_control_tokens": true,
57
+ "add_setup_tokens": true,
58
+ "architectures": [
59
+ "MolmoAct2ForConditionalGeneration"
60
+ ],
61
+ "auto_map": {
62
+ "AutoConfig": "configuration_molmoact2.MolmoAct2Config",
63
+ "AutoModelForImageTextToText": "modeling_molmoact2.MolmoAct2ForConditionalGeneration"
64
+ },
65
+ "depth_end_token_id": null,
66
+ "depth_mode": 2,
67
+ "depth_output_token_id": null,
68
+ "depth_start_token_id": null,
69
+ "depth_token_start_id": null,
70
+ "dtype": "float32",
71
+ "enable_depth_reasoning": false,
72
+ "flow_matching_beta_alpha": 1.0,
73
+ "flow_matching_beta_beta": 1.5,
74
+ "flow_matching_cutoff": 1.0,
75
+ "flow_matching_num_steps": 10,
76
+ "flow_matching_time_offset": 0.001,
77
+ "flow_matching_time_scale": 0.999,
78
+ "frame_end_token_id": 154632,
79
+ "frame_start_token_id": 154631,
80
+ "image_col_id": 154627,
81
+ "image_end_token_id": 154625,
82
+ "image_high_res_id": 154626,
83
+ "image_low_res_id": 154630,
84
+ "image_patch_id": 154626,
85
+ "image_start_token_id": 154624,
86
+ "initializer_range": 0.02,
87
+ "low_res_image_start_token_id": 154628,
88
+ "mask_action_dim_padding": true,
89
+ "max_action_dim": 32,
90
+ "model_type": "molmoact2",
91
+ "n_obs_steps": 1,
92
+ "norm_stats_filename": "norm_stats.json",
93
+ "num_action_tokens": 2048,
94
+ "num_depth_codes": 100,
95
+ "num_depth_tokens": 0,
96
+ "num_state_tokens": 256,
97
+ "state_end_token_id": 151674,
98
+ "state_format": "discrete",
99
+ "state_start_token_id": 151673,
100
+ "state_token_start_id": 151675,
101
+ "text_config": {
102
+ "additional_vocab_size": 128,
103
+ "attention_dropout": 0.0,
104
+ "attn_implementation": "sdpa",
105
+ "embedding_dropout": 0.0,
106
+ "head_dim": 128,
107
+ "hidden_act": "silu",
108
+ "hidden_size": 2560,
109
+ "initializer_range": 0.02,
110
+ "intermediate_size": 9728,
111
+ "layer_norm_eps": 1e-06,
112
+ "max_position_embeddings": 16384,
113
+ "model_type": "molmoact2_text",
114
+ "norm_after": false,
115
+ "num_attention_heads": 32,
116
+ "num_hidden_layers": 36,
117
+ "num_key_value_heads": 8,
118
+ "qk_norm_type": "qwen3",
119
+ "qkv_bias": false,
120
+ "residual_dropout": 0.0,
121
+ "rope_parameters": {
122
+ "rope_theta": 5000000.0,
123
+ "rope_type": "default"
124
+ },
125
+ "rope_scaling_layers": null,
126
+ "rope_theta": 5000000.0,
127
+ "tie_word_embeddings": false,
128
+ "use_cache": true,
129
+ "use_qk_norm": true,
130
+ "vocab_size": 154624
131
+ },
132
+ "tie_word_embeddings": false,
133
+ "transformers_version": "5.3.0",
134
+ "use_frame_special_tokens": true,
135
+ "vit_config": {
136
+ "attention_dropout": 0.0,
137
+ "attn_implementation": "sdpa",
138
+ "float32_attention": true,
139
+ "head_dim": 72,
140
+ "hidden_act": "gelu_pytorch_tanh",
141
+ "hidden_size": 1152,
142
+ "image_default_input_size": [
143
+ 378,
144
+ 378
145
+ ],
146
+ "image_num_pos": 729,
147
+ "image_patch_size": 14,
148
+ "initializer_range": 0.02,
149
+ "intermediate_size": 4304,
150
+ "layer_norm_eps": 1e-06,
151
+ "model_type": "molmoact2",
152
+ "num_attention_heads": 16,
153
+ "num_hidden_layers": 27,
154
+ "num_key_value_heads": 16,
155
+ "residual_dropout": 0.0
156
+ },
157
+ "bos_token_id": 151645,
158
+ "eos_token_id": 151645,
159
+ "pad_token_id": 151643
160
+ }
configuration_molmoact2.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MolmoAct2 configuration
3
+ """
4
+
5
+ from typing import Optional, Any
6
+
7
+ from transformers import PretrainedConfig
8
+ from transformers.modeling_rope_utils import rope_config_validation
9
+ from transformers.utils import logging
10
+
11
+ logger = logging.get_logger(__name__)
12
+
13
+
14
+ class MolmoAct2VitConfig(PretrainedConfig):
15
+ r"""
16
+ This is the configuration class to store the configuration of a [`MolmoAct2VisionTransformer`].
17
+ It is used to instantiate a `MolmoAct2VisionTransformer` according to the specified arguments,
18
+ defining the model architecture.
19
+
20
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
21
+ documentation from [`PretrainedConfig`] for more information.
22
+
23
+ Example:
24
+ ```python
25
+ >>> from transformers import MolmoAct2VitConfig, MolmoAct2VisionTransformer
26
+
27
+ >>> # Initializing a MolmoAct2VitConfig
28
+ >>> configuration = MolmoAct2VitConfig()
29
+
30
+ >>> # Initializing a MolmoAct2VisionTransformer (with random weights)
31
+ >>> model = MolmoAct2VisionTransformer(configuration)
32
+
33
+ >>> # Accessing the model configuration
34
+ >>> configuration = model.config
35
+ ```"""
36
+
37
+ model_type = "molmoact2"
38
+ base_config_key = "vit_config"
39
+
40
+ def __init__(
41
+ self,
42
+ hidden_size: int = 1152,
43
+ intermediate_size: int = 4304,
44
+ num_hidden_layers: int = 27,
45
+ num_attention_heads: int = 16,
46
+ num_key_value_heads: int = 16,
47
+ head_dim: int = 72,
48
+ hidden_act: str = "gelu_pytorch_tanh",
49
+ layer_norm_eps: float = 1e-6,
50
+ image_default_input_size: tuple[int, int] = (378, 378),
51
+ image_patch_size: int = 14,
52
+ image_num_pos: int = 577,
53
+ attention_dropout: float = 0.0,
54
+ residual_dropout: float = 0.0,
55
+ initializer_range: float = 0.02,
56
+ float32_attention: bool = True,
57
+ attn_implementation: str = "eager",
58
+ **kwargs,
59
+ ):
60
+ self.attn_implementation = attn_implementation
61
+ super().__init__(
62
+ attn_implementation=attn_implementation,
63
+ **kwargs
64
+ )
65
+ self.hidden_size = hidden_size
66
+ self.intermediate_size = intermediate_size
67
+ self.num_hidden_layers = num_hidden_layers
68
+ self.num_attention_heads = num_attention_heads
69
+ self.num_key_value_heads = num_key_value_heads
70
+ self.head_dim = head_dim
71
+ self.hidden_act = hidden_act
72
+ self.layer_norm_eps = layer_norm_eps
73
+ self.image_default_input_size = image_default_input_size
74
+ self.image_patch_size = image_patch_size
75
+ self.image_num_pos = image_num_pos
76
+ self.attention_dropout = attention_dropout
77
+ self.residual_dropout = residual_dropout
78
+ self.initializer_range = initializer_range
79
+ self.float32_attention = float32_attention
80
+
81
+ @property
82
+ def image_num_patch(self):
83
+ h, w = self.image_default_input_size
84
+ return h // self.image_patch_size, w // self.image_patch_size
85
+
86
+
87
+ class MolmoAct2AdapterConfig(PretrainedConfig):
88
+ r"""
89
+ This is the configuration class to store the configuration of MolmoAct2Adapter. With MolmoAct2VitConfig,
90
+ It is used to instantiate an MolmoAct2VisionBackbone according to the specified arguments,
91
+ defining the model architecture.
92
+
93
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
94
+ documentation from [`PretrainedConfig`] for more information.
95
+
96
+ Example:
97
+
98
+ ```python
99
+ >>> from transformers import MolmoAct2VitConfig, MolmoAct2AdapterConfig, MolmoAct2VisionBackbone
100
+
101
+ >>> # Initializing a MolmoAct2VitConfig and a MolmoAct2AdapterConfig
102
+ >>> vit_config = MolmoAct2VitConfig()
103
+ >>> adapter_config = MolmoPoolingConfig()
104
+
105
+ >>> # Initializing a MolmoAct2VisionBackbone (with random weights)
106
+ >>> model = MolmoAct2VisionBackbone(vit_config, adapter_config)
107
+
108
+ >>> # Accessing the model configuration
109
+ >>> vit_configuration = model.vit_config
110
+ >>> adapter_configuration = model.adapter_config
111
+ ```"""
112
+
113
+ model_type = "molmoact2"
114
+ base_config_key = "adapter_config"
115
+
116
+ def __init__(
117
+ self,
118
+ vit_layers: tuple = (-3, -9),
119
+ pooling_attention_mask: bool = False,
120
+ hidden_size: int = 1152,
121
+ num_attention_heads: int = 16,
122
+ num_key_value_heads: int = 16,
123
+ head_dim: int = 72,
124
+ float32_attention: bool = True,
125
+ attention_dropout: float = 0.0,
126
+ residual_dropout: float = 0.0,
127
+ hidden_act: str = "silu",
128
+ intermediate_size: int = 18944,
129
+ text_hidden_size: int = 3584,
130
+ image_feature_dropout: float = 0.0,
131
+ initializer_range: float = 0.02,
132
+ attn_implementation: str = "eager",
133
+ **kwargs,
134
+ ):
135
+ self.attn_implementation = attn_implementation
136
+ super().__init__(
137
+ attn_implementation=attn_implementation,
138
+ **kwargs
139
+ )
140
+ self.vit_layers = vit_layers
141
+ self.pooling_attention_mask = pooling_attention_mask
142
+ self.hidden_size = hidden_size
143
+ self.num_attention_heads = num_attention_heads
144
+ self.num_key_value_heads = num_key_value_heads
145
+ self.head_dim = head_dim
146
+ self.float32_attention = float32_attention
147
+ self.attention_dropout = attention_dropout
148
+ self.residual_dropout = residual_dropout
149
+ self.hidden_act = hidden_act
150
+ self.intermediate_size = intermediate_size
151
+ self.text_hidden_size = text_hidden_size
152
+ self.image_feature_dropout = image_feature_dropout
153
+ self.initializer_range = initializer_range
154
+
155
+
156
+ class MolmoAct2TextConfig(PretrainedConfig):
157
+ r"""
158
+ This is the configuration class to store the configuration of a [`MolmoAct2TextModel`]. It is used to instantiate a
159
+ `MolmoAct2TextModel` according to the specified arguments, defining the model architecture.
160
+
161
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
162
+ documentation from [`PretrainedConfig`] for more information.
163
+
164
+ Example:
165
+ ```python
166
+ >>> from transformers import MolmoAct2TextConfig, MolmoAct2TextModel
167
+
168
+ >>> # Initializing a MolmoAct2TextConfig
169
+ >>> configuration = MolmoAct2TextConfig()
170
+
171
+ >>> # Initializing a MolmoAct2TextModel (with random weights)
172
+ >>> model = MolmoAct2TextModel(configuration)
173
+
174
+ >>> # Accessing the model configuration
175
+ >>> configuration = model.config
176
+ ```"""
177
+
178
+ model_type = "molmoact2_text"
179
+ base_config_key = "text_config"
180
+ keys_to_ignore_at_inference = ["past_key_values"]
181
+ base_model_tp_plan = {
182
+ "blocks.*.self_attn.att_proj": "colwise",
183
+ "blocks.*.self_attn.attn_out": "rowwise",
184
+ "blocks.*.mlp.ff_proj": "colwise",
185
+ "blocks.*.mlp.ff_out": "rowwise",
186
+ }
187
+ base_model_pp_plan = {
188
+ "wte": (["input_ids"], ["inputs_embeds"]),
189
+ "blocks": (["hidden_states", "attention_mask"], ["hidden_states"]),
190
+ "ln_f": (["hidden_states"], ["hidden_states"]),
191
+ }
192
+
193
+ def __init__(
194
+ self,
195
+ hidden_size: int = 3584,
196
+ num_attention_heads: int = 28,
197
+ num_key_value_heads: Optional[int] = 4,
198
+ head_dim: int = 128,
199
+ vocab_size: int = 152064,
200
+ additional_vocab_size: int = 128,
201
+ qkv_bias: bool = True,
202
+ num_hidden_layers: int = 48,
203
+ intermediate_size: int = 18944,
204
+ hidden_act: str = "silu",
205
+ embedding_dropout: float=0.0,
206
+ attention_dropout: float=0.0,
207
+ residual_dropout: float = 0.0,
208
+ max_position_embeddings: int = 4096,
209
+ rope_theta: float = 1000000.0,
210
+ rope_scaling: dict[str, Any] = None,
211
+ rope_scaling_layers: Optional[list[int]] = None,
212
+ use_qk_norm: bool = False,
213
+ qk_norm_type: str = "olmo",
214
+ layer_norm_eps: int = 1e-6,
215
+ norm_after: bool = False,
216
+ initializer_range: float = 0.02,
217
+ use_cache=True,
218
+ tie_word_embeddings=False,
219
+ attn_implementation: str = "eager",
220
+ **kwargs,
221
+ ):
222
+ self.attn_implementation = attn_implementation
223
+ super().__init__(
224
+ tie_word_embeddings=tie_word_embeddings,
225
+ attn_implementation=attn_implementation,
226
+ **kwargs
227
+ )
228
+ self.hidden_size = hidden_size
229
+ self.num_attention_heads = num_attention_heads
230
+ if num_key_value_heads is None:
231
+ num_key_value_heads = num_attention_heads
232
+ self.num_key_value_heads = num_key_value_heads
233
+ self.head_dim = head_dim
234
+ self.vocab_size = vocab_size
235
+ self.additional_vocab_size = additional_vocab_size
236
+ self.qkv_bias = qkv_bias
237
+ self.num_hidden_layers = num_hidden_layers
238
+ self.intermediate_size = intermediate_size
239
+ self.hidden_act = hidden_act
240
+ self.embedding_dropout = embedding_dropout
241
+ self.attention_dropout = attention_dropout
242
+ self.residual_dropout = residual_dropout
243
+ self.max_position_embeddings = max_position_embeddings
244
+ self.rope_theta = rope_theta
245
+ self.rope_scaling = rope_scaling
246
+ self.rope_scaling_layers = rope_scaling_layers
247
+ self.use_qk_norm = use_qk_norm
248
+ self.qk_norm_type = qk_norm_type
249
+ self.layer_norm_eps = layer_norm_eps
250
+ self.norm_after = norm_after
251
+ self.initializer_range = initializer_range
252
+ self.use_cache = use_cache
253
+
254
+ # Validate the correctness of rotary position embeddings parameters
255
+ rope_config_validation(self)
256
+
257
+
258
+ class MolmoAct2ActionExpertConfig(PretrainedConfig):
259
+ r"""Configuration for the MolmoAct2 modern action expert."""
260
+
261
+ model_type = "molmoact2_action_expert"
262
+ base_config_key = "action_expert_config"
263
+
264
+ def __init__(
265
+ self,
266
+ max_horizon: int = 32,
267
+ max_action_dim: int = 14,
268
+ hidden_size: int = 1024,
269
+ num_layers: int = 32,
270
+ num_heads: int = 16,
271
+ mlp_ratio: float = 8.0 / 3.0,
272
+ ffn_multiple_of: int = 256,
273
+ timestep_embed_dim: int = 256,
274
+ dropout: float = 0.0,
275
+ attn_dropout: float = 0.0,
276
+ context_layer_norm: bool = True,
277
+ qk_norm: bool = True,
278
+ qk_norm_eps: float = 1e-6,
279
+ rope: bool = True,
280
+ rope_on_cross_attention: bool = False,
281
+ causal_attn: bool = False,
282
+ compile: str = "blocks",
283
+ implementation: str = "new",
284
+ **kwargs,
285
+ ):
286
+ super().__init__(**kwargs)
287
+ if implementation != "new":
288
+ raise ValueError(
289
+ "MolmoAct2 HF export supports only action_expert.implementation='new'."
290
+ )
291
+ self.max_horizon = max_horizon
292
+ self.max_action_dim = max_action_dim
293
+ self.hidden_size = hidden_size
294
+ self.num_layers = num_layers
295
+ self.num_heads = num_heads
296
+ self.mlp_ratio = mlp_ratio
297
+ self.ffn_multiple_of = ffn_multiple_of
298
+ self.timestep_embed_dim = timestep_embed_dim
299
+ self.dropout = dropout
300
+ self.attn_dropout = attn_dropout
301
+ self.context_layer_norm = context_layer_norm
302
+ self.qk_norm = qk_norm
303
+ self.qk_norm_eps = qk_norm_eps
304
+ self.rope = rope
305
+ self.rope_on_cross_attention = rope_on_cross_attention
306
+ self.causal_attn = causal_attn
307
+ self.compile = compile
308
+ self.implementation = implementation
309
+
310
+
311
+ class MolmoAct2Config(PretrainedConfig):
312
+ r"""
313
+ This is the configuration class to store the configuration of a [`MolmoAct2ForConditionalGeneration`].
314
+ It is used to instantiate an MolmoAct2 model according to the specified arguments, defining the model architecture.
315
+
316
+ Example:
317
+
318
+ ```python
319
+ >>> from transformers import MolmoAct2Config, MolmoAct2VitConfig, MolmoAct2AdapterConfig, MolmoAct2TextConfig
320
+
321
+ >>> # Initializing a MolmoAct2VitConfig
322
+ >>> vit_config = MolmoAct2VitConfig()
323
+
324
+ >>> # Initializing a MolmoAct2AdapterConfig
325
+ >>> adapter_config = MolmoAct2AdapterConfig()
326
+
327
+ >>> # Initializing a MolmoAct2TextConfig
328
+ >>> text_config = MolmoAct2TextConfig()
329
+
330
+ >>> # Initializing a MolmoAct2Config
331
+ >>> configuration = MolmoAct2Config(
332
+ >>> vit_config=vit_config,
333
+ >>> adapter_config=adapter_config,
334
+ >>> text_config=text_config,
335
+ >>> image_start_token_id=151936,
336
+ >>> image_end_token_id=151937,
337
+ >>> image_patch_id=151938,
338
+ >>> image_col_id=151939,
339
+ >>> low_res_image_start_token_id=151940,
340
+ >>> image_low_res_id=151942,
341
+ >>> frame_start_token_id=151943,
342
+ >>> frame_end_token_id=151944,
343
+ >>> )
344
+
345
+ >>> # Initializing a model
346
+ >>> model = MolmoAct2ForConditionalGeneration(configuration)
347
+
348
+ >>> # Accessing the model configuration
349
+ >>> configuration = model.config
350
+ ```"""
351
+
352
+ model_type = "molmoact2"
353
+ sub_configs = {
354
+ "text_config": MolmoAct2TextConfig,
355
+ "vit_config": MolmoAct2VitConfig,
356
+ "adapter_config": MolmoAct2AdapterConfig,
357
+ "action_expert_config": MolmoAct2ActionExpertConfig,
358
+ }
359
+
360
+ def __init__(
361
+ self,
362
+ vit_config: MolmoAct2VitConfig = None,
363
+ adapter_config: MolmoAct2AdapterConfig = None,
364
+ text_config: MolmoAct2TextConfig = None,
365
+ action_expert_config: MolmoAct2ActionExpertConfig = None,
366
+ image_start_token_id: int = None,
367
+ low_res_image_start_token_id: int = None,
368
+ image_end_token_id: int = None,
369
+ image_low_res_id: int = None,
370
+ image_patch_id: int = None,
371
+ image_col_id: int = None,
372
+ frame_start_token_id: int = None,
373
+ frame_end_token_id: int = None,
374
+ use_frame_special_tokens: bool = True,
375
+ initializer_range: float = 0.02,
376
+ add_action_expert: bool = True,
377
+ max_action_dim: int = 7,
378
+ action_horizon: int = 16,
379
+ n_obs_steps: int = 1,
380
+ action_format: str = "continuous",
381
+ state_format: str = "discrete",
382
+ action_expert_condition_source: str = "kv_cache",
383
+ action_expert_layer_mode: str = "per_layer",
384
+ flow_matching_num_steps: int = 10,
385
+ flow_matching_cutoff: float = 1.0,
386
+ flow_matching_time_offset: float = 0.001,
387
+ flow_matching_time_scale: float = 0.999,
388
+ flow_matching_beta_alpha: float = 1.0,
389
+ flow_matching_beta_beta: float = 1.5,
390
+ mask_action_dim_padding: bool = True,
391
+ enable_depth_reasoning: bool = False,
392
+ depth_mode: int = 2,
393
+ num_depth_codes: int = 100,
394
+ action_expert_depth_gate: bool = False,
395
+ action_expert_depth_gate_per_layer: bool = False,
396
+ action_expert_depth_gate_init_bias: float = -4.0,
397
+ action_output_token_id: int = None,
398
+ action_start_token_id: int = None,
399
+ action_end_token_id: int = None,
400
+ action_token_start_id: int = None,
401
+ num_action_tokens: int = 0,
402
+ depth_output_token_id: int = None,
403
+ depth_start_token_id: int = None,
404
+ depth_end_token_id: int = None,
405
+ depth_token_start_id: int = None,
406
+ num_depth_tokens: int = 0,
407
+ state_start_token_id: int = None,
408
+ state_end_token_id: int = None,
409
+ state_token_start_id: int = None,
410
+ num_state_tokens: int = 0,
411
+ add_setup_tokens: bool = True,
412
+ add_control_tokens: bool = True,
413
+ norm_stats_filename: str = "norm_stats.json",
414
+ **kwargs,
415
+ ):
416
+ super().__init__(**kwargs)
417
+ if vit_config is None:
418
+ self.vit_config = MolmoAct2VitConfig()
419
+ elif isinstance(vit_config, dict):
420
+ self.vit_config = MolmoAct2VitConfig(**vit_config)
421
+ else:
422
+ self.vit_config = vit_config
423
+ if adapter_config is None:
424
+ self.adapter_config = MolmoAct2AdapterConfig()
425
+ elif isinstance(adapter_config, dict):
426
+ self.adapter_config = MolmoAct2AdapterConfig(**adapter_config)
427
+ else:
428
+ self.adapter_config = adapter_config
429
+ if text_config is None:
430
+ self.text_config = MolmoAct2TextConfig()
431
+ elif isinstance(text_config, dict):
432
+ self.text_config = MolmoAct2TextConfig(**text_config)
433
+ else:
434
+ self.text_config = text_config
435
+ self.add_action_expert = bool(add_action_expert)
436
+ if not self.add_action_expert:
437
+ self.action_expert_config = None
438
+ elif action_expert_config is None:
439
+ self.action_expert_config = MolmoAct2ActionExpertConfig(
440
+ max_horizon=action_horizon,
441
+ max_action_dim=max_action_dim,
442
+ num_layers=self.text_config.num_hidden_layers,
443
+ )
444
+ elif isinstance(action_expert_config, dict):
445
+ self.action_expert_config = MolmoAct2ActionExpertConfig(**action_expert_config)
446
+ else:
447
+ self.action_expert_config = action_expert_config
448
+ if self.add_action_expert:
449
+ self._validate_release_action_config(
450
+ action_expert_config=self.action_expert_config,
451
+ action_expert_condition_source=action_expert_condition_source,
452
+ action_expert_layer_mode=action_expert_layer_mode,
453
+ state_format=state_format,
454
+ )
455
+ self.image_start_token_id = image_start_token_id
456
+ self.low_res_image_start_token_id = low_res_image_start_token_id
457
+ self.image_end_token_id = image_end_token_id
458
+ self.image_low_res_id = image_low_res_id
459
+ self.image_high_res_id = image_patch_id
460
+ self.image_patch_id = image_patch_id
461
+ self.image_col_id = image_col_id
462
+ self.frame_start_token_id = frame_start_token_id
463
+ self.frame_end_token_id = frame_end_token_id
464
+ self.use_frame_special_tokens = use_frame_special_tokens
465
+ self.initializer_range = initializer_range
466
+ self.max_action_dim = max_action_dim
467
+ self.action_horizon = action_horizon
468
+ self.n_obs_steps = n_obs_steps
469
+ self.action_format = action_format
470
+ self.state_format = state_format
471
+ self.action_expert_condition_source = action_expert_condition_source
472
+ self.action_expert_layer_mode = action_expert_layer_mode
473
+ self.flow_matching_num_steps = flow_matching_num_steps
474
+ self.flow_matching_cutoff = flow_matching_cutoff
475
+ self.flow_matching_time_offset = flow_matching_time_offset
476
+ self.flow_matching_time_scale = flow_matching_time_scale
477
+ self.flow_matching_beta_alpha = flow_matching_beta_alpha
478
+ self.flow_matching_beta_beta = flow_matching_beta_beta
479
+ self.mask_action_dim_padding = mask_action_dim_padding
480
+ self.enable_depth_reasoning = enable_depth_reasoning
481
+ self.depth_mode = depth_mode
482
+ self.num_depth_codes = num_depth_codes
483
+ self.action_expert_depth_gate = action_expert_depth_gate
484
+ self.action_expert_depth_gate_per_layer = action_expert_depth_gate_per_layer
485
+ self.action_expert_depth_gate_init_bias = action_expert_depth_gate_init_bias
486
+ self.action_output_token_id = action_output_token_id
487
+ self.action_start_token_id = action_start_token_id
488
+ self.action_end_token_id = action_end_token_id
489
+ self.action_token_start_id = action_token_start_id
490
+ self.num_action_tokens = num_action_tokens
491
+ self.depth_output_token_id = depth_output_token_id
492
+ self.depth_start_token_id = depth_start_token_id
493
+ self.depth_end_token_id = depth_end_token_id
494
+ self.depth_token_start_id = depth_token_start_id
495
+ self.num_depth_tokens = num_depth_tokens
496
+ self.state_start_token_id = state_start_token_id
497
+ self.state_end_token_id = state_end_token_id
498
+ self.state_token_start_id = state_token_start_id
499
+ self.num_state_tokens = num_state_tokens
500
+ self.add_setup_tokens = add_setup_tokens
501
+ self.add_control_tokens = add_control_tokens
502
+ self.norm_stats_filename = norm_stats_filename
503
+
504
+ @staticmethod
505
+ def _validate_release_action_config(
506
+ *,
507
+ action_expert_config: MolmoAct2ActionExpertConfig,
508
+ action_expert_condition_source: str,
509
+ action_expert_layer_mode: str,
510
+ state_format: str,
511
+ ) -> None:
512
+ if action_expert_config.implementation != "new":
513
+ raise ValueError(
514
+ "MolmoAct2 HF export supports only action_expert.implementation='new'."
515
+ )
516
+ if action_expert_condition_source != "kv_cache":
517
+ raise ValueError(
518
+ "MolmoAct2 HF export supports only action_expert_condition_source='kv_cache'."
519
+ )
520
+ if action_expert_layer_mode != "per_layer":
521
+ raise ValueError(
522
+ "MolmoAct2 HF export supports only action_expert_layer_mode='per_layer'."
523
+ )
524
+ if state_format != "discrete":
525
+ raise ValueError("MolmoAct2 HF export supports only state_format='discrete'.")
526
+
527
+ @property
528
+ def image_num_patch(self):
529
+ assert self.vit_config is not None
530
+ return self.vit_config.image_num_patch
531
+
532
+ @property
533
+ def num_attention_heads(self):
534
+ return self.text_config.num_attention_heads
535
+
536
+ @property
537
+ def num_key_value_heads(self):
538
+ return self.text_config.num_key_value_heads
539
+
540
+ @property
541
+ def head_dim(self):
542
+ return self.text_config.head_dim
543
+
544
+ @property
545
+ def num_hidden_layers(self):
546
+ return self.text_config.num_hidden_layers
547
+
548
+ @property
549
+ def hidden_size(self):
550
+ return self.text_config.hidden_size
551
+
552
+ @property
553
+ def vocab_size(self):
554
+ return self.text_config.vocab_size
555
+
556
+ @property
557
+ def max_position_embeddings(self):
558
+ return self.text_config.max_position_embeddings
559
+
560
+
561
+ MolmoAct2VitConfig.register_for_auto_class()
562
+ MolmoAct2AdapterConfig.register_for_auto_class()
563
+ MolmoAct2TextConfig.register_for_auto_class()
564
+ MolmoAct2ActionExpertConfig.register_for_auto_class()
565
+ MolmoAct2Config.register_for_auto_class()
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151645,
3
+ "eos_token_id": 151645,
4
+ "pad_token_id": 151643,
5
+ "transformers_version": "5.3.0"
6
+ }
image_processing_molmoact2.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image processor class for MolmoAct2"""
2
+ from typing import Optional, Union
3
+ import numpy as np
4
+ import einops
5
+ import torch
6
+ import torchvision.transforms
7
+
8
+ from transformers.image_utils import (
9
+ IMAGENET_STANDARD_MEAN,
10
+ IMAGENET_STANDARD_STD,
11
+ ImageInput,
12
+ PILImageResampling,
13
+ make_flat_list_of_images,
14
+ valid_images,
15
+ to_numpy_array,
16
+ )
17
+ from transformers.image_transforms import convert_to_rgb
18
+ from transformers.processing_utils import ImagesKwargs
19
+ from transformers.image_processing_utils import BaseImageProcessor, get_size_dict
20
+ from transformers.utils import logging
21
+ from transformers.feature_extraction_utils import BatchFeature
22
+ from transformers.utils import TensorType, logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ def normalize_image(
29
+ image: np.ndarray,
30
+ image_mean: list[float],
31
+ image_std: list[float],
32
+ ) -> np.ndarray:
33
+ if np.allclose(image_mean, [0.5, 0.5, 0.5]) and np.allclose(image_std, [0.5, 0.5, 0.5]):
34
+ return image * np.asarray(2.0, dtype=np.float32) - np.asarray(1.0, dtype=np.float32)
35
+ image -= np.array(image_mean, dtype=np.float32)[None, None, :]
36
+ image /= np.array(image_std, dtype=np.float32)[None, None, :]
37
+ return image
38
+
39
+
40
+ def resize_image(
41
+ image: np.ndarray,
42
+ desired_output_size: list[int],
43
+ resample: PILImageResampling,
44
+ ) -> np.ndarray:
45
+ image = torch.permute(torch.from_numpy(image), [2, 0, 1])
46
+ dtype = image.dtype
47
+ if torch.is_floating_point(image):
48
+ in_min = 0.0
49
+ in_max = 1.0
50
+ resized = torchvision.transforms.Resize(
51
+ desired_output_size,
52
+ resample,
53
+ antialias=False,
54
+ )(image)
55
+ resized = torch.clip(resized, 0.0, 1.0).to(dtype)
56
+ else:
57
+ assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(image.dtype)
58
+ in_min = 0.0
59
+ in_max = 255.0
60
+ resized = torchvision.transforms.Resize(
61
+ desired_output_size,
62
+ resample,
63
+ antialias=False,
64
+ )(image)
65
+ resized = torch.clip(resized, 0, 255).to(dtype)
66
+
67
+ resized = resized.to(torch.float32)
68
+ resized = (resized - in_min) / (in_max - in_min)
69
+
70
+ resized = torch.permute(resized, [1, 2, 0]).numpy()
71
+
72
+ return resized
73
+
74
+
75
+ def select_tiling(h, w, patch_size, max_num_crops):
76
+ """Divide in image of size [w, h] in up to max_num_patches of size patch_size"""
77
+ original_size = np.stack([h, w]) # [1, 2]
78
+ original_res = h * w
79
+ tilings = []
80
+ for i in range(1, max_num_crops + 1):
81
+ for j in range(1, max_num_crops + 1):
82
+ if i*j <= max_num_crops:
83
+ tilings.append((i, j))
84
+ # sort so argmin and argmax favour smaller tilings in the event of a tie
85
+ tilings.sort(key=lambda x: (x[0]*x[1], x[0]))
86
+ candidate_tilings = np.array(tilings, dtype=np.int32) # [n_resolutions, 2]
87
+ candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2]
88
+
89
+ # How much we would need to scale the image to fit exactly in each tiling
90
+ original_size = np.stack([h, w], dtype=np.float32) # [1, 2]
91
+
92
+ # The original size can be zero in rare cases if the image is smaller than the margin
93
+ # In those cases letting the scale become infinite means the tiling is based on the
94
+ # other side, or falls back to the smallest tiling
95
+ with np.errstate(divide='ignore'):
96
+ required_scale_d = candidate_resolutions.astype(np.float32) / original_size,
97
+ required_scale = np.min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1]
98
+ if np.all(required_scale < 1):
99
+ # We are forced to downscale, so try to minimize the amount of downscaling
100
+ ix = np.argmax(required_scale)
101
+ else:
102
+ # Pick the resolution that required the least upscaling so that it most closely fits the image
103
+ required_scale = np.where(required_scale < 1.0, 10e9, required_scale)
104
+ ix = np.argmin(required_scale)
105
+ return candidate_tilings[ix]
106
+
107
+
108
+ def build_resized_image(
109
+ image: np.ndarray,
110
+ base_image_input_size: list[int],
111
+ resample: PILImageResampling,
112
+ image_mean: list[float],
113
+ image_std: list[float],
114
+ image_patch_size: int,
115
+ ) -> tuple[np.ndarray, np.ndarray]:
116
+ resized = resize_image(
117
+ image, base_image_input_size, resample,
118
+ )
119
+ resized = normalize_image(resized, image_mean, image_std)
120
+ if len(resized.shape) == 3:
121
+ resized = np.expand_dims(resized, 0)
122
+ crop_patch_w = base_image_input_size[1] // image_patch_size
123
+ crop_patch_h = base_image_input_size[0] // image_patch_size
124
+ resize_idx = np.arange(crop_patch_w*crop_patch_h).reshape([crop_patch_h, crop_patch_w])
125
+ return resized, resize_idx
126
+
127
+
128
+ def build_overlapping_crops(
129
+ image: np.ndarray,
130
+ max_crops: int,
131
+ overlap_margins: list[int],
132
+ base_image_input_size: list[int],
133
+ resample: PILImageResampling,
134
+ image_mean: list[float],
135
+ image_std: list[float],
136
+ image_patch_size: int,
137
+ ) -> tuple[np.ndarray, np.ndarray]:
138
+ """Decompose an image into a set of overlapping crops
139
+
140
+ :return crop_arr: [n_crops, h, w, 3] The crops
141
+ :return patch_idx: [overlap_patch_h, overlap_patch_w] For each patch in the resized image
142
+ the crops were extracted from, what patch in `crop_arr` it corresponds to
143
+ """
144
+ original_image_h, original_image_w = image.shape[:2]
145
+ crop_size = base_image_input_size[0]
146
+ assert base_image_input_size[0] == base_image_input_size[1]
147
+
148
+ left_margin, right_margin = overlap_margins
149
+ total_margin_pixels = image_patch_size * (right_margin + left_margin) # pixels removed per dim
150
+ crop_patches = base_image_input_size[0] // image_patch_size # patches per crop dim
151
+ crop_window_patches = crop_patches - (right_margin + left_margin) # usable patches
152
+ crop_window_size = crop_window_patches * image_patch_size
153
+ crop_patch_w = base_image_input_size[1] // image_patch_size
154
+ crop_patch_h = base_image_input_size[0] // image_patch_size
155
+ original_image_h, original_image_w = image.shape[:2]
156
+ crop_size = base_image_input_size[0]
157
+
158
+ # Decide how to tile the image, to account for the overlap margins we compute the tiling
159
+ # as if we had an image without the margins and were using a crop size without the margins
160
+ tiling = select_tiling(
161
+ original_image_h - total_margin_pixels,
162
+ original_image_w - total_margin_pixels,
163
+ crop_window_size,
164
+ max_crops,
165
+ )
166
+
167
+ src = resize_image(
168
+ image,
169
+ [tiling[0]*crop_window_size+total_margin_pixels, tiling[1]*crop_window_size+total_margin_pixels],
170
+ resample,
171
+ )
172
+ src = normalize_image(src, image_mean, image_std)
173
+
174
+ # Now we have to split the image into crops, and track what patches came from
175
+ # where in `patch_idx_arr`
176
+ n_crops = tiling[0] * tiling[1]
177
+ crop_arr = np.zeros([n_crops, crop_size, crop_size, 3], dtype=src.dtype)
178
+ patch_idx_arr = np.zeros([n_crops, crop_patch_h, crop_patch_w], dtype=np.int32)
179
+ on_crop = 0
180
+ for i in range(tiling[0]):
181
+ # Slide over `src` by `crop_window_size` steps, but extract crops of size `crops_size`
182
+ # which results in overlapping crop windows
183
+ y0 = i*crop_window_size
184
+ for j in range(tiling[1]):
185
+ x0 = j*crop_window_size
186
+ crop_arr[on_crop] = src[y0:y0+crop_size, x0:x0+crop_size]
187
+ patch_idx = np.arange(crop_patch_w*crop_patch_h).reshape(crop_patch_h, crop_patch_w)
188
+ patch_idx += on_crop * crop_patch_h * crop_patch_w
189
+
190
+ # Mask out idx that are in the overlap region
191
+ if i != 0:
192
+ patch_idx[:left_margin, :] = -1
193
+ if j != 0:
194
+ patch_idx[:, :left_margin] = -1
195
+ if i != tiling[0]-1:
196
+ patch_idx[-right_margin:, :] = -1
197
+ if j != tiling[1]-1:
198
+ patch_idx[:, -right_margin:] = -1
199
+ patch_idx_arr[on_crop] = patch_idx
200
+ on_crop += 1
201
+
202
+ # `patch_idx_arr` is ordered crop-by-crop, here we transpose `patch_idx_arr`
203
+ # so it is ordered left-to-right order
204
+ patch_idx_arr = np.reshape(
205
+ patch_idx_arr,
206
+ [tiling[0], tiling[1], crop_patch_h, crop_patch_w]
207
+ )
208
+ patch_idx_arr = np.transpose(patch_idx_arr, [0, 2, 1, 3])
209
+ patch_idx_arr = np.reshape(patch_idx_arr, [-1])
210
+
211
+ # Now get the parts not in the overlap region, so it should map each patch in `src`
212
+ # to the correct patch it should come from in `crop_arr`
213
+ patch_idx_arr = patch_idx_arr[patch_idx_arr >= 0].reshape(
214
+ src.shape[0]//image_patch_size,
215
+ src.shape[1]//image_patch_size,
216
+ )
217
+ return crop_arr, patch_idx_arr
218
+
219
+
220
+ def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
221
+ """Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
222
+ if len(array.shape) == 3:
223
+ n_crops, h, w = array.shape
224
+ h_patches = h//patch_size
225
+ w_patches = w//patch_size
226
+ array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
227
+ array = np.transpose(array, [0, 1, 3, 2, 4])
228
+ array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size])
229
+ return array
230
+ else:
231
+ n_crops, h, w, c = array.shape
232
+ h_patches = h//patch_size
233
+ w_patches = w//patch_size
234
+ array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
235
+ array = np.transpose(array, [0, 1, 3, 2, 4, 5])
236
+ array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size*c])
237
+ return array
238
+
239
+
240
+ def arange_for_pooling(
241
+ idx_arr: np.ndarray,
242
+ pool_h: int,
243
+ pool_w: int,
244
+ ) -> np.ndarray:
245
+ h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0]
246
+ w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1]
247
+ idx_arr = np.pad(idx_arr, [[h_pad//2, (h_pad+1)//2], [w_pad//2, (w_pad+1)//2]],
248
+ mode='constant',constant_values=-1)
249
+ return einops.rearrange(
250
+ idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
251
+
252
+
253
+ def image_to_patches_and_grids(
254
+ image: np.ndarray,
255
+ max_crops: int,
256
+ overlap_margins: list[int],
257
+ base_image_input_size: list[int],
258
+ resample: PILImageResampling,
259
+ image_mean: list[float],
260
+ image_std: list[float],
261
+ image_patch_size: int,
262
+ image_pooling_w: int,
263
+ image_pooling_h: int,
264
+ crop_mode: str = "overlap-and-resize-c2",
265
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
266
+ """
267
+ :return image_grids, the shape of each (low-res, high-res) image after pooling
268
+ :return crops, the image crops to processes with the ViT
269
+ :return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the
270
+ patches in `crops` to pool for that token, masked with -1
271
+ """
272
+ if isinstance(base_image_input_size, int):
273
+ base_image_input_size = (base_image_input_size, base_image_input_size)
274
+
275
+ base_image_input_d = image_patch_size
276
+ pooling_w = image_pooling_w
277
+ pooling_h = image_pooling_h
278
+ crop_patch_w = base_image_input_size[1] // base_image_input_d
279
+ crop_patch_h = base_image_input_size[0] // base_image_input_d
280
+
281
+ if crop_mode == "resize":
282
+ resized, resize_idx = build_resized_image(
283
+ image,
284
+ base_image_input_size,
285
+ resample,
286
+ image_mean,
287
+ image_std,
288
+ image_patch_size,
289
+ )
290
+ resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
291
+ resized_h, resized_w = resize_idx.shape[:2]
292
+ resize_idx = resize_idx.reshape([-1, pooling_h * pooling_w])
293
+ image_grid = [np.array([resized_h, resized_w, 0, 0])]
294
+ return (
295
+ np.stack(image_grid, 0),
296
+ batch_pixels_to_patches(resized, image_patch_size),
297
+ resize_idx,
298
+ )
299
+
300
+ if crop_mode not in {"overlap-and-resize-c2", "overlap-and-resize"}:
301
+ raise ValueError(f"Unsupported MolmoAct2 image crop_mode {crop_mode!r}.")
302
+
303
+ crop_arr, patch_idx_arr = build_overlapping_crops(
304
+ image,
305
+ max_crops,
306
+ overlap_margins,
307
+ base_image_input_size,
308
+ resample,
309
+ image_mean,
310
+ image_std,
311
+ image_patch_size,
312
+ )
313
+ pooling_idx = arange_for_pooling(patch_idx_arr, pooling_h, pooling_w)
314
+ h, w = pooling_idx.shape[:2]
315
+ pooling_idx = pooling_idx.reshape([-1, pooling_h*pooling_w])
316
+
317
+ # Finally do the same for the global image
318
+ resized, resize_idx = build_resized_image(
319
+ image,
320
+ base_image_input_size,
321
+ resample,
322
+ image_mean,
323
+ image_std,
324
+ image_patch_size,
325
+ )
326
+ crop_arr = np.concatenate([resized, crop_arr], 0)
327
+
328
+ resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
329
+ resized_h, resized_w = resize_idx.shape[:2]
330
+ resize_idx = resize_idx.reshape([-1, pooling_h*pooling_w])
331
+
332
+ # Global image goes first, so the order of patches in previous crops gets increased
333
+ pooling_idx = np.where(
334
+ pooling_idx >= 0,
335
+ pooling_idx + crop_patch_h*crop_patch_w,
336
+ -1
337
+ )
338
+ pooling_idx = np.concatenate([resize_idx, pooling_idx])
339
+ image_grid = [np.array([resized_h, resized_w, h, w])]
340
+
341
+ return (
342
+ np.stack(image_grid, 0),
343
+ batch_pixels_to_patches(crop_arr, image_patch_size),
344
+ pooling_idx
345
+ )
346
+
347
+
348
+ class MolmoAct2ImagesKwargs(ImagesKwargs, total=False):
349
+ max_crops: Optional[int]
350
+ overlap_margins: Optional[list[int]]
351
+ crop_mode: Optional[str]
352
+ patch_size: Optional[int]
353
+ pooling_size: Optional[list[int]]
354
+
355
+
356
+ class MolmoAct2ImageProcessor(BaseImageProcessor):
357
+ r"""
358
+ Constructs a MolmoAct2 image processor that preprocesses images for the model.
359
+
360
+ Args:
361
+ size (`dict[str, int]` *optional*, defaults to `{"height": 378, "width": 378}`):
362
+ Size of the image after resizing.
363
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
364
+ Resampling filter to use when resizing the image.
365
+ image_mean (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
366
+ Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
367
+ image_std (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
368
+ Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
369
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
370
+ Whether to convert the image to RGB.
371
+ max_crops (`int`, *optional*, defaults to `8`):
372
+ Maximum number of crops to use per image.
373
+ overlap_margins (`list[int]`, *optional*, defaults to `[4, 4]`):
374
+ Overlap margins to use.
375
+ patch_size (`int`, *optional*, defaults to 14):
376
+ The spatial patch size of the vision encoder.
377
+ pooling_size (`list[int]`, *optional*, defaults to `[2, 2]`):
378
+ The pooling size of the vision adapter.
379
+ """
380
+
381
+ model_input_names = ["pixel_values", "image_token_pooling", "image_grids", "image_num_crops"]
382
+
383
+ def __init__(
384
+ self,
385
+ size: Optional[dict[str, int]] = None,
386
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
387
+ image_mean: Optional[Union[float, list[float]]] = None,
388
+ image_std: Optional[Union[float, list[float]]] = None,
389
+ do_convert_rgb: bool = True,
390
+ max_crops: int = 8,
391
+ overlap_margins: list[int] = [4, 4],
392
+ crop_mode: str = "overlap-and-resize-c2",
393
+ patch_size: int = 14,
394
+ pooling_size: list[int] = [2, 2],
395
+ **kwargs,
396
+ ) -> None:
397
+ super().__init__(**kwargs)
398
+ size = size if size is not None else {"height": 378, "width": 378}
399
+ size = get_size_dict(size, default_to_square=True)
400
+ self.size = size
401
+
402
+ self.resample = resample
403
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
404
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
405
+ self.do_convert_rgb = do_convert_rgb
406
+
407
+ self.max_crops = max_crops
408
+ self.overlap_margins = overlap_margins
409
+ self.crop_mode = crop_mode
410
+ self.patch_size = patch_size
411
+ self.pooling_size = pooling_size
412
+
413
+ def preprocess(
414
+ self,
415
+ images: ImageInput,
416
+ size: Optional[dict[str, int]] = None,
417
+ resample: Optional[PILImageResampling] = None,
418
+ image_mean: Optional[Union[float, list[float]]] = None,
419
+ image_std: Optional[Union[float, list[float]]] = None,
420
+ do_convert_rgb: Optional[bool] = None,
421
+ max_crops: Optional[int] = None,
422
+ overlap_margins: Optional[list[int]] = None,
423
+ crop_mode: Optional[str] = None,
424
+ patch_size: Optional[int] = None,
425
+ pooling_size: Optional[list[int]] = None,
426
+ return_tensors: Optional[Union[str, TensorType]] = None,
427
+ **kwargs,
428
+ ) -> BatchFeature:
429
+ """
430
+ Args:
431
+ images (`ImageInput`):
432
+ Image to preprocess.
433
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
434
+ Size of the image after resizing.
435
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
436
+ Resampling filter to use when resizing the image. This can be one of the enum `PILImageResampling`. Only
437
+ has an effect if `do_resize` is set to `True`.
438
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
439
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
440
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
441
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
442
+ `True`.
443
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
444
+ Whether to convert the image to RGB.
445
+ max_crops (`int`, *optional*, defaults to `self.max_crops`):
446
+ Maximum number of crops to use per image.
447
+ overlap_margins (`list[int]`, *optional*, defaults to `self.overlap_margins`):
448
+ Overlap margins to use.
449
+ patch_size (`int`, *optional*, defaults to `self.patch_size`):
450
+ The spatial patch size of the vision encoder.
451
+ pooling_size (`list[int]`, *optional*, defaults to `self.pooling_size`):
452
+ The pooling size of the vision adapter.
453
+ return_tensors (`str` or `TensorType`, *optional*):
454
+ The type of tensors to return. Can be one of:
455
+ - Unset: Return a list of `np.ndarray`.
456
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
457
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
458
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
459
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
460
+
461
+ Returns:
462
+ A `BatchFeature` containing the following keys:
463
+ - `pixel_values`: The preprocessed images.
464
+ - `image_token_pooling`: The indices of the patches in `crops` to pool for each token in `image_tokens`.
465
+ - `image_grids`: The image grids.
466
+ - `image_num_crops`: The number of crops for each image.
467
+ """
468
+ if size is not None:
469
+ if "height" not in size or "width" not in size:
470
+ raise ValueError("size must contain 'height' and 'width' keys.")
471
+ else:
472
+ size = {**self.size}
473
+
474
+ base_image_input_size = [size["height"], size["width"]]
475
+
476
+ resample = resample or self.resample
477
+ image_mean = image_mean or self.image_mean
478
+ image_std = image_std or self.image_std
479
+ do_convert_rgb = do_convert_rgb or self.do_convert_rgb
480
+
481
+ max_crops = max_crops or self.max_crops
482
+ overlap_margins = overlap_margins or self.overlap_margins
483
+ crop_mode = crop_mode or self.crop_mode
484
+ patch_size = patch_size or self.patch_size
485
+ pooling_size = pooling_size or self.pooling_size
486
+
487
+ image_pooling_h, image_pooling_w = pooling_size
488
+
489
+ if images is not None:
490
+ images = self.fetch_images(images)
491
+ images = make_flat_list_of_images(images)
492
+
493
+ if images is not None and not valid_images(images):
494
+ raise ValueError(
495
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
496
+ "torch.Tensor, tf.Tensor or jax.ndarray."
497
+ )
498
+
499
+ if do_convert_rgb:
500
+ images = [convert_to_rgb(image) for image in images]
501
+
502
+ # All transformations expect numpy arrays.
503
+ images = [to_numpy_array(image) for image in images]
504
+
505
+ data = {}
506
+ if images is not None:
507
+ batch_grids = []
508
+ batch_crops = []
509
+ batch_pooled_patches_idx = []
510
+ batch_num_crops = []
511
+
512
+ for image in images:
513
+ image_grid, crops, pooled_idx = image_to_patches_and_grids(
514
+ image,
515
+ max_crops,
516
+ overlap_margins,
517
+ base_image_input_size,
518
+ resample,
519
+ image_mean,
520
+ image_std,
521
+ patch_size,
522
+ image_pooling_w,
523
+ image_pooling_h,
524
+ crop_mode,
525
+ )
526
+ batch_grids.append(image_grid)
527
+ batch_crops.append(crops)
528
+ batch_pooled_patches_idx.append(pooled_idx)
529
+ batch_num_crops.append(crops.shape[0])
530
+
531
+ pixel_values = np.concatenate(batch_crops, 0)
532
+ image_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
533
+ image_grids = np.concatenate(batch_grids, 0)
534
+ image_num_crops = np.array(batch_num_crops)
535
+
536
+ data.update(
537
+ pixel_values=pixel_values,
538
+ image_token_pooling=image_token_pooling,
539
+ image_grids=image_grids,
540
+ image_num_crops=image_num_crops,
541
+ )
542
+
543
+ return BatchFeature(data, tensor_type=return_tensors)
544
+
545
+
546
+ MolmoAct2ImageProcessor.register_for_auto_class()
inference.py ADDED
@@ -0,0 +1,768 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Inference utilities for MolmoAct2"""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Iterable, Optional, Sequence, Tuple
5
+
6
+ import torch
7
+ from torch.nn import functional as F
8
+ from transformers.cache_utils import Cache
9
+ from transformers.configuration_utils import PretrainedConfig
10
+
11
+
12
+ @dataclass
13
+ class _ActionFlowInputs:
14
+ trajectory: torch.Tensor
15
+ context: Any
16
+ modulations: Sequence[Any]
17
+ action_dim_is_pad: Optional[torch.Tensor]
18
+
19
+
20
+ @dataclass
21
+ class _ActionFlowCudaGraph:
22
+ key: Tuple[Any, ...]
23
+ graph: torch.cuda.CUDAGraph
24
+ static_inputs: _ActionFlowInputs
25
+ output: torch.Tensor
26
+
27
+
28
+ @dataclass
29
+ class _DepthDecodeCudaGraphLayerStage:
30
+ residual: torch.Tensor
31
+ query: torch.Tensor
32
+ key: torch.Tensor
33
+ value: torch.Tensor
34
+
35
+
36
+ @dataclass
37
+ class _DepthDecodeCudaGraphPostStage:
38
+ graph: torch.cuda.CUDAGraph
39
+ attn_context: torch.Tensor
40
+
41
+
42
+ @dataclass
43
+ class _DepthDecodeCudaGraph:
44
+ cache_key: Tuple[Any, ...]
45
+ pre_graph: torch.cuda.CUDAGraph
46
+ token_ids: torch.Tensor
47
+ cos: torch.Tensor
48
+ sin: torch.Tensor
49
+ positions: torch.Tensor
50
+ stages: Sequence[_DepthDecodeCudaGraphLayerStage]
51
+ post_graphs: Sequence[_DepthDecodeCudaGraphPostStage]
52
+ output: torch.Tensor
53
+
54
+
55
+ @dataclass
56
+ class _DepthDecodeCudaGraphSpec:
57
+ eligible: bool
58
+ cache_key_prefix: Tuple[Any, ...]
59
+ num_hidden_layers: int
60
+ head_dim: int
61
+ num_attention_heads: int
62
+
63
+
64
+ def _cache_seq_len_int(past_key_values: Optional[Cache]) -> int:
65
+ if past_key_values is None:
66
+ return 0
67
+ seq_len = past_key_values.get_seq_length()
68
+ if torch.is_tensor(seq_len):
69
+ return int(seq_len.item())
70
+ return int(seq_len)
71
+
72
+
73
+ def _cache_max_len_int(past_key_values: Optional[Cache]) -> int:
74
+ if past_key_values is None:
75
+ return -1
76
+ max_len = past_key_values.get_max_cache_shape()
77
+ if torch.is_tensor(max_len):
78
+ return int(max_len.item())
79
+ return int(max_len)
80
+
81
+
82
+ def _iter_cache_key_values(
83
+ past_key_values: Cache,
84
+ ) -> Iterable[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]]:
85
+ layers = getattr(past_key_values, "layers", None)
86
+ if layers is not None:
87
+ for layer in layers:
88
+ yield getattr(layer, "keys", None), getattr(layer, "values", None)
89
+ return
90
+ for layer in past_key_values:
91
+ yield layer[0], layer[1]
92
+
93
+
94
+ class _DepthDecodeStaticLayerCache:
95
+ is_compileable = False
96
+ is_sliding = False
97
+
98
+ def __init__(self, max_cache_len: int) -> None:
99
+ self.max_cache_len = int(max_cache_len)
100
+ self.cumulative_length = 0
101
+ self.keys: Optional[torch.Tensor] = None
102
+ self.values: Optional[torch.Tensor] = None
103
+
104
+ def _allocate(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
105
+ bsz, n_heads = key_states.shape[:2]
106
+ self.keys = torch.empty(
107
+ (bsz, n_heads, self.max_cache_len, key_states.shape[-1]),
108
+ dtype=key_states.dtype,
109
+ device=key_states.device,
110
+ )
111
+ self.values = torch.empty(
112
+ (bsz, n_heads, self.max_cache_len, value_states.shape[-1]),
113
+ dtype=value_states.dtype,
114
+ device=value_states.device,
115
+ )
116
+
117
+ def update(
118
+ self,
119
+ key_states: torch.Tensor,
120
+ value_states: torch.Tensor,
121
+ *args,
122
+ **kwargs,
123
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
124
+ if self.keys is None:
125
+ self._allocate(key_states, value_states)
126
+ start = self.cumulative_length
127
+ end = start + key_states.shape[-2]
128
+ if end > self.max_cache_len:
129
+ raise RuntimeError(
130
+ f"KV cache length {end} exceeds max_cache_len={self.max_cache_len}."
131
+ )
132
+ self.keys[:, :, start:end, :].copy_(key_states)
133
+ self.values[:, :, start:end, :].copy_(value_states)
134
+ self.cumulative_length = end
135
+ return self.keys[:, :, :end, :], self.values[:, :, :end, :]
136
+
137
+ def get_seq_length(self) -> int:
138
+ return self.cumulative_length
139
+
140
+ def get_max_cache_shape(self) -> int:
141
+ return -1
142
+
143
+ def reset(self) -> None:
144
+ self.cumulative_length = 0
145
+
146
+
147
+ class _DepthDecodeStaticCache(Cache):
148
+ def __init__(self, config: PretrainedConfig, max_cache_len: int) -> None:
149
+ text_config = config.get_text_config(decoder=True)
150
+ super().__init__(
151
+ layers=[
152
+ _DepthDecodeStaticLayerCache(max_cache_len=max_cache_len)
153
+ for _ in range(text_config.num_hidden_layers)
154
+ ]
155
+ )
156
+
157
+ def get_seq_length(self, layer_idx: int = 0) -> int:
158
+ return self.layers[layer_idx].get_seq_length()
159
+
160
+ def get_max_cache_shape(self, layer_idx: int = 0) -> int:
161
+ return self.layers[layer_idx].get_max_cache_shape()
162
+
163
+ def reset(self) -> None:
164
+ for layer in self.layers:
165
+ layer.reset()
166
+
167
+
168
+ class ActionCudaGraphManager:
169
+ def __init__(self, model: Any) -> None:
170
+ self.model = model
171
+ self.enabled = True
172
+ self.action_flow_graph: Optional[_ActionFlowCudaGraph] = None
173
+
174
+ def set_enabled(self, enabled: bool) -> None:
175
+ self.enabled = bool(enabled)
176
+
177
+ def can_use_action_flow(self, inputs: _ActionFlowInputs) -> bool:
178
+ action_model = self.model
179
+ if not self.enabled:
180
+ return False
181
+ if action_model.training or action_model._require_action_expert().training:
182
+ return False
183
+ if inputs.trajectory.device.type != "cuda":
184
+ return False
185
+
186
+ def all_on_cuda():
187
+ yield inputs.trajectory
188
+ for k, v in inputs.context.kv_contexts:
189
+ yield k
190
+ yield v
191
+ for t in (
192
+ inputs.context.cross_mask,
193
+ inputs.context.self_mask,
194
+ inputs.context.valid_action,
195
+ inputs.action_dim_is_pad,
196
+ ):
197
+ if t is not None:
198
+ yield t
199
+ if inputs.context.rope_cache is not None:
200
+ yield from inputs.context.rope_cache
201
+ for step in inputs.modulations:
202
+ yield step.conditioning
203
+ for block_modulation in step.block_modulations:
204
+ yield from block_modulation
205
+ yield from step.final_modulation
206
+
207
+ return all(t.device.type == "cuda" for t in all_on_cuda())
208
+
209
+ def run_action_flow(
210
+ self,
211
+ inputs: _ActionFlowInputs,
212
+ steps: int,
213
+ run_loop,
214
+ ) -> torch.Tensor:
215
+ key = _cuda_graph_key(inputs, steps)
216
+ cache = self.action_flow_graph
217
+ if cache is None or cache.key != key:
218
+ static_inputs = _clone_static_inputs(inputs)
219
+ graph, output = _capture_cuda_graph(
220
+ lambda: run_loop(static_inputs, steps),
221
+ inputs.trajectory.device,
222
+ after_warmup=lambda: static_inputs.trajectory.copy_(inputs.trajectory),
223
+ )
224
+ cache = _ActionFlowCudaGraph(
225
+ key=key,
226
+ graph=graph,
227
+ static_inputs=static_inputs,
228
+ output=output,
229
+ )
230
+ self.action_flow_graph = cache
231
+ else:
232
+ _copy_inputs_(cache.static_inputs, inputs)
233
+
234
+ cache.graph.replay()
235
+ return cache.output.clone()
236
+
237
+
238
+ class DepthDecodeCudaGraphManager:
239
+ def __init__(self, model: Any) -> None:
240
+ self.model = model
241
+ self.backbone = model.model
242
+ self.enabled = True
243
+ self.graph: Optional[_DepthDecodeCudaGraph] = None
244
+ self.graph_spec: Optional[_DepthDecodeCudaGraphSpec] = None
245
+
246
+ def set_enabled(self, enabled: bool) -> None:
247
+ self.enabled = bool(enabled)
248
+
249
+ def make_static_cache(self, max_cache_len: int) -> _DepthDecodeStaticCache:
250
+ return _DepthDecodeStaticCache(
251
+ config=self.model.config.text_config,
252
+ max_cache_len=max_cache_len,
253
+ )
254
+
255
+ def _depth_decode_spec(self) -> _DepthDecodeCudaGraphSpec:
256
+ static = self.graph_spec
257
+ if static is None:
258
+ cfg = self.backbone.transformer.config
259
+ rotary_emb = getattr(self.backbone.transformer, "rotary_emb", None)
260
+ static = _DepthDecodeCudaGraphSpec(
261
+ eligible=(
262
+ not cfg.norm_after
263
+ and cfg.rope_scaling_layers is None
264
+ and getattr(rotary_emb, "rope_type", None) == "default"
265
+ and cfg._attn_implementation == "sdpa"
266
+ ),
267
+ cache_key_prefix=(
268
+ cfg.hidden_size,
269
+ cfg.num_attention_heads,
270
+ cfg.num_key_value_heads,
271
+ cfg.head_dim,
272
+ cfg.num_hidden_layers,
273
+ cfg.use_qk_norm,
274
+ cfg.qk_norm_type,
275
+ cfg._attn_implementation,
276
+ ),
277
+ num_hidden_layers=cfg.num_hidden_layers,
278
+ head_dim=cfg.head_dim,
279
+ num_attention_heads=cfg.num_attention_heads,
280
+ )
281
+ self.graph_spec = static
282
+ return static
283
+
284
+ def can_use(
285
+ self,
286
+ next_input_ids: torch.Tensor,
287
+ *,
288
+ past_key_values: Cache,
289
+ attention_bias: torch.Tensor,
290
+ ) -> bool:
291
+ if (
292
+ not self.enabled
293
+ or self.model.training
294
+ or self.backbone.transformer.training
295
+ ):
296
+ return False
297
+ if next_input_ids.device.type != "cuda":
298
+ return False
299
+ if (
300
+ next_input_ids.ndim != 2
301
+ or next_input_ids.shape[0] != 1
302
+ or next_input_ids.shape[1] != 1
303
+ ):
304
+ return False
305
+ if not isinstance(past_key_values, _DepthDecodeStaticCache):
306
+ return False
307
+ if (
308
+ not torch.is_tensor(attention_bias)
309
+ or attention_bias.device != next_input_ids.device
310
+ ):
311
+ return False
312
+ return self._depth_decode_spec().eligible
313
+
314
+ def _depth_decode_key(
315
+ self,
316
+ next_input_ids: torch.Tensor,
317
+ attention_bias: torch.Tensor,
318
+ ) -> Tuple[Any, ...]:
319
+ device = next_input_ids.device
320
+ return (
321
+ self._depth_decode_spec().cache_key_prefix,
322
+ device.type,
323
+ device.index,
324
+ self.model.lm_head.weight.dtype,
325
+ attention_bias.shape[-1],
326
+ )
327
+
328
+ def _select_depth_decode_rope(
329
+ self, cos: torch.Tensor, sin: torch.Tensor, *, past_length: int
330
+ ) -> None:
331
+ emb = self.backbone.transformer.rotary_emb
332
+ cos.copy_(emb._pos_cos_cache[0, :, past_length : past_length + 1, :])
333
+ sin.copy_(emb._pos_sin_cache[0, :, past_length : past_length + 1, :])
334
+
335
+ def _depth_decode_pre_layer(
336
+ self,
337
+ layer_idx: int,
338
+ hidden_states: torch.Tensor,
339
+ cos: torch.Tensor,
340
+ sin: torch.Tensor,
341
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
342
+ block = self.backbone.transformer.blocks[layer_idx]
343
+ attention = block.self_attn
344
+ residual = hidden_states
345
+ hidden_states = block.attn_norm(hidden_states)
346
+
347
+ input_shape = hidden_states.shape[:-1]
348
+ hidden_shape = (*input_shape, -1, attention.head_dim)
349
+ qkv = attention.att_proj(hidden_states)
350
+ query_states, key_states, value_states = qkv.split(attention.fused_dims, dim=-1)
351
+ value_states = value_states.view(hidden_shape)
352
+
353
+ apply_qk_norm = attention.q_norm is not None and attention.k_norm is not None
354
+ norm_after_view = apply_qk_norm and attention.qk_norm_type == "qwen3"
355
+
356
+ if apply_qk_norm and not norm_after_view:
357
+ query_states = attention.q_norm(query_states)
358
+ key_states = attention.k_norm(key_states)
359
+
360
+ query_states = query_states.view(hidden_shape)
361
+ key_states = key_states.view(hidden_shape)
362
+
363
+ if norm_after_view:
364
+ query_states = attention.q_norm(query_states)
365
+ key_states = attention.k_norm(key_states)
366
+
367
+ query_states = query_states.transpose(1, 2)
368
+ key_states = key_states.transpose(1, 2)
369
+ value_states = value_states.transpose(1, 2)
370
+ query_states, key_states = _apply_rotary_pos_emb(
371
+ query_states, key_states, cos, sin
372
+ )
373
+ return residual, query_states, key_states, value_states
374
+
375
+ def _depth_decode_pre0(
376
+ self,
377
+ token_ids: torch.Tensor,
378
+ cos: torch.Tensor,
379
+ sin: torch.Tensor,
380
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
381
+ inputs_embeds = self.model._embed_base_tokens(token_ids)
382
+ return self._depth_decode_pre_layer(0, inputs_embeds, cos, sin)
383
+
384
+ def _depth_decode_post_layer(
385
+ self,
386
+ layer_idx: int,
387
+ residual: torch.Tensor,
388
+ attn_context: torch.Tensor,
389
+ ) -> torch.Tensor:
390
+ block = self.backbone.transformer.blocks[layer_idx]
391
+ attention = block.self_attn
392
+ input_shape = residual.shape[:-1]
393
+ attn_output = attn_context.reshape(*input_shape, -1).contiguous()
394
+ attn_output = attention.attn_out(attn_output)
395
+ hidden_states = residual + block.dropout(attn_output)
396
+
397
+ residual = hidden_states
398
+ hidden_states = block.ff_norm(hidden_states)
399
+ hidden_states = block.mlp(hidden_states)
400
+ hidden_states = residual + block.dropout(hidden_states)
401
+ return hidden_states
402
+
403
+ def _depth_decode_post_and_pre_next(
404
+ self,
405
+ layer_idx: int,
406
+ residual: torch.Tensor,
407
+ attn_context: torch.Tensor,
408
+ cos: torch.Tensor,
409
+ sin: torch.Tensor,
410
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
411
+ hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
412
+ return self._depth_decode_pre_layer(layer_idx + 1, hidden_states, cos, sin)
413
+
414
+ def _depth_decode_last_post(
415
+ self,
416
+ layer_idx: int,
417
+ residual: torch.Tensor,
418
+ attn_context: torch.Tensor,
419
+ ) -> torch.Tensor:
420
+ hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
421
+ return self.backbone.transformer.ln_f(hidden_states)
422
+
423
+ def _build_depth_decode_graph(
424
+ self,
425
+ next_input_ids: torch.Tensor,
426
+ *,
427
+ past_length: int,
428
+ attention_bias: torch.Tensor,
429
+ ) -> _DepthDecodeCudaGraph:
430
+ text_config = self.backbone.transformer.config
431
+ device = next_input_ids.device
432
+ dtype = self.model.lm_head.weight.dtype
433
+ static = self._depth_decode_spec()
434
+ num_layers = static.num_hidden_layers
435
+ head_dim = static.head_dim
436
+ max_cache_len = int(attention_bias.shape[-1])
437
+ max_rope_len = max(int(text_config.max_position_embeddings or 0), max_cache_len)
438
+ self.backbone.transformer.prepare_rope_cache(
439
+ device=device, max_seq_len=max_rope_len
440
+ )
441
+
442
+ token_ids = torch.empty((1, 1), device=device, dtype=torch.long)
443
+ cos = torch.empty((1, 1, head_dim), device=device, dtype=dtype)
444
+ sin = torch.empty_like(cos)
445
+ positions = torch.arange(max_cache_len, device=device, dtype=torch.long)
446
+ context_shape = (1, 1, static.num_attention_heads, head_dim)
447
+
448
+ token_ids.copy_(next_input_ids)
449
+ self._select_depth_decode_rope(cos, sin, past_length=past_length)
450
+
451
+ pre_graph, pre_output = _capture_cuda_graph(
452
+ lambda: self._depth_decode_pre0(token_ids, cos, sin),
453
+ device,
454
+ )
455
+ stages = [_DepthDecodeCudaGraphLayerStage(*pre_output)]
456
+ post_graphs = []
457
+ for layer_idx in range(num_layers - 1):
458
+ stage = stages[-1]
459
+ attn_context = torch.empty(context_shape, device=device, dtype=dtype)
460
+ graph, output = _capture_cuda_graph(
461
+ lambda layer_idx=layer_idx, stage=stage, attn_context=attn_context: (
462
+ self._depth_decode_post_and_pre_next(
463
+ layer_idx,
464
+ stage.residual,
465
+ attn_context,
466
+ cos,
467
+ sin,
468
+ )
469
+ ),
470
+ device,
471
+ )
472
+ post_graphs.append(
473
+ _DepthDecodeCudaGraphPostStage(graph=graph, attn_context=attn_context)
474
+ )
475
+ stages.append(_DepthDecodeCudaGraphLayerStage(*output))
476
+
477
+ last_stage = stages[-1]
478
+ last_attn_context = torch.empty(context_shape, device=device, dtype=dtype)
479
+ last_graph, last_output = _capture_cuda_graph(
480
+ lambda: self._depth_decode_last_post(
481
+ num_layers - 1,
482
+ last_stage.residual,
483
+ last_attn_context,
484
+ ),
485
+ device,
486
+ )
487
+ post_graphs.append(
488
+ _DepthDecodeCudaGraphPostStage(
489
+ graph=last_graph, attn_context=last_attn_context
490
+ )
491
+ )
492
+ return _DepthDecodeCudaGraph(
493
+ cache_key=self._depth_decode_key(next_input_ids, attention_bias),
494
+ pre_graph=pre_graph,
495
+ token_ids=token_ids,
496
+ cos=cos,
497
+ sin=sin,
498
+ positions=positions,
499
+ stages=tuple(stages),
500
+ post_graphs=tuple(post_graphs),
501
+ output=last_output,
502
+ )
503
+
504
+ def _get_depth_decode_graph(
505
+ self,
506
+ next_input_ids: torch.Tensor,
507
+ *,
508
+ past_length: int,
509
+ attention_bias: torch.Tensor,
510
+ ) -> _DepthDecodeCudaGraph:
511
+ key = self._depth_decode_key(next_input_ids, attention_bias)
512
+ decode_graph = self.graph
513
+ if decode_graph is None or decode_graph.cache_key != key:
514
+ decode_graph = self._build_depth_decode_graph(
515
+ next_input_ids,
516
+ past_length=past_length,
517
+ attention_bias=attention_bias,
518
+ )
519
+ self.graph = decode_graph
520
+ else:
521
+ decode_graph.token_ids.copy_(next_input_ids)
522
+ self._select_depth_decode_rope(
523
+ decode_graph.cos, decode_graph.sin, past_length=past_length
524
+ )
525
+ return decode_graph
526
+
527
+ def _run_depth_decode_attention_core(
528
+ self,
529
+ layer_idx: int,
530
+ stage: _DepthDecodeCudaGraphLayerStage,
531
+ *,
532
+ past_key_values: Cache,
533
+ attention_bias: torch.Tensor,
534
+ cache_position: torch.Tensor,
535
+ cos: torch.Tensor,
536
+ sin: torch.Tensor,
537
+ ) -> torch.Tensor:
538
+ attention = self.backbone.transformer.blocks[layer_idx].self_attn
539
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
540
+ key_states, value_states = past_key_values.update(
541
+ stage.key,
542
+ stage.value,
543
+ layer_idx,
544
+ cache_kwargs,
545
+ )
546
+ key_states = _repeat_kv(key_states, attention.num_key_value_groups)
547
+ value_states = _repeat_kv(value_states, attention.num_key_value_groups)
548
+ attn_output = F.scaled_dot_product_attention(
549
+ stage.query,
550
+ key_states,
551
+ value_states,
552
+ attn_mask=attention_bias,
553
+ dropout_p=0.0,
554
+ is_causal=False,
555
+ )
556
+ return attn_output.transpose(1, 2)
557
+
558
+ def run(
559
+ self,
560
+ next_input_ids: torch.Tensor,
561
+ *,
562
+ past_key_values: Cache,
563
+ attention_bias: torch.Tensor,
564
+ past_length: int,
565
+ ) -> Tuple[torch.Tensor, Cache]:
566
+ end = past_length + 1
567
+ decode_graph = self._get_depth_decode_graph(
568
+ next_input_ids,
569
+ past_length=past_length,
570
+ attention_bias=attention_bias,
571
+ )
572
+ cache_position = decode_graph.positions[past_length:end]
573
+ attention_bias_q = attention_bias[:, :, past_length:end, :end]
574
+
575
+ decode_graph.pre_graph.replay()
576
+
577
+ for layer_idx, post_graph in enumerate(decode_graph.post_graphs):
578
+ attn_context = self._run_depth_decode_attention_core(
579
+ layer_idx,
580
+ decode_graph.stages[layer_idx],
581
+ past_key_values=past_key_values,
582
+ attention_bias=attention_bias_q,
583
+ cache_position=cache_position,
584
+ cos=decode_graph.cos,
585
+ sin=decode_graph.sin,
586
+ )
587
+ post_graph.attn_context.copy_(attn_context)
588
+ post_graph.graph.replay()
589
+
590
+ return decode_graph.output, past_key_values
591
+
592
+
593
+ def _cuda_graph_tensor_signature(
594
+ tensor: Optional[torch.Tensor],
595
+ ) -> Optional[Tuple[Any, ...]]:
596
+ if tensor is None:
597
+ return None
598
+ return (
599
+ tuple(tensor.shape),
600
+ tuple(tensor.stride()),
601
+ str(tensor.dtype),
602
+ str(tensor.device),
603
+ )
604
+
605
+
606
+ def _cuda_graph_context_signature(context: Any) -> Tuple[Any, ...]:
607
+ sig = _cuda_graph_tensor_signature
608
+ return (
609
+ tuple((sig(k), sig(v)) for k, v in context.kv_contexts),
610
+ sig(context.cross_mask),
611
+ sig(context.self_mask),
612
+ sig(context.valid_action),
613
+ None
614
+ if context.rope_cache is None
615
+ else tuple(sig(t) for t in context.rope_cache),
616
+ )
617
+
618
+
619
+ def _cuda_graph_modulation_signature(modulations: Sequence[Any]) -> Tuple[Any, ...]:
620
+ sig = _cuda_graph_tensor_signature
621
+ return tuple(
622
+ (
623
+ sig(step.conditioning),
624
+ tuple(
625
+ tuple(sig(t) for t in block_modulation)
626
+ for block_modulation in step.block_modulations
627
+ ),
628
+ tuple(sig(t) for t in step.final_modulation),
629
+ )
630
+ for step in modulations
631
+ )
632
+
633
+
634
+ def _cuda_graph_key(inputs: _ActionFlowInputs, steps: int) -> Tuple[Any, ...]:
635
+ sig = _cuda_graph_tensor_signature
636
+ return (
637
+ sig(inputs.trajectory),
638
+ _cuda_graph_context_signature(inputs.context),
639
+ _cuda_graph_modulation_signature(inputs.modulations),
640
+ sig(inputs.action_dim_is_pad),
641
+ int(steps),
642
+ )
643
+
644
+
645
+ def _clone_static_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
646
+ if tensor is None:
647
+ return None
648
+ static = torch.empty_strided(
649
+ tuple(tensor.shape),
650
+ tuple(tensor.stride()),
651
+ device=tensor.device,
652
+ dtype=tensor.dtype,
653
+ )
654
+ static.copy_(tensor)
655
+ return static
656
+
657
+
658
+ def _clone_static_context(context: Any) -> Any:
659
+ rope_cache = None
660
+ if context.rope_cache is not None:
661
+ rope_cache = tuple(_clone_static_tensor(t) for t in context.rope_cache)
662
+ return context.__class__(
663
+ kv_contexts=tuple(
664
+ (_clone_static_tensor(k), _clone_static_tensor(v))
665
+ for k, v in context.kv_contexts
666
+ ),
667
+ cross_mask=_clone_static_tensor(context.cross_mask),
668
+ self_mask=_clone_static_tensor(context.self_mask),
669
+ valid_action=_clone_static_tensor(context.valid_action),
670
+ rope_cache=rope_cache,
671
+ )
672
+
673
+
674
+ def _clone_static_modulations(modulations: Sequence[Any]) -> Sequence[Any]:
675
+ return tuple(
676
+ step.__class__(
677
+ conditioning=_clone_static_tensor(step.conditioning),
678
+ block_modulations=tuple(
679
+ tuple(_clone_static_tensor(t) for t in block_modulation)
680
+ for block_modulation in step.block_modulations
681
+ ),
682
+ final_modulation=tuple(
683
+ _clone_static_tensor(t) for t in step.final_modulation
684
+ ),
685
+ )
686
+ for step in modulations
687
+ )
688
+
689
+
690
+ def _clone_static_inputs(inputs: _ActionFlowInputs) -> _ActionFlowInputs:
691
+ return _ActionFlowInputs(
692
+ trajectory=_clone_static_tensor(inputs.trajectory),
693
+ context=_clone_static_context(inputs.context),
694
+ modulations=_clone_static_modulations(inputs.modulations),
695
+ action_dim_is_pad=_clone_static_tensor(inputs.action_dim_is_pad),
696
+ )
697
+
698
+
699
+ def _copy_context_(dst: Any, src: Any) -> None:
700
+ for (dst_k, dst_v), (src_k, src_v) in zip(dst.kv_contexts, src.kv_contexts):
701
+ dst_k.copy_(src_k)
702
+ dst_v.copy_(src_v)
703
+ if src.cross_mask is not None:
704
+ dst.cross_mask.copy_(src.cross_mask)
705
+ if src.self_mask is not None:
706
+ dst.self_mask.copy_(src.self_mask)
707
+ if src.valid_action is not None:
708
+ dst.valid_action.copy_(src.valid_action)
709
+ if src.rope_cache is not None:
710
+ for dst_tensor, src_tensor in zip(dst.rope_cache, src.rope_cache):
711
+ dst_tensor.copy_(src_tensor)
712
+
713
+
714
+ def _copy_inputs_(dst: _ActionFlowInputs, src: _ActionFlowInputs) -> None:
715
+ dst.trajectory.copy_(src.trajectory)
716
+ _copy_context_(dst.context, src.context)
717
+ if src.action_dim_is_pad is not None:
718
+ dst.action_dim_is_pad.copy_(src.action_dim_is_pad)
719
+
720
+
721
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
722
+ x1 = x[..., : x.shape[-1] // 2]
723
+ x2 = x[..., x.shape[-1] // 2 :]
724
+ return torch.cat((-x2, x1), dim=-1)
725
+
726
+
727
+ def _apply_rotary_pos_emb(
728
+ q: torch.Tensor,
729
+ k: torch.Tensor,
730
+ cos: torch.Tensor,
731
+ sin: torch.Tensor,
732
+ unsqueeze_dim: int = 1,
733
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
734
+ cos = cos.unsqueeze(unsqueeze_dim)
735
+ sin = sin.unsqueeze(unsqueeze_dim)
736
+ q_embed = (q * cos) + (_rotate_half(q) * sin)
737
+ k_embed = (k * cos) + (_rotate_half(k) * sin)
738
+ return q_embed, k_embed
739
+
740
+
741
+ def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
742
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
743
+ if n_rep == 1:
744
+ return hidden_states
745
+ hidden_states = hidden_states[:, :, None, :, :].expand(
746
+ batch, num_key_value_heads, n_rep, slen, head_dim
747
+ )
748
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
749
+
750
+
751
+ def _capture_cuda_graph(
752
+ fn,
753
+ device: torch.device,
754
+ *,
755
+ after_warmup=None,
756
+ ) -> Tuple[torch.cuda.CUDAGraph, Any]:
757
+ warmup_stream = torch.cuda.Stream(device=device)
758
+ warmup_stream.wait_stream(torch.cuda.current_stream(device))
759
+ with torch.cuda.stream(warmup_stream):
760
+ fn()
761
+ torch.cuda.current_stream(device).wait_stream(warmup_stream)
762
+ if after_warmup is not None:
763
+ after_warmup()
764
+
765
+ graph = torch.cuda.CUDAGraph()
766
+ with torch.cuda.graph(graph):
767
+ output = fn()
768
+ return graph, output
model-00001-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5aa56113a3c1c95f0c03b90bada5b7ac2babf59c752a3d6201b8362a530c6809
3
+ size 4919324120
model-00002-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1fd95aec9ad144d3f4eac42db88a000321e016beeae606ecfcb84fcbdb13ecf4
3
+ size 4844690992
model-00003-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e6c455434915e74a8d750bedd0fc4c23b5a8ce37954aaff9c3e3f22085a26890
3
+ size 4844691024
model-00004-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec4de13f1384f7b206f6090ade912667df17bfdb5da66fab7a3edf5c18c48ded
3
+ size 4998106920
model-00005-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2808d1df315627f809d54baf4108160b10daa5095d64d4319ed5cd24e94be851
3
+ size 2334605176
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_molmoact2.py ADDED
The diff for this file is too large to render. See raw diff
 
norm_stats.json ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "format": "molmoact2_norm_stats.v1",
3
+ "norm_mode": "q01_q99",
4
+ "metadata_by_tag": {
5
+ "libero": {
6
+ "action_key": "action",
7
+ "state_key": "observation.state",
8
+ "camera_keys": [
9
+ "observation.images.image",
10
+ "observation.images.wrist_image"
11
+ ],
12
+ "normalize_gripper": false,
13
+ "action_horizon": 10,
14
+ "n_action_steps": 10,
15
+ "setup_type": "single franka robotic arm in libero",
16
+ "control_mode": "delta end-effector pose",
17
+ "action_stats": {
18
+ "min": [
19
+ -0.9375,
20
+ -0.9375,
21
+ -0.9375,
22
+ -0.2582142949104309,
23
+ -0.375,
24
+ -0.3675000071525574,
25
+ -1.0
26
+ ],
27
+ "max": [
28
+ 0.9375,
29
+ 0.9375,
30
+ 0.9375,
31
+ 0.3557142913341522,
32
+ 0.375,
33
+ 0.375,
34
+ 1.0
35
+ ],
36
+ "mean": [
37
+ 0.06278156570450202,
38
+ 0.08684081017968912,
39
+ -0.09037305936952836,
40
+ 0.0005407430783705141,
41
+ 0.0056433796450358715,
42
+ -0.005229098518603562,
43
+ -0.04964072167678376
44
+ ],
45
+ "std": [
46
+ 0.3355237114945633,
47
+ 0.3784469867268323,
48
+ 0.44472859911256607,
49
+ 0.03924354049229973,
50
+ 0.06339296407444922,
51
+ 0.07797027713976648,
52
+ 0.9987671529022402
53
+ ],
54
+ "count": [
55
+ 273465.0
56
+ ],
57
+ "q01": [
58
+ -0.6792031928846481,
59
+ -0.7736573115323259,
60
+ -0.8728073904104404,
61
+ -0.10277447185825356,
62
+ -0.15509810617083444,
63
+ -0.20289961475228455,
64
+ -1.0
65
+ ],
66
+ "q10": [
67
+ -0.328718721971874,
68
+ -0.3626162647358338,
69
+ -0.6610056625361599,
70
+ -0.03907064459203904,
71
+ -0.06428551162168497,
72
+ -0.07928202560631951,
73
+ -1.0
74
+ ],
75
+ "q50": [
76
+ 0.015333975787982875,
77
+ 0.006437010746251905,
78
+ -0.07265095199149316,
79
+ -1.701317418858285e-05,
80
+ 0.00021801956089207239,
81
+ -5.852172701796134e-05,
82
+ -0.12287333595187695
83
+ ],
84
+ "q90": [
85
+ 0.5238177265233007,
86
+ 0.671417970219526,
87
+ 0.5384412174699407,
88
+ 0.040331002487738146,
89
+ 0.08240652401791884,
90
+ 0.0690125677722944,
91
+ 0.9999141552827842
92
+ ],
93
+ "q99": [
94
+ 0.8536542808794264,
95
+ 0.8637811051429717,
96
+ 0.9363295547540081,
97
+ 0.13045695485814487,
98
+ 0.18015313802054606,
99
+ 0.24129727661704234,
100
+ 0.9999914155282784
101
+ ],
102
+ "names": [
103
+ "x",
104
+ "y",
105
+ "z",
106
+ "roll",
107
+ "pitch",
108
+ "yaw",
109
+ "gripper"
110
+ ],
111
+ "mask": [
112
+ true,
113
+ true,
114
+ true,
115
+ true,
116
+ true,
117
+ true,
118
+ false
119
+ ]
120
+ },
121
+ "state_stats": {
122
+ "min": [
123
+ -0.4828203022480011,
124
+ -0.3255046010017395,
125
+ 0.008128180168569088,
126
+ 0.35277295112609863,
127
+ -3.641430377960205,
128
+ -1.842738389968872,
129
+ -0.0013586411951109767,
130
+ -0.042040832340717316
131
+ ],
132
+ "max": [
133
+ 0.21031762659549713,
134
+ 0.39128610491752625,
135
+ 1.3660105466842651,
136
+ 3.6714255809783936,
137
+ 3.560650587081909,
138
+ 1.386339545249939,
139
+ 0.04233968257904053,
140
+ 0.0013633022317662835
141
+ ],
142
+ "mean": [
143
+ -0.04651878279191748,
144
+ 0.034409066787269356,
145
+ 0.7645525031210381,
146
+ 2.9722094975655056,
147
+ -0.22046978549041713,
148
+ -0.1255794031738752,
149
+ 0.026914253269017054,
150
+ -0.027190783616938205
151
+ ],
152
+ "std": [
153
+ 0.10494395508556839,
154
+ 0.1517661933220375,
155
+ 0.378516707505034,
156
+ 0.34427344187858827,
157
+ 0.9069468516043042,
158
+ 0.32539190149967406,
159
+ 0.01417590382231912,
160
+ 0.014058894296088888
161
+ ],
162
+ "count": [
163
+ 273465.0
164
+ ],
165
+ "q01": [
166
+ -0.31479429659059555,
167
+ -0.26691552643710226,
168
+ 0.5194626050191016,
169
+ 2.159994551314992,
170
+ -1.801294177865994,
171
+ -0.8949778881389838,
172
+ 0.003382730811955442,
173
+ -0.04008920533069468
174
+ ],
175
+ "q10": [
176
+ -0.18409729127502492,
177
+ -0.158759498072202,
178
+ 0.5694822295083012,
179
+ 2.501970046458546,
180
+ -1.1889107640062022,
181
+ -0.5297043790093273,
182
+ 0.007573322430226042,
183
+ -0.039827946964434036
184
+ ],
185
+ "q50": [
186
+ -0.02822545357081922,
187
+ 0.029718887641213443,
188
+ 0.7185643731428462,
189
+ 3.0915725099012166,
190
+ -0.12491069931831773,
191
+ -0.08338984738533357,
192
+ 0.030648370056451133,
193
+ -0.031519123023466586
194
+ ],
195
+ "q90": [
196
+ 0.06725052913150302,
197
+ 0.23387160335018267,
198
+ 0.9599947530498419,
199
+ 3.1743361507512997,
200
+ 0.5456820212337484,
201
+ 0.20414514594693875,
202
+ 0.03985537019679712,
203
+ -0.008040434619037518
204
+ ],
205
+ "q99": [
206
+ 0.1222615490116252,
207
+ 0.3140223876046953,
208
+ 1.042961724319958,
209
+ 3.277638017923068,
210
+ 1.724488202195691,
211
+ 0.5659922739094448,
212
+ 0.04009682017699841,
213
+ -0.003493522538066522
214
+ ],
215
+ "names": [
216
+ "x",
217
+ "y",
218
+ "z",
219
+ "rx",
220
+ "ry",
221
+ "rz",
222
+ "rw",
223
+ "gripper"
224
+ ],
225
+ "mask": [
226
+ true,
227
+ true,
228
+ true,
229
+ true,
230
+ true,
231
+ true,
232
+ true,
233
+ false
234
+ ]
235
+ }
236
+ }
237
+ }
238
+ }
processing_molmoact2.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Processor class for MolmoAct2.
3
+ """
4
+ from typing import Optional, Union
5
+ import dataclasses
6
+
7
+ import numpy as np
8
+
9
+ from transformers.image_utils import ImageInput
10
+ from transformers.video_utils import VideoInput
11
+ from transformers.processing_utils import (
12
+ Unpack,
13
+ ProcessingKwargs,
14
+ ProcessorMixin,
15
+ )
16
+ from transformers.feature_extraction_utils import BatchFeature
17
+ from transformers.tokenization_utils_base import TextInput, PreTokenizedInput
18
+ from transformers.utils import logging
19
+
20
+ from transformers import AutoTokenizer
21
+ from .image_processing_molmoact2 import MolmoAct2ImagesKwargs, MolmoAct2ImageProcessor
22
+ from .video_processing_molmoact2 import MolmoAct2VideoProcessorKwargs, MolmoAct2VideoProcessor
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ # Special tokens, these should be present in any tokenizer we use since the preprocessor uses them
29
+ IMAGE_PATCH_TOKEN = f"<im_patch>" # Where to insert high-res tokens
30
+ IMAGE_LOW_RES_TOKEN = f"<im_low>" # Where to insert low-res tokens
31
+ IM_START_TOKEN = f"<im_start>"
32
+ LOW_RES_IMAGE_START_TOKEN = f"<low_res_im_start>"
33
+ FRAME_START_TOKEN = f"<frame_start>"
34
+ IM_END_TOKEN = f"<im_end>"
35
+ FRAME_END_TOKEN= f"<frame_end>"
36
+ IM_COL_TOKEN = f"<im_col>"
37
+ IMAGE_PROMPT = "<|image|>"
38
+ VIDEO_PROMPT = "<|video|>"
39
+
40
+ IMAGE_TOKENS = [
41
+ IMAGE_PATCH_TOKEN,
42
+ IM_COL_TOKEN,
43
+ IM_START_TOKEN,
44
+ LOW_RES_IMAGE_START_TOKEN,
45
+ FRAME_START_TOKEN,
46
+ IM_END_TOKEN,
47
+ FRAME_END_TOKEN,
48
+ IMAGE_LOW_RES_TOKEN,
49
+ ]
50
+
51
+
52
+ class MolmoAct2ProcessorKwargs(ProcessingKwargs, total=False):
53
+ """MolmoAct2 processor kwargs"""
54
+ images_kwargs: MolmoAct2ImagesKwargs
55
+ videos_kwargs: MolmoAct2VideoProcessorKwargs
56
+ _defaults = {
57
+ "text_kwargs": {
58
+ "padding": False,
59
+ "return_mm_token_type_ids": True,
60
+ },
61
+ "videos_kwargs": {"return_metadata": True},
62
+ }
63
+
64
+
65
+ class MolmoAct2Processor(ProcessorMixin):
66
+ attributes = ["image_processor", "video_processor", "tokenizer"]
67
+ optional_attributes = [
68
+ "chat_template",
69
+ "time_mode",
70
+ "image_use_col_tokens",
71
+ "use_single_crop_col_tokens",
72
+ "use_single_crop_start_token",
73
+ "video_use_col_tokens",
74
+ "use_frame_special_tokens",
75
+ ]
76
+ image_processor_class = "AutoImageProcessor"
77
+ video_processor_class = "AutoVideoProcessor"
78
+ tokenizer_class = "AutoTokenizer"
79
+
80
+ def __init__(
81
+ self,
82
+ image_processor: MolmoAct2ImageProcessor = None,
83
+ video_processor: MolmoAct2VideoProcessor = None,
84
+ tokenizer: AutoTokenizer = None,
85
+ chat_template: Optional[str] = None,
86
+ image_use_col_tokens: Optional[bool] = True,
87
+ use_single_crop_col_tokens: Optional[bool] = None,
88
+ use_single_crop_start_token: Optional[bool] = True,
89
+ video_use_col_tokens: Optional[bool] = False,
90
+ use_frame_special_tokens: Optional[bool] = True,
91
+ **kwargs
92
+ ) -> None:
93
+ super().__init__(
94
+ image_processor,
95
+ video_processor,
96
+ tokenizer,
97
+ chat_template=chat_template,
98
+ )
99
+ self.image_use_col_tokens = image_use_col_tokens
100
+ self.use_single_crop_col_tokens = use_single_crop_col_tokens
101
+ self.use_single_crop_start_token = use_single_crop_start_token
102
+ self.video_use_col_tokens = video_use_col_tokens
103
+ self.use_frame_special_tokens = use_frame_special_tokens
104
+
105
+ self.image_placeholder_token = IMAGE_PROMPT
106
+ self.video_placeholder_token = VIDEO_PROMPT
107
+ self.image_token_ids = [
108
+ tokenizer.convert_tokens_to_ids(token)
109
+ for token in IMAGE_TOKENS
110
+ ]
111
+
112
+ def get_image_tokens(self, image_grid: np.ndarray):
113
+ resized_h, resized_w, height, width = image_grid
114
+ if int(height) == 0 or int(width) == 0:
115
+ per_row = np.full(resized_w, IMAGE_PATCH_TOKEN)
116
+ use_single_crop_col_tokens = (
117
+ self.image_use_col_tokens
118
+ if self.use_single_crop_col_tokens is None
119
+ else self.use_single_crop_col_tokens
120
+ )
121
+ if use_single_crop_col_tokens:
122
+ per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
123
+ joint = [
124
+ [IM_START_TOKEN],
125
+ np.tile(per_row, [resized_h]),
126
+ [IM_END_TOKEN],
127
+ ]
128
+ return np.concatenate(joint)
129
+ per_row = np.full(width, IMAGE_PATCH_TOKEN)
130
+ if self.image_use_col_tokens:
131
+ per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
132
+ joint = [
133
+ [IM_START_TOKEN],
134
+ np.tile(per_row, [height]),
135
+ [IM_END_TOKEN],
136
+ ]
137
+ per_row = np.full(resized_w, IMAGE_PATCH_TOKEN)
138
+ use_single_crop_col_tokens = (
139
+ self.image_use_col_tokens
140
+ if self.use_single_crop_col_tokens is None
141
+ else self.use_single_crop_col_tokens
142
+ )
143
+ image_start_token = (
144
+ LOW_RES_IMAGE_START_TOKEN
145
+ if self.use_single_crop_start_token
146
+ else IM_START_TOKEN
147
+ )
148
+ if use_single_crop_col_tokens:
149
+ per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
150
+ joint = [
151
+ [image_start_token],
152
+ np.tile(per_row, [resized_h]),
153
+ [IM_END_TOKEN],
154
+ ] + joint
155
+
156
+ return np.concatenate(joint)
157
+
158
+ def get_video_string(
159
+ self,
160
+ video_grid: np.ndarray,
161
+ timestamps: np.ndarray,
162
+ ):
163
+ if self.use_frame_special_tokens:
164
+ start_token_id = FRAME_START_TOKEN
165
+ end_token_id = FRAME_END_TOKEN
166
+ else:
167
+ start_token_id = IM_START_TOKEN
168
+ end_token_id = IM_END_TOKEN
169
+
170
+ num_frames, h, w = video_grid
171
+ video_string: str = ""
172
+ for frame_idx, frame_time in enumerate(timestamps):
173
+ # `per-frame-compact` time mode
174
+ prev_space = " " if frame_idx > 0 else ""
175
+ frame_prefix = prev_space + f"{frame_time:.1f} " # explicit whitespace before/after image tokens
176
+
177
+ video_string += frame_prefix
178
+ per_row = np.full(w, IMAGE_PATCH_TOKEN)
179
+ if self.video_use_col_tokens:
180
+ per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
181
+ extra_tokens = np.tile(per_row, [h])
182
+ video_tokens = [
183
+ [start_token_id],
184
+ extra_tokens,
185
+ [end_token_id],
186
+ ]
187
+ video_string += "".join(np.concatenate(video_tokens, 0))
188
+
189
+ return video_string
190
+
191
+ def insert_bos(
192
+ self,
193
+ input_ids: np.ndarray,
194
+ attention_mask: np.ndarray,
195
+ bos_token_id: int,
196
+ pad_token_id: int,
197
+ ):
198
+ """
199
+ Args:
200
+ input_ids: [B, S] array with left padding
201
+ attention_mask: [B, S] array (0 for pad, 1 for valid)
202
+ bos_token_id: int
203
+ pad_token_id: int
204
+ Returns:
205
+ input_ids_out: [B, S] or [B, S+1] array with bos inserted if needed
206
+ attention_mask_out: same shape as input_ids_out
207
+ """
208
+
209
+ need_to_expand = len(input_ids.shape) == 1
210
+ if need_to_expand:
211
+ input_ids = input_ids[None, :]
212
+ attention_mask = attention_mask[None, :]
213
+
214
+ B, S = input_ids.shape
215
+
216
+ # Handle zero-length sequence
217
+ if S == 0:
218
+ new_input_ids = np.full((B, 1), bos_token_id, dtype=input_ids.dtype)
219
+ new_attention_mask = np.ones((B, 1), dtype=attention_mask.dtype)
220
+ if need_to_expand:
221
+ new_input_ids = new_input_ids[0]
222
+ new_attention_mask = new_attention_mask[0]
223
+ return new_input_ids, new_attention_mask
224
+
225
+ first_valid_index = (attention_mask == 1).argmax(axis=-1) # [B]
226
+ bos_already_present = np.all(input_ids[np.arange(B), first_valid_index] == bos_token_id)
227
+
228
+ if bos_already_present:
229
+ if need_to_expand:
230
+ input_ids = input_ids[0]
231
+ attention_mask = attention_mask[0]
232
+ return input_ids, attention_mask
233
+ else:
234
+ new_input_ids = np.full((B, S+1), pad_token_id, dtype=input_ids.dtype)
235
+ new_attention_mask = np.zeros((B, S+1), dtype=attention_mask.dtype)
236
+
237
+ src_idx = np.tile(np.arange(S), (B, 1)) # [B, S]
238
+ valid_mask = src_idx >= first_valid_index[:, None] # [B, S]
239
+ tgt_idx = src_idx + 1 # shit right
240
+ batch_idx = np.tile(np.arange(B)[:, None], (1, S)) # [B, S]
241
+
242
+ # flatten valid_positions
243
+ flat_vals = input_ids[valid_mask]
244
+ flat_batch = batch_idx[valid_mask]
245
+ flat_tgt = tgt_idx[valid_mask]
246
+
247
+ new_input_ids[flat_batch, flat_tgt] = flat_vals
248
+ new_attention_mask[flat_batch, flat_tgt] = 1
249
+
250
+ insert_pos = first_valid_index
251
+ new_input_ids[np.arange(B), insert_pos] = bos_token_id
252
+ new_attention_mask[np.arange(B), insert_pos] = 1
253
+
254
+ if need_to_expand:
255
+ new_input_ids = new_input_ids[0]
256
+ new_attention_mask = new_attention_mask[0]
257
+
258
+ return new_input_ids, new_attention_mask
259
+
260
+ def __call__(
261
+ self,
262
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
263
+ images: ImageInput = None,
264
+ videos: VideoInput = None,
265
+ **kwargs: Unpack[MolmoAct2ProcessorKwargs],
266
+ ) -> BatchFeature:
267
+ """
268
+
269
+ Args:
270
+ text (`str`, `list[str]`, `list[list[str]]`):
271
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
272
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
273
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
274
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
275
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
276
+ tensor. Both channels-first and channels-last formats are supported.
277
+ videos (`dict[str, Any]` or `list[dict[str, Any]]`):
278
+ The video or batch of videos to be prepared. Each video can be a dictionary with the following keys:
279
+ - `"frames"`: `np.ndarray` of shape (T, H, W, 3)
280
+ - `"timestamps"`: `np.ndarray` of shape (T,)
281
+ - `"sampled_fps"`: `float` (optional)
282
+ - `"sampling_augmentation"`: `str` (optional)
283
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
284
+ If set, will return tensors of a particular framework. Acceptable values are:
285
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
286
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
287
+ - `'np'`: Return NumPy `np.ndarray` objects.
288
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
289
+
290
+ Returns:
291
+ `BatchFeature`: A [`BatchFeature`] with the following fields:
292
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
293
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
294
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`).
295
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
296
+ - **image_token_pooling** -- Indices of the patches in `image_grids` to pool for each token in `image_tokens`.
297
+ Returned when `images` is not `None`.
298
+ - **image_grids** -- Grids of images. Returned when `images` is not `None`.
299
+ - **image_num_crops** -- Number of crops for each image. Returned when `images` is not `None`.
300
+ - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
301
+ - **video_token_pooling** -- Indices of the patches in `video_grids` to pool for each token in `video_tokens`.
302
+ Returned when `videos` is not `None`.
303
+ - **video_grids** -- Grids of videos. Returned when `videos` is not `None`.
304
+ """
305
+
306
+ output_kwargs = self._merge_kwargs(
307
+ MolmoAct2ProcessorKwargs,
308
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
309
+ **kwargs,
310
+ )
311
+
312
+ if images is not None:
313
+ image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
314
+ image_grids = image_inputs["image_grids"]
315
+ else:
316
+ image_inputs = {}
317
+ image_grids = None
318
+
319
+ if videos is not None:
320
+ videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
321
+ video_grids = videos_inputs["video_grids"]
322
+ # If user has not requested video metadata, pop it
323
+ if "return_metadata" not in kwargs:
324
+ video_metadata = videos_inputs.pop("video_metadata")
325
+ else:
326
+ video_metadata = videos_inputs["video_metadata"]
327
+ else:
328
+ videos_inputs = {}
329
+ video_grids = None
330
+
331
+ if not isinstance(text, list):
332
+ text = [text]
333
+
334
+ text = text.copy() # below lines change text in-place
335
+
336
+ if image_grids is not None:
337
+ index = 0
338
+ for i in range(len(text)):
339
+ num_images = text[i].count(self.image_placeholder_token)
340
+ image_grids_i = image_grids[index:index+num_images]
341
+ for image_grid in image_grids_i:
342
+ image_tokens = self.get_image_tokens(image_grid)
343
+ image_string = "".join(image_tokens)
344
+ text[i] = text[i].replace(self.image_placeholder_token, image_string, 1)
345
+ index += num_images
346
+
347
+ if video_grids is not None:
348
+ index = 0
349
+ for i in range(len(text)):
350
+ num_videos = text[i].count(self.video_placeholder_token)
351
+ assert num_videos in {0, 1}, "At most one video is supported for now"
352
+ video_grids_i = video_grids[index:index+num_videos]
353
+ metadata_i = video_metadata[index:index+num_videos]
354
+ for video_grid, metadata in zip(video_grids_i, metadata_i):
355
+ video_string = self.get_video_string(
356
+ video_grid,
357
+ metadata.timestamps,
358
+ )
359
+ text[i] = text[i].replace(self.video_placeholder_token, video_string, 1)
360
+ index += num_videos
361
+
362
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
363
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
364
+ text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
365
+
366
+ input_ids = text_inputs["input_ids"]
367
+ attention_mask = text_inputs["attention_mask"]
368
+
369
+ input_ids = np.array(input_ids)
370
+ attention_mask = np.array(attention_mask)
371
+
372
+ bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
373
+ input_ids, attention_mask = self.insert_bos(
374
+ input_ids, attention_mask, bos, self.tokenizer.pad_token_id
375
+ )
376
+
377
+ if return_mm_token_type_ids:
378
+ image_tokens = np.array(self.image_token_ids).astype(input_ids.dtype)
379
+ token_type_ids = np.any(input_ids[:, :, None] == image_tokens[None, None, :], axis=-1)
380
+ text_inputs["token_type_ids"] = token_type_ids.tolist()
381
+
382
+ text_inputs["input_ids"] = input_ids.tolist()
383
+ text_inputs["attention_mask"] = attention_mask.tolist()
384
+
385
+ return BatchFeature(
386
+ data={**text_inputs, **image_inputs, **videos_inputs},
387
+ tensor_type=return_tensors,
388
+ )
389
+
390
+ def post_process_image_text_to_text(
391
+ self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
392
+ ):
393
+ """
394
+ Post-process the output of the model to decode the text.
395
+
396
+ Args:
397
+ generated_outputs (`torch.Tensor` or `np.ndarray`):
398
+ The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
399
+ or `(sequence_length,)`.
400
+ skip_special_tokens (`bool`, *optional*, defaults to `True`):
401
+ Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
402
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
403
+ Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
404
+ **kwargs:
405
+ Additional arguments to be passed to the tokenizer's `batch_decode method`.
406
+
407
+ Returns:
408
+ `list[str]`: The decoded text.
409
+ """
410
+ return self.tokenizer.batch_decode(
411
+ generated_outputs,
412
+ skip_special_tokens=skip_special_tokens,
413
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
414
+ **kwargs,
415
+ )
416
+
417
+
418
+ MolmoAct2Processor.register_for_auto_class()
processor_config.json ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_molmoact2.MolmoAct2Processor"
4
+ },
5
+ "image_processor": {
6
+ "auto_map": {
7
+ "AutoImageProcessor": "image_processing_molmoact2.MolmoAct2ImageProcessor",
8
+ "AutoProcessor": "processing_molmoact2.MolmoAct2Processor"
9
+ },
10
+ "crop_mode": "resize",
11
+ "do_convert_rgb": true,
12
+ "image_mean": [
13
+ 0.5,
14
+ 0.5,
15
+ 0.5
16
+ ],
17
+ "image_processor_type": "MolmoAct2ImageProcessor",
18
+ "image_std": [
19
+ 0.5,
20
+ 0.5,
21
+ 0.5
22
+ ],
23
+ "max_crops": 8,
24
+ "overlap_margins": [
25
+ 4,
26
+ 4
27
+ ],
28
+ "patch_size": 14,
29
+ "pooling_size": [
30
+ 2,
31
+ 2
32
+ ],
33
+ "resample": 2,
34
+ "size": {
35
+ "height": 378,
36
+ "width": 378
37
+ }
38
+ },
39
+ "image_use_col_tokens": true,
40
+ "processor_class": "MolmoAct2Processor",
41
+ "use_frame_special_tokens": true,
42
+ "use_single_crop_col_tokens": false,
43
+ "use_single_crop_start_token": true,
44
+ "video_processor": {
45
+ "auto_map": {
46
+ "AutoProcessor": "processing_molmoact2.MolmoAct2Processor",
47
+ "AutoVideoProcessor": "video_processing_molmoact2.MolmoAct2VideoProcessor"
48
+ },
49
+ "data_format": "channels_first",
50
+ "default_to_square": true,
51
+ "do_convert_rgb": true,
52
+ "do_normalize": true,
53
+ "do_rescale": true,
54
+ "do_resize": true,
55
+ "do_sample_frames": true,
56
+ "frame_sample_mode": "uniform_last_frame",
57
+ "image_mean": [
58
+ 0.5,
59
+ 0.5,
60
+ 0.5
61
+ ],
62
+ "image_std": [
63
+ 0.5,
64
+ 0.5,
65
+ 0.5
66
+ ],
67
+ "max_fps": 2.0,
68
+ "num_frames": 8,
69
+ "patch_size": 14,
70
+ "pooling_size": [
71
+ 3,
72
+ 3
73
+ ],
74
+ "resample": 2,
75
+ "rescale_factor": 0.00392156862745098,
76
+ "return_metadata": false,
77
+ "sampling_fps": 2,
78
+ "size": {
79
+ "height": 378,
80
+ "width": 378
81
+ },
82
+ "video_processor_type": "MolmoAct2VideoProcessor"
83
+ },
84
+ "video_use_col_tokens": false
85
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d5395aefc9b1b7f0385d8c86a2f1775e5af81bdfbf9f2d97827ea37921d9f862
3
+ size 11983605
tokenizer_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "auto_map": {
4
+ "AutoProcessor": "processing_molmoact2.MolmoAct2Processor"
5
+ },
6
+ "backend": "tokenizers",
7
+ "bos_token": "<|im_end|>",
8
+ "clean_up_tokenization_spaces": false,
9
+ "eos_token": "<|im_end|>",
10
+ "errors": "replace",
11
+ "extra_special_tokens": [
12
+ "<im_start>",
13
+ "<im_end>",
14
+ "<im_patch>",
15
+ "<im_col>",
16
+ "<low_res_im_start>",
17
+ "<|image|>",
18
+ "<im_low>",
19
+ "<frame_start>",
20
+ "<frame_end>",
21
+ "<|video|>",
22
+ "<|points|>",
23
+ "<|token_index|>",
24
+ "<|vit_index|>",
25
+ "<|vit_loc|>"
26
+ ],
27
+ "is_local": false,
28
+ "model_max_length": 1010000,
29
+ "pad_token": "<|endoftext|>",
30
+ "processor_class": "MolmoAct2Processor",
31
+ "split_special_tokens": false,
32
+ "tokenizer_class": "Qwen2Tokenizer",
33
+ "unk_token": null
34
+ }
video_processing_molmoact2.py ADDED
@@ -0,0 +1,969 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Video processor class for MolmoAct2"""
2
+ from functools import partial
3
+ import os
4
+ import warnings
5
+ from contextlib import redirect_stdout
6
+ from io import BytesIO
7
+ from urllib.parse import urlparse
8
+ from typing import Optional, Union, Callable
9
+
10
+ import numpy as np
11
+ import requests
12
+ import einops
13
+ import torch
14
+ import torchvision.transforms
15
+
16
+ from transformers.image_utils import (
17
+ IMAGENET_STANDARD_MEAN,
18
+ IMAGENET_STANDARD_STD,
19
+ ImageInput,
20
+ PILImageResampling,
21
+ SizeDict,
22
+ validate_kwargs,
23
+ )
24
+ from transformers.video_utils import (
25
+ VideoInput,
26
+ is_valid_video,
27
+ make_batched_videos,
28
+ make_batched_metadata,
29
+ VideoMetadata,
30
+ )
31
+ from transformers.processing_utils import Unpack, VideosKwargs
32
+ from transformers.video_processing_utils import BaseVideoProcessor
33
+ from transformers.utils import logging
34
+ from transformers.feature_extraction_utils import BatchFeature
35
+ from transformers.utils import (
36
+ is_av_available,
37
+ is_decord_available,
38
+ is_torchcodec_available,
39
+ is_yt_dlp_available,
40
+ TensorType,
41
+ logging,
42
+ to_numpy,
43
+ )
44
+
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+ MAX_VIDEO_FPS = 8
49
+
50
+
51
+ def normalize_image(
52
+ image: np.ndarray,
53
+ image_mean: list[float],
54
+ image_std: list[float],
55
+ ) -> np.ndarray:
56
+ if np.allclose(image_mean, [0.5, 0.5, 0.5]) and np.allclose(image_std, [0.5, 0.5, 0.5]):
57
+ return image * np.asarray(2.0, dtype=np.float32) - np.asarray(1.0, dtype=np.float32)
58
+ image -= np.array(image_mean, dtype=np.float32)[None, None, :]
59
+ image /= np.array(image_std, dtype=np.float32)[None, None, :]
60
+ return image
61
+
62
+
63
+ def resize_image(
64
+ image: np.ndarray,
65
+ desired_output_size: list[int],
66
+ resample: PILImageResampling,
67
+ ) -> np.ndarray:
68
+ if len(image.shape) == 3:
69
+ is_video = False
70
+ image = torch.permute(torch.from_numpy(image), [2, 0, 1])
71
+ else:
72
+ is_video = True
73
+ image = torch.permute(torch.from_numpy(image), [0, 3, 1, 2])
74
+ dtype = image.dtype
75
+ if torch.is_floating_point(image):
76
+ in_min = 0.0
77
+ in_max = 1.0
78
+ resized = torchvision.transforms.Resize(
79
+ desired_output_size,
80
+ resample,
81
+ antialias=False,
82
+ )(image)
83
+ resized = torch.clip(resized, 0.0, 1.0).to(dtype)
84
+ else:
85
+ assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(image.dtype)
86
+ in_min = 0.0
87
+ in_max = 255.0
88
+ resized = torchvision.transforms.Resize(
89
+ desired_output_size,
90
+ resample,
91
+ antialias=False,
92
+ )(image)
93
+ resized = torch.clip(resized, 0, 255).to(dtype)
94
+
95
+ resized = resized.to(torch.float32)
96
+ resized = (resized - in_min) / (in_max - in_min)
97
+
98
+ if is_video:
99
+ resized = torch.permute(resized, [0, 2, 3, 1]).numpy()
100
+ else:
101
+ resized = torch.permute(resized, [1, 2, 0]).numpy()
102
+
103
+ return resized
104
+
105
+
106
+ def build_resized_image(
107
+ image: np.ndarray,
108
+ base_image_input_size: list[int],
109
+ resample: PILImageResampling,
110
+ image_mean: list[float],
111
+ image_std: list[float],
112
+ image_patch_size: int,
113
+ ) -> tuple[np.ndarray, np.ndarray]:
114
+ resized = resize_image(
115
+ image, base_image_input_size, resample,
116
+ )
117
+ resized = normalize_image(resized, image_mean, image_std)
118
+ if len(resized.shape) == 3:
119
+ resized = np.expand_dims(resized, 0)
120
+ crop_patch_w = base_image_input_size[1] // image_patch_size
121
+ crop_patch_h = base_image_input_size[0] // image_patch_size
122
+ resize_idx = np.arange(crop_patch_w*crop_patch_h).reshape([crop_patch_h, crop_patch_w])
123
+ return resized, resize_idx
124
+
125
+
126
+ def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
127
+ """Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
128
+ if len(array.shape) == 3:
129
+ n_crops, h, w = array.shape
130
+ h_patches = h//patch_size
131
+ w_patches = w//patch_size
132
+ array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
133
+ array = np.transpose(array, [0, 1, 3, 2, 4])
134
+ array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size])
135
+ return array
136
+ else:
137
+ n_crops, h, w, c = array.shape
138
+ h_patches = h//patch_size
139
+ w_patches = w//patch_size
140
+ array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
141
+ array = np.transpose(array, [0, 1, 3, 2, 4, 5])
142
+ array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size*c])
143
+ return array
144
+
145
+
146
+ def arange_for_pooling(
147
+ idx_arr: np.ndarray,
148
+ pool_h: int,
149
+ pool_w: int,
150
+ ) -> np.ndarray:
151
+ h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0]
152
+ w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1]
153
+ idx_arr = np.pad(idx_arr, [[h_pad//2, (h_pad+1)//2], [w_pad//2, (w_pad+1)//2]],
154
+ mode='constant',constant_values=-1)
155
+ return einops.rearrange(
156
+ idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
157
+
158
+
159
+ def image_to_patches_and_grids(
160
+ image: ImageInput,
161
+ base_image_input_size: list[int],
162
+ resample: PILImageResampling,
163
+ image_mean: list[float],
164
+ image_std: list[float],
165
+ image_patch_size: int,
166
+ image_pooling_w: int,
167
+ image_pooling_h: int,
168
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
169
+ """
170
+ :return image_grids, the shape of each image after pooling
171
+ :return crops, the image crops to processes with the ViT
172
+ :return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the
173
+ patches in `crops` to pool for that token, masked with -1
174
+ """
175
+ if isinstance(base_image_input_size, int):
176
+ base_image_input_size = (base_image_input_size, base_image_input_size)
177
+
178
+ pooling_w = image_pooling_w
179
+ pooling_h = image_pooling_h
180
+
181
+ resized, resize_idx = build_resized_image(
182
+ image,
183
+ base_image_input_size,
184
+ resample,
185
+ image_mean,
186
+ image_std,
187
+ image_patch_size,
188
+ )
189
+ pooling_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
190
+ h, w = pooling_idx.shape[:2]
191
+ pooling_idx = pooling_idx.reshape([-1, pooling_h*pooling_w])
192
+ image_grid = [h, w]
193
+ return (
194
+ image_grid,
195
+ batch_pixels_to_patches(resized, image_patch_size),
196
+ pooling_idx,
197
+ )
198
+
199
+
200
+ def get_candidate_target_fps(
201
+ video_fps: Union[int, float],
202
+ sampling_fps: Union[int, float],
203
+ max_fps: Union[int, float] = MAX_VIDEO_FPS,
204
+ ) -> list[float]:
205
+ """
206
+ Return the subset of `video_fps` factors that remain multiples of `sampling_fps`.
207
+
208
+ Examples:
209
+ >>> get_candidate_target_fps(video_fps=6, sampling_fps=2)
210
+ [2, 6]
211
+ >>> get_candidate_target_fps(video_fps=5, sampling_fps=1)
212
+ [1, 5]
213
+ >>> get_candidate_target_fps(video_fps=2, sampling_fps=2)
214
+ [2]
215
+ >>> get_candidate_target_fps(video_fps=5, sampling_fps=2)
216
+ Traceback (most recent call last):
217
+ ...
218
+ ValueError: sampling_fps=2 must divide video_fps=5 to produce consistent frame steps.
219
+ """
220
+ video_fps = int(video_fps)
221
+ sampling_fps = int(sampling_fps)
222
+ max_fps = int(max_fps)
223
+
224
+ if sampling_fps is None:
225
+ raise ValueError("sampling_fps must be provided")
226
+ if video_fps <= 0 or sampling_fps <= 0:
227
+ raise ValueError(f"video_fps and sampling_fps must be positive (got {video_fps}, {sampling_fps})")
228
+ if video_fps % sampling_fps != 0:
229
+ raise ValueError(f"sampling_fps={sampling_fps} must divide video_fps={video_fps}.")
230
+
231
+ candidates = []
232
+ for candidate in range(sampling_fps, video_fps + 1, sampling_fps):
233
+ if candidate > max_fps:
234
+ break
235
+ if video_fps % candidate == 0:
236
+ candidates.append(float(candidate))
237
+
238
+ return candidates
239
+
240
+
241
+ def read_video_decord(
242
+ video_path,
243
+ sample_timestamps_fn: Callable,
244
+ **kwargs,
245
+ ) -> np.ndarray:
246
+ """
247
+ Decode a video using the Decord backend.
248
+
249
+ Args:
250
+ video_path (`str`):
251
+ Path to the video file.
252
+ sample_timestamps_fn (`Callable`):
253
+ A callable function that will return timestamps at which the video should be sampled.
254
+
255
+ Returns:
256
+ tuple[`np.array`, `VideoMetadata`]: A tuple containing:
257
+ - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
258
+ - `VideoMetadata` object.
259
+ """
260
+ # Lazy import from decord
261
+ import importlib
262
+ decord = importlib.import_module("decord")
263
+
264
+ vr = decord.VideoReader(uri=video_path, ctx=decord.cpu(0)) # decord has problems with gpu
265
+ video_fps = vr.get_avg_fps()
266
+ total_num_frames = len(vr)
267
+ time_stamps = vr.get_frame_timestamp(list(range(len(vr))))
268
+ duration = time_stamps[-1][1] - time_stamps[0][0]
269
+
270
+ metadata = VideoMetadata(
271
+ total_num_frames=int(total_num_frames),
272
+ fps=float(video_fps),
273
+ duration=float(duration),
274
+ video_backend="decord",
275
+ )
276
+
277
+ target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
278
+ target_timestamps = np.array(target_timestamps)
279
+ offset = time_stamps[0, 0]
280
+
281
+ ix = np.searchsorted(time_stamps[:, 1], target_timestamps + offset, side='right')
282
+ ix = np.minimum(ix, len(time_stamps) - 1)
283
+
284
+ video = vr.get_batch(ix).asnumpy()
285
+ metadata.update(
286
+ {
287
+ "frames_indices": target_timestamps * video_fps,
288
+ "height": video.shape[1],
289
+ "width": video.shape[2],
290
+ }
291
+ )
292
+ return video, metadata
293
+
294
+
295
+ def read_video_torchcodec(
296
+ video_path,
297
+ sample_timestamps_fn: Callable,
298
+ **kwargs,
299
+ ) -> np.ndarray:
300
+ """
301
+ Decode a video using torchcodec decoder.
302
+
303
+ Args:
304
+ video_path (`str`):
305
+ Path to the video file.
306
+ sample_timestamps_fn (`Callable`):
307
+ A callable function that will return timestamps at which the video should be sampled.
308
+
309
+ Returns:
310
+ tuple[`np.array`, `VideoMetadata`]: A tuple containing:
311
+ - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
312
+ - `VideoMetadata` object.
313
+ """
314
+ # Lazy import torchcodec
315
+ import importlib
316
+ torchcodec = importlib.import_module("torchcodec")
317
+
318
+ decoder = torchcodec.decoders.VideoDecoder(
319
+ video_path,
320
+ # Interestingly `exact` mode takes less than approximate when we load the whole video
321
+ seek_mode="exact",
322
+ # Allow FFmpeg decide on the number of threads for efficiency
323
+ num_ffmpeg_threads=0,
324
+ )
325
+ # If the first frame starts at > 0, we effectively clip the video starting at that time
326
+ # since (most) video players would also skip to that time
327
+ time_offset = decoder.metadata.begin_stream_seconds_from_content
328
+ # Note this duration does assume we started playing at `time_offset`
329
+ duration = decoder.metadata.duration_seconds
330
+
331
+ metadata = VideoMetadata(
332
+ total_num_frames=decoder.metadata.num_frames,
333
+ fps=decoder.metadata.average_fps,
334
+ duration=duration,
335
+ video_backend="torchcodec",
336
+ height=decoder.metadata.height,
337
+ width=decoder.metadata.width,
338
+ )
339
+
340
+ target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
341
+
342
+ # Floating point/rounding issues might cause `target_timestamps` to be very slightly
343
+ # out-of-bounds, to handle this we sanity check then clip them
344
+ assert all(x >= 0 for x in target_timestamps)
345
+ assert all(x < duration+1e-6 for x in target_timestamps)
346
+ # 1e-6 padding since torchcodec can throw out-of-bounds errors even if you ask for the
347
+ # exact boundary value, we should still get the first/last frame anyway
348
+ max_timestamp = decoder.metadata.end_stream_seconds_from_content - 1e-6
349
+ min_timestamp = decoder.metadata.begin_stream_seconds_from_content + 1e-6
350
+ # Note we avoid using numpy ops here to reduce floating precision issues
351
+ timestamps = [x + time_offset for x in target_timestamps]
352
+ timestamps = [max(min_timestamp, min(max_timestamp, x)) for x in timestamps]
353
+
354
+ video = decoder.get_frames_played_at(timestamps).data.numpy().transpose(0, 2, 3, 1) # Convert to THWC format
355
+ target_timestamps = np.array(target_timestamps)
356
+ metadata.frames_indices = target_timestamps * metadata.fps
357
+
358
+ return video, metadata
359
+
360
+
361
+ def read_video_pyav(
362
+ video_path,
363
+ sample_timestamps_fn: Callable,
364
+ **kwargs,
365
+ ) -> np.ndarray:
366
+ """
367
+ Decode a video using the PyAV backend.
368
+
369
+ Args:
370
+ video_path (`str`):
371
+ Path to the video file.
372
+ sample_timestamps_fn (`Callable`):
373
+ A callable function that will return timestamps at which the video should be sampled.
374
+
375
+ Returns:
376
+ tuple[`np.array`, `VideoMetadata`]: A tuple containing:
377
+ - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
378
+ - `VideoMetadata` object.
379
+ """
380
+ # Lazy import torchcodec
381
+ import importlib
382
+ av = importlib.import_module("av")
383
+
384
+ with av.open(video_path) as container:
385
+ video_stream = container.streams.video[0]
386
+ fps = video_stream.average_rate or video_stream.guessed_rate
387
+ it = container.decode(video=0)
388
+ frames = list(it)
389
+
390
+ stream = container.streams.video[0]
391
+ start = frames[0].pts * stream.time_base
392
+ container_end = stream.duration
393
+ if container_end is not None:
394
+ container_end *= stream.time_base
395
+ if container_end is None or container_end < frames[-1].pts:
396
+ # Some problem with stream duration, so use the frame PTS directly
397
+ # and guess the duration of the last frame
398
+ end = frames[-1].pts * stream.time_base + 1/fps
399
+ else:
400
+ end = container_end
401
+ duration = float(end - start)
402
+
403
+ metadata = VideoMetadata(
404
+ total_num_frames=len(frames),
405
+ fps=float(fps),
406
+ duration=float(duration),
407
+ video_backend="pyav",
408
+ height=video_stream.height,
409
+ width=video_stream.width,
410
+ )
411
+
412
+ target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
413
+ offset = float(start)
414
+
415
+ target_timestamps = np.array(target_timestamps)
416
+ end_time_stamps = np.array([float(frame.pts * stream.time_base) for frame in frames[1:]] + [duration])
417
+ indices = np.searchsorted(end_time_stamps, target_timestamps + offset, side='right')
418
+ indices = np.minimum(indices, len(end_time_stamps) - 1)
419
+
420
+ video = np.stack(
421
+ [frames[i].to_ndarray(format="rgb24", channel_last=True) for i in indices],
422
+ axis=0,
423
+ )
424
+
425
+ metadata.frames_indices = target_timestamps * fps
426
+
427
+ return video, metadata
428
+
429
+
430
+ VIDEO_DECODERS = {
431
+ "decord": read_video_decord,
432
+ "torchcodec": read_video_torchcodec,
433
+ "pyav": read_video_pyav,
434
+ }
435
+
436
+
437
+ def load_video(
438
+ video: VideoInput,
439
+ backend: str = "decord",
440
+ sample_timestamps_fn: Optional[Callable] = None,
441
+ **kwargs,
442
+ ):
443
+ """
444
+ Loads `video` to a numpy array.
445
+
446
+ Args:
447
+ video (`VideoInput`):
448
+ The video to convert to the numpy array format. Can be a link to video or local path.
449
+ backend (`str`, *optional*, defaults to `"decord"`):
450
+ The backend to use when loading the video. Can be any of ["decord", "pyav", ""torchcodec"]. Defaults to "decord".
451
+ sample_timestamps_fn (`Callable`):
452
+ A callable function that will return timestamps at which the video should be sampled.
453
+ """
454
+
455
+ # Early exit if provided an array or `PIL` frames
456
+ if not isinstance(video, str):
457
+ metadata = [None] * len(video)
458
+ return video, metadata
459
+
460
+ if urlparse(video).netloc in ["www.youtube.com", "youtube.com"]:
461
+ if not is_yt_dlp_available():
462
+ raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.")
463
+ # Lazy import from yt_dlp
464
+ import importlib
465
+ yt_dlp = importlib.import_module("yt_dlp")
466
+
467
+ buffer = BytesIO()
468
+ with redirect_stdout(buffer), yt_dlp.YoutubeDL() as f:
469
+ f.download([video])
470
+ bytes_obj = buffer.getvalue()
471
+ file_obj = BytesIO(bytes_obj)
472
+ elif video.startswith("http://") or video.startswith("https://"):
473
+ file_obj = BytesIO(requests.get(video).content)
474
+ elif os.path.isfile(video):
475
+ file_obj = video
476
+ else:
477
+ raise TypeError("Incorrect format used for video. Should be an url linking to an video or a local path.")
478
+
479
+ # can also load with decord, but not cv2/torchvision
480
+ # both will fail in case of url links
481
+ video_is_url = video.startswith("http://") or video.startswith("https://")
482
+ if video_is_url and backend == "opencv":
483
+ raise ValueError("If you are trying to load a video from URL, you cannot use 'opencv' as backend")
484
+
485
+ if (
486
+ (not is_decord_available() and backend == "decord")
487
+ or (not is_torchcodec_available() and backend == "torchcodec")
488
+ or (not is_av_available() and backend == "pyav")
489
+ ):
490
+ raise ImportError(
491
+ f"You chose backend={backend} for loading the video but the required library is not found in your environment "
492
+ f"Make sure to install {backend} before loading the video."
493
+ )
494
+
495
+ video_decoder = VIDEO_DECODERS[backend]
496
+ video, metadata = video_decoder(file_obj, sample_timestamps_fn, **kwargs)
497
+ return video, metadata
498
+
499
+
500
+ def get_target_fps(
501
+ video_fps: float,
502
+ max_frames: int,
503
+ total_frames: int,
504
+ frame_sample_mode: str,
505
+ candidate_target_fps: tuple[float],
506
+ ) -> float:
507
+ """
508
+ Get the target fps that best spans the video and has the most frames sampled
509
+ """
510
+ num_frames_sampled = 0
511
+ selected_target_fps = None
512
+ for target_fps in candidate_target_fps:
513
+ step_size = max(int(video_fps / target_fps), 1)
514
+ num_frames_sampled_at_fps = int(total_frames / step_size)
515
+ if num_frames_sampled == 0:
516
+ if "uniform" in frame_sample_mode:
517
+ if num_frames_sampled_at_fps > max_frames:
518
+ break
519
+ selected_target_fps = target_fps
520
+ num_frames_sampled = num_frames_sampled_at_fps
521
+
522
+ else:
523
+ # the candidate sampling fps increases so frame count can't decrease
524
+ assert num_frames_sampled <= num_frames_sampled_at_fps
525
+ if num_frames_sampled_at_fps > max_frames:
526
+ # choose the sampling fps that spans the video
527
+ continue
528
+
529
+ elif num_frames_sampled_at_fps > num_frames_sampled:
530
+ # both are less than max_frames, choose the one with higher density of frames sampled
531
+ selected_target_fps = target_fps
532
+ num_frames_sampled = num_frames_sampled_at_fps
533
+ return selected_target_fps
534
+
535
+
536
+ def get_frame_times_and_chosen_fps(
537
+ selected_target_fps,
538
+ total_frames,
539
+ max_frames,
540
+ video_fps
541
+ ):
542
+ if selected_target_fps is None:
543
+ frame_indices = np.linspace(0, total_frames, max_frames, endpoint=False, dtype=int)
544
+ else:
545
+ step_size = max(int(video_fps / selected_target_fps), 1)
546
+ frame_indices = np.arange(0, total_frames, step_size)
547
+ if len(frame_indices) > max_frames:
548
+ frame_indices = frame_indices[:max_frames]
549
+ return selected_target_fps, frame_indices
550
+
551
+
552
+ class MolmoAct2VideoProcessorKwargs(VideosKwargs, total=False):
553
+ patch_size: Optional[int]
554
+ pooling_size: Optional[list[int]]
555
+ frame_sample_mode: Optional[str]
556
+ max_fps: Optional[int]
557
+ sampling_fps: Optional[int]
558
+
559
+
560
+ class MolmoAct2VideoProcessor(BaseVideoProcessor):
561
+ resample = PILImageResampling.BILINEAR
562
+ size = {"height": 378, "width": 378}
563
+ image_mean = IMAGENET_STANDARD_MEAN
564
+ image_std = IMAGENET_STANDARD_STD
565
+ do_resize = True
566
+ do_rescale = True
567
+ do_normalize = True
568
+ do_convert_rgb = True
569
+ patch_size = 14
570
+ pooling_size = [3, 3]
571
+ do_sample_frames = True
572
+ frame_sample_mode = "uniform_last_frame"
573
+ max_fps = 2
574
+ sampling_fps = 2
575
+ valid_kwargs = MolmoAct2VideoProcessorKwargs
576
+ model_input_names = ["pixel_values_videos", "video_token_pooling", "video_grids"]
577
+
578
+ def __init__(self, **kwargs: Unpack[MolmoAct2VideoProcessorKwargs]):
579
+ super().__init__(**kwargs)
580
+ if self.size is not None and (
581
+ self.size.get("height", None) is None or self.size.get("width", None) is None
582
+ ):
583
+ raise ValueError("size must contain 'height' and 'width' keys.")
584
+
585
+ def _further_process_kwargs(
586
+ self,
587
+ size: Optional[SizeDict] = None,
588
+ **kwargs,
589
+ ) -> dict:
590
+ """
591
+ Update kwargs that need further processing before being validated
592
+ Can be overridden by subclasses to customize the processing of kwargs.
593
+ """
594
+ if size is not None and ("height" not in size or "width" not in size):
595
+ raise ValueError("size must contain 'height' and 'width' keys.")
596
+
597
+ return super()._further_process_kwargs(size=size, **kwargs)
598
+
599
+ def sample_times(
600
+ self,
601
+ metadata: VideoMetadata,
602
+ frame_sample_mode: str,
603
+ num_frames: int,
604
+ max_fps: Optional[int] = None,
605
+ sampling_fps: Optional[int] = None,
606
+ **kwargs,
607
+ ) -> np.ndarray:
608
+ """
609
+ Time-based sampling if an array video is passed
610
+ Args:
611
+ metadata (`VideoMetadata`):
612
+ Metadata of the video containing information about total duration, fps and total number of frames.
613
+ frame_sample_mode (`str`, *optional*):
614
+ Mode to sample frames. Defaults to `self.frame_sample_mode`.
615
+ num_frames (`int`, *optional*):
616
+ Maximum number of frames to sample. Defaults to `self.num_frames`.
617
+ man_fps (`int`, *optional*):
618
+ Maximum frames per second to sample.
619
+ sampling_fps (`int`, *optional*):
620
+ Sampling frames per second. Defaults to `self.sampling_fps`.
621
+ Used when `frame_sample_mode` is `"fps"`.
622
+ """
623
+ frame_sample_mode = frame_sample_mode or self.frame_sample_mode
624
+ num_frames = num_frames or self.num_frames
625
+ sampling_fps = sampling_fps or self.sampling_fps
626
+
627
+ duration = metadata.duration or metadata.total_num_frames / metadata.fps
628
+ if frame_sample_mode == "fps":
629
+ candidate_target_fps = get_candidate_target_fps(metadata.fps, sampling_fps)
630
+ # Try larger and larger FPSs until we hit one that can't span the video
631
+ target_fps = candidate_target_fps[0]
632
+ for candidate_fps in candidate_target_fps[1:]:
633
+ if num_frames / candidate_fps < duration:
634
+ break
635
+ target_fps = candidate_fps
636
+ times = np.arange(0, num_frames) / target_fps
637
+ times = times[times < duration]
638
+ return times
639
+ elif frame_sample_mode == "uniform_last_frame":
640
+ if max_fps is not None:
641
+ max_duration = (num_frames-1) / max_fps # -1 to include the last frame
642
+ if max_duration < duration:
643
+ times = np.linspace(
644
+ 0, duration, num=num_frames, endpoint=True, dtype=np.float64
645
+ )
646
+ else:
647
+ times = np.arange(0.0, stop=duration, step=1/max_fps)
648
+ times = np.concatenate([times, [duration]], axis=0)
649
+ assert len(times) <= num_frames
650
+ else:
651
+ times = np.linspace(
652
+ 0, duration, num=num_frames, endpoint=True, dtype=np.float64
653
+ )
654
+ return times
655
+ else:
656
+ raise NotImplementedError(frame_sample_mode)
657
+
658
+ def sample_frames(
659
+ self,
660
+ metadata: VideoMetadata,
661
+ frame_sample_mode: Optional[str] = None,
662
+ num_frames: Optional[int] = None,
663
+ max_fps: Optional[int] = None,
664
+ sampling_fps: Optional[int] = None,
665
+ **kwargs,
666
+ ) -> np.ndarray:
667
+ """
668
+ Frame-based sampling if an array video is passed
669
+ Args:
670
+ metadata (`VideoMetadata`):
671
+ Metadata of the video containing information about total duration, fps and total number of frames.
672
+ frame_sample_mode (`str`, *optional*):
673
+ Mode to sample frames. Defaults to `self.frame_sample_mode`.
674
+ num_frames (`int`, *optional*):
675
+ Maximum number of frames to sample. Defaults to `self.num_frames`.
676
+ max_fps (`int`, *optional*):
677
+ Maximum frames per second to sample.
678
+ sampling_fps (`int`, *optional*):
679
+ Sampling frames per second. Defaults to `self.sampling_fps`.
680
+ Used when `frame_sample_mode` is `"fps"`.
681
+ """
682
+ frame_sample_mode = frame_sample_mode or self.frame_sample_mode
683
+ num_frames = num_frames or self.num_frames
684
+ sampling_fps = sampling_fps or self.sampling_fps
685
+
686
+ total_num_frames = metadata.total_num_frames
687
+ if frame_sample_mode == "uniform_last_frame" and max_fps is not None:
688
+ duration = total_num_frames / metadata.fps
689
+ if total_num_frames <= 2:
690
+ return np.arange(total_num_frames).astype(int)
691
+ if duration > (num_frames - 1) / max_fps: # -1 to include the last frame
692
+ # uniform fallback
693
+ indices = np.linspace(
694
+ 0,
695
+ total_num_frames - 1,
696
+ num=min(num_frames, total_num_frames),
697
+ endpoint=True,
698
+ ).astype(int)
699
+ return indices
700
+ else:
701
+ float_indices = np.arange(
702
+ 0.0, stop=total_num_frames - 1, step=float(metadata.fps / max_fps),
703
+ )
704
+ if np.round(float_indices[-1]) != total_num_frames - 1:
705
+ float_indices = np.concatenate([float_indices, [total_num_frames - 1]], axis=0)
706
+ indices = np.round(float_indices).astype(int)
707
+ assert indices[-1] < total_num_frames
708
+ assert len(float_indices) <= num_frames
709
+ return indices
710
+ elif frame_sample_mode == "uniform_last_frame":
711
+ indices = np.linspace(
712
+ 0, total_num_frames - 1, num=min(num_frames, total_num_frames), endpoint=True,
713
+ ).astype(int)
714
+ return indices
715
+ elif frame_sample_mode == "fps":
716
+ candidate_target_fps = get_candidate_target_fps(metadata.fps, sampling_fps)
717
+ selected_target_fps = get_target_fps(
718
+ metadata.fps,
719
+ num_frames,
720
+ total_num_frames,
721
+ frame_sample_mode,
722
+ candidate_target_fps,
723
+ )
724
+ _, indices = get_frame_times_and_chosen_fps(
725
+ selected_target_fps,
726
+ total_num_frames,
727
+ num_frames,
728
+ metadata.fps,
729
+ )
730
+ return indices
731
+ else:
732
+ raise NotImplementedError(frame_sample_mode)
733
+
734
+ def fetch_videos(
735
+ self,
736
+ video_url_or_urls: Union[str, list[str], list[list[str]]],
737
+ sample_timestamps_fn=None
738
+ ):
739
+ """
740
+ Convert a single or a list of urls into the corresponding `np.array` objects.
741
+
742
+ If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
743
+ returned.
744
+ """
745
+ if (
746
+ (not is_decord_available())
747
+ and (not is_torchcodec_available())
748
+ and (not is_av_available())
749
+ ):
750
+ raise ImportError(
751
+ "MolmoAct2VideoProcessor requires `decord`, `torchcodec`, or `av` to be installed."
752
+ )
753
+
754
+ if is_decord_available():
755
+ backend = "decord"
756
+ elif is_torchcodec_available():
757
+ warnings.warn(
758
+ "`decord` is not installed and cannot be used to decode the video by default. "
759
+ "Falling back to `torchcodec`."
760
+ )
761
+ backend = "torchcodec"
762
+ else:
763
+ warnings.warn(
764
+ "`decord` is not installed and cannot be used to decode the video by default. "
765
+ "Falling back to `PyAV`."
766
+ )
767
+ backend = "pyav"
768
+
769
+ if isinstance(video_url_or_urls, list):
770
+ return list(zip(*[self.fetch_videos(x, sample_timestamps_fn=sample_timestamps_fn) for x in video_url_or_urls]))
771
+ else:
772
+ return load_video(video_url_or_urls, backend=backend, sample_timestamps_fn=sample_timestamps_fn)
773
+
774
+ def _decode_and_sample_videos(
775
+ self,
776
+ videos: VideoInput,
777
+ video_metadata: Union[VideoMetadata, dict],
778
+ do_sample_frames: Optional[bool] = None,
779
+ sample_indices_fn: Optional[Callable] = None,
780
+ sample_timestamps_fn: Optional[Callable] = None,
781
+ ):
782
+ """
783
+ Decode input videos and sample frames if needed.
784
+ """
785
+ videos = make_batched_videos(videos)
786
+ video_metadata = make_batched_metadata(videos, video_metadata=video_metadata)
787
+
788
+ # Framed-based sampling if an array video is passed
789
+ # Otherwise, time-based sampling with decoding
790
+ if is_valid_video(videos[0]) and do_sample_frames:
791
+ assert video_metadata[0].fps is not None, "FPS must be provided for video input"
792
+ sampled_videos = []
793
+ sampled_metadata = []
794
+ for video, metadata in zip(videos, video_metadata):
795
+ indices = sample_indices_fn(metadata=metadata)
796
+ metadata.frames_indices = indices
797
+ sampled_videos.append(video[indices])
798
+ sampled_metadata.append(metadata)
799
+ videos = sampled_videos
800
+ video_metadata = sampled_metadata
801
+ elif not is_valid_video(videos[0]):
802
+ if sample_indices_fn is None:
803
+ logger.warning(
804
+ "do_sample_frames is False, but video array is not provided: "
805
+ "Will decode the video and sample frames using MolmoAct2's default sampling mode"
806
+ )
807
+ if isinstance(videos[0], list):
808
+ raise ValueError(
809
+ "A list of images is not supported for video input!"
810
+ )
811
+ else:
812
+ videos, video_metadata = self.fetch_videos(videos, sample_timestamps_fn=sample_timestamps_fn)
813
+
814
+ return videos, video_metadata
815
+
816
+ def _prepare_input_videos(
817
+ self,
818
+ videos: VideoInput,
819
+ **kwargs,
820
+ ) -> list[np.ndarray]:
821
+ processed_videos = [to_numpy(video) for video in videos]
822
+ return processed_videos
823
+
824
+ def preprocess(
825
+ self,
826
+ videos: VideoInput,
827
+ **kwargs: Unpack[MolmoAct2VideoProcessorKwargs],
828
+ ) -> BatchFeature:
829
+ validate_kwargs(
830
+ captured_kwargs=kwargs.keys(),
831
+ valid_processor_keys=list(self.valid_kwargs.__annotations__.keys()) + ["return_tensors"],
832
+ )
833
+
834
+ # Set default kwargs from self. This ensures that if a kwarg is not provided
835
+ # by the user, it gets its default value from the instance, or is set to None.
836
+ for kwarg_name in self.valid_kwargs.__annotations__:
837
+ kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
838
+
839
+ do_sample_frames = kwargs.pop("do_sample_frames")
840
+ video_metadata = kwargs.pop("video_metadata")
841
+
842
+ sample_indices_fn = partial(self.sample_frames, **kwargs) if do_sample_frames else None
843
+ sample_timestamps_fn = partial(self.sample_times, **kwargs)
844
+ videos, video_metadata = self._decode_and_sample_videos(
845
+ videos,
846
+ video_metadata=video_metadata,
847
+ do_sample_frames=do_sample_frames,
848
+ sample_indices_fn=sample_indices_fn,
849
+ sample_timestamps_fn=sample_timestamps_fn,
850
+ )
851
+ videos = self._prepare_input_videos(videos=videos)
852
+
853
+ kwargs = self._further_process_kwargs(**kwargs)
854
+
855
+ return_metadata = kwargs.pop("return_metadata")
856
+ preprocessed_videos = self._preprocess(videos=videos, **kwargs)
857
+ if return_metadata:
858
+ preprocessed_videos["video_metadata"] = video_metadata
859
+ return preprocessed_videos
860
+
861
+ def _preprocess(
862
+ self,
863
+ videos: list[np.ndarray],
864
+ size: Optional[SizeDict] = None,
865
+ resample: Optional[PILImageResampling] = None,
866
+ image_mean: Optional[Union[float, list[float]]] = None,
867
+ image_std: Optional[Union[float, list[float]]] = None,
868
+ do_convert_rgb: Optional[bool] = None,
869
+ patch_size: Optional[int] = None,
870
+ pooling_size: Optional[list[int]] = None,
871
+ return_tensors: Optional[Union[str, TensorType]] = None,
872
+ **kwargs,
873
+ ) -> BatchFeature:
874
+ """
875
+ Preprocess a video for the model.
876
+ Args:
877
+ videos (`VideoInput`):
878
+ Video to preprocess.
879
+ size (`SizeDict`, *optional*, defaults to `self.size`):
880
+ Size of the image after resizing.
881
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
882
+ Resampling filter to use when resizing the image. This can be one of the enum `PILImageResampling`. Only
883
+ has an effect if `do_resize` is set to `True`.
884
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
885
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
886
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
887
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
888
+ `True`.
889
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
890
+ Whether to convert the image to RGB.
891
+ patch_size (`int`, *optional*, defaults to `self.patch_size`):
892
+ The spatial patch size of the vision encoder.
893
+ pooling_size (`list[int]`, *optional*, defaults to `self.pooling_size`):
894
+ The pooling size of the vision adapter.
895
+ return_tensors (`str` or `TensorType`, *optional*):
896
+ The type of tensors to return. Can be one of:
897
+ - Unset: Return a list of `np.ndarray`.
898
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
899
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
900
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
901
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
902
+
903
+ Returns:
904
+ A `BatchFeature` containing the following keys:
905
+ - `pixel_values_videos`: The preprocessed videos.
906
+ - `video_token_pooling`: The indices of the patches in `crops` to pool for each token in `video_tokens`.
907
+ - `video_grids`: The video grids.
908
+ """
909
+ if size.height is None or size.width is None:
910
+ raise ValueError("size must contain 'height' and 'width' keys.")
911
+
912
+ base_image_input_size = [size.height, size.width]
913
+
914
+ resample = resample or self.resample
915
+ image_mean = image_mean or self.image_mean
916
+ image_std = image_std or self.image_std
917
+ do_convert_rgb = do_convert_rgb or self.do_convert_rgb
918
+
919
+ patch_size = patch_size or self.patch_size
920
+ pooling_size = pooling_size or self.pooling_size
921
+
922
+ image_pooling_h, image_pooling_w = pooling_size
923
+
924
+ batch_grids = []
925
+ batch_crops = []
926
+ batch_pooled_patches_idx = []
927
+
928
+ for video in videos:
929
+ all_crops = []
930
+ pooled_patches_idx = []
931
+
932
+ for frame in video:
933
+ image_grid, crops, pooled_idx = image_to_patches_and_grids(
934
+ frame,
935
+ base_image_input_size,
936
+ resample,
937
+ image_mean,
938
+ image_std,
939
+ patch_size,
940
+ image_pooling_w,
941
+ image_pooling_h,
942
+ )
943
+ offset = sum(np.prod(x.shape[:2]) for x in all_crops)
944
+ pooled_idx_with_offset = np.where(pooled_idx >= 0, pooled_idx + offset, pooled_idx)
945
+ pooled_patches_idx.append(pooled_idx_with_offset)
946
+ all_crops.append(crops)
947
+
948
+ video_grid = np.array([len(video), image_grid[0], image_grid[1]])
949
+ all_crops = np.concatenate(all_crops, 0)
950
+ pooled_patches_idx = np.concatenate(pooled_patches_idx, 0)
951
+
952
+ batch_grids.append(video_grid)
953
+ batch_crops.append(all_crops)
954
+ batch_pooled_patches_idx.append(pooled_patches_idx)
955
+
956
+ video_grids = np.stack(batch_grids, 0)
957
+ pixel_values_videos = np.concatenate(batch_crops, 0)
958
+ video_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
959
+
960
+ data =dict(
961
+ pixel_values_videos=pixel_values_videos,
962
+ video_token_pooling=video_token_pooling,
963
+ video_grids=video_grids,
964
+ )
965
+
966
+ return BatchFeature(data, tensor_type=return_tensors)
967
+
968
+
969
+ MolmoAct2VideoProcessor.register_for_auto_class()