Add files using upload-large-folder tool
Browse files- .gitattributes +1 -0
- README.md +159 -3
- assets/MolmoAct2.svg +26 -0
- assets/sample_agentview_rgb.png +0 -0
- assets/sample_wrist_rgb.png +0 -0
- chat_template.jinja +1 -0
- config.json +160 -0
- configuration_molmoact2.py +565 -0
- generation_config.json +6 -0
- image_processing_molmoact2.py +546 -0
- inference.py +768 -0
- model-00001-of-00005.safetensors +3 -0
- model-00002-of-00005.safetensors +3 -0
- model-00003-of-00005.safetensors +3 -0
- model-00004-of-00005.safetensors +3 -0
- model-00005-of-00005.safetensors +3 -0
- model.safetensors.index.json +0 -0
- modeling_molmoact2.py +0 -0
- norm_stats.json +238 -0
- processing_molmoact2.py +418 -0
- processor_config.json +85 -0
- tokenizer.json +3 -0
- tokenizer_config.json +34 -0
- video_processing_molmoact2.py +969 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,3 +1,159 @@
|
|
| 1 |
-
---
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: transformers
|
| 3 |
+
tags:
|
| 4 |
+
- molmoact2
|
| 5 |
+
- robotics
|
| 6 |
+
- image-text-to-text
|
| 7 |
+
- libero
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
<img src="assets/MolmoAct2.svg" alt="MolmoAct Logo" style="width: auto; height: 50px;">
|
| 11 |
+
|
| 12 |
+
# **MolmoAct2-LIBERO**
|
| 13 |
+
|
| 14 |
+
MolmoAct2 is an open vision-language-action model for robot control. It builds on Molmo2-ER and attaches a flow-matching continuous action expert that conditions on the VLM key-value cache through a per-layer connection.
|
| 15 |
+
|
| 16 |
+
This checkpoint is fine-tuned on the full LIBERO training mixture, combining Spatial, Object, Goal, and Long suites. It is intended for both further fine-tuning and LIBERO policy inference.
|
| 17 |
+
|
| 18 |
+
## Quick Links
|
| 19 |
+
|
| 20 |
+
- 📂 Models: [Models](https://huggingface.co/collections/allenai/molmoact2-models), [Finetuned Models](https://huggingface.co/collections/allenai/molmoact2-finetuned-models)
|
| 21 |
+
- 📂 Datasets: [MolmoAct2-BimanualYAM Dataset](https://huggingface.co/collections/allenai/molmoact2-datasets), [MolmoAct2 Datasets](https://huggingface.co/collections/allenai/molmoact2-datasets), [Molmo2-ER Datasets](https://huggingface.co/collections/allenai/molmo2-er-datasets)
|
| 22 |
+
- 📄 Paper:
|
| 23 |
+
- 💻 Code: [allenai/molmoact2](https://github.com/allenai/molmoact2)
|
| 24 |
+
- 🎥 Blog Post: [MolmoAct2](https://allenai.org/blog/molmoact2)
|
| 25 |
+
|
| 26 |
+
## Intended Use
|
| 27 |
+
|
| 28 |
+
Use this checkpoint for LIBERO inference or for further fine-tuning. Dataset normalization metadata is stored in `norm_stats.json`; pass `norm_tag="libero"` at inference time.
|
| 29 |
+
|
| 30 |
+
Continuous action prediction is the intended and recommended inference mode. Discrete action prediction is exposed for parity and debugging, but we use continuous actions by default.
|
| 31 |
+
|
| 32 |
+
## Install
|
| 33 |
+
|
| 34 |
+
```bash
|
| 35 |
+
pip install torch transformers pillow numpy huggingface_hub
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
## Sample Input
|
| 39 |
+
|
| 40 |
+
This sample comes from `libero_10`, episode 0, frame 0. The LIBERO camera order is front/agent view followed by wrist view.
|
| 41 |
+
|
| 42 |
+
| Agentview RGB | Wrist RGB |
|
| 43 |
+
| --- | --- |
|
| 44 |
+
|  |  |
|
| 45 |
+
|
| 46 |
+
```python
|
| 47 |
+
from huggingface_hub import hf_hub_download
|
| 48 |
+
from PIL import Image
|
| 49 |
+
import numpy as np
|
| 50 |
+
|
| 51 |
+
repo_id = "allenai/MolmoAct2-LIBERO"
|
| 52 |
+
|
| 53 |
+
agentview_rgb = Image.open(
|
| 54 |
+
hf_hub_download(repo_id, "assets/sample_agentview_rgb.png")
|
| 55 |
+
).convert("RGB")
|
| 56 |
+
wrist_rgb = Image.open(
|
| 57 |
+
hf_hub_download(repo_id, "assets/sample_wrist_rgb.png")
|
| 58 |
+
).convert("RGB")
|
| 59 |
+
|
| 60 |
+
task = "put the white mug on the left plate and put the yellow and white mug on the right plate"
|
| 61 |
+
robot_state = np.array(
|
| 62 |
+
[
|
| 63 |
+
-0.05338004603981972,
|
| 64 |
+
0.007029631175100803,
|
| 65 |
+
0.6783280968666077,
|
| 66 |
+
3.1407692432403564,
|
| 67 |
+
0.0017593271331861615,
|
| 68 |
+
-0.08994418382644653,
|
| 69 |
+
0.03878866136074066,
|
| 70 |
+
-0.03878721222281456,
|
| 71 |
+
],
|
| 72 |
+
dtype=np.float32,
|
| 73 |
+
)
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
## Continuous Actions
|
| 77 |
+
|
| 78 |
+
```python
|
| 79 |
+
import numpy as np
|
| 80 |
+
import torch
|
| 81 |
+
from huggingface_hub import hf_hub_download
|
| 82 |
+
from PIL import Image
|
| 83 |
+
from transformers import AutoModelForImageTextToText, AutoProcessor
|
| 84 |
+
|
| 85 |
+
repo_id = "allenai/MolmoAct2-LIBERO"
|
| 86 |
+
|
| 87 |
+
agentview_rgb = Image.open(
|
| 88 |
+
hf_hub_download(repo_id, "assets/sample_agentview_rgb.png")
|
| 89 |
+
).convert("RGB")
|
| 90 |
+
wrist_rgb = Image.open(
|
| 91 |
+
hf_hub_download(repo_id, "assets/sample_wrist_rgb.png")
|
| 92 |
+
).convert("RGB")
|
| 93 |
+
task = "put the white mug on the left plate and put the yellow and white mug on the right plate"
|
| 94 |
+
robot_state = np.array(
|
| 95 |
+
[
|
| 96 |
+
-0.05338004603981972,
|
| 97 |
+
0.007029631175100803,
|
| 98 |
+
0.6783280968666077,
|
| 99 |
+
3.1407692432403564,
|
| 100 |
+
0.0017593271331861615,
|
| 101 |
+
-0.08994418382644653,
|
| 102 |
+
0.03878866136074066,
|
| 103 |
+
-0.03878721222281456,
|
| 104 |
+
],
|
| 105 |
+
dtype=np.float32,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
|
| 109 |
+
model = AutoModelForImageTextToText.from_pretrained(
|
| 110 |
+
repo_id,
|
| 111 |
+
trust_remote_code=True,
|
| 112 |
+
torch_dtype=torch.float32,
|
| 113 |
+
).to("cuda").eval()
|
| 114 |
+
|
| 115 |
+
out = model.predict_action(
|
| 116 |
+
processor=processor,
|
| 117 |
+
images=[agentview_rgb, wrist_rgb],
|
| 118 |
+
task=task,
|
| 119 |
+
state=robot_state,
|
| 120 |
+
norm_tag="libero",
|
| 121 |
+
action_mode="continuous",
|
| 122 |
+
enable_depth_reasoning=False,
|
| 123 |
+
num_steps=10,
|
| 124 |
+
normalize_language=True,
|
| 125 |
+
enable_cuda_graph=True,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
actions = out.actions
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
`images` should preserve camera order, for example `[agentview_rgb, wrist_rgb]`. Images may be PIL images or RGB arrays. `state` is the raw robot state, and actions are returned in robot scale.
|
| 132 |
+
|
| 133 |
+
`normalize_language=True` is the default. It lowercases the task string and removes trailing sentence punctuation to match training preprocessing. Set it to `False` if you need to preserve the task text exactly.
|
| 134 |
+
|
| 135 |
+
`enable_cuda_graph=True` is the default. The first few calls can be slow because the model warms up and captures CUDA graphs; run several random warm-up calls before measuring deployment latency. `num_steps` controls the continuous flow solver and defaults to the checkpoint config value, 10.
|
| 136 |
+
|
| 137 |
+
Depth reasoning is disabled for this checkpoint. Calling `enable_depth_reasoning=True` will raise an error.
|
| 138 |
+
|
| 139 |
+
## Discrete Actions
|
| 140 |
+
|
| 141 |
+
Discrete action inference requires a caller-provided action tokenizer. It is not saved in this repository. Discrete mode decodes action tokens directly; the continuous action expert is not used.
|
| 142 |
+
|
| 143 |
+
```python
|
| 144 |
+
action_tokenizer = AutoProcessor.from_pretrained(
|
| 145 |
+
"YOUR_ACTION_TOKENIZER_REPO",
|
| 146 |
+
trust_remote_code=True,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
out = model.predict_action(
|
| 150 |
+
processor=processor,
|
| 151 |
+
images=[agentview_rgb, wrist_rgb],
|
| 152 |
+
task=task,
|
| 153 |
+
state=robot_state,
|
| 154 |
+
norm_tag="libero",
|
| 155 |
+
action_mode="discrete",
|
| 156 |
+
action_tokenizer=action_tokenizer,
|
| 157 |
+
enable_depth_reasoning=False,
|
| 158 |
+
)
|
| 159 |
+
```
|
assets/MolmoAct2.svg
ADDED
|
|
assets/sample_agentview_rgb.png
ADDED
|
assets/sample_wrist_rgb.png
ADDED
|
chat_template.jinja
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{% set DEMO_STYLES = ['point_count','pointing','cosyn_point','user_qa','long_caption','short_caption','video_long_caption','video_short_caption','video_point_track_per_frame','video_point_track_start_end','video_point_track_all_frames','video_single_point_track_start_end','video_transcript','video_clip_caption_start_end','video_clip_caption_start_end_in_seconds','video_clip_transcript_start_end','video_clip_transcript_start_end_in_seconds','video_frame_caption_timestamp','video_frame_caption_timestamp_in_seconds','correction_qa','text_sft','video_point','video_point_count','video_count','video_count_point','multi_image_pointing','multi_image_counting','multi_image_point_then_count','multi_image_count_then_point','demo','a_okvqa_mc','ai2_diagram_no_letter','ai2_diagram','science_qa','multi_image_mc','multi_image_mc_exp','mantis_instruct_mc','video_multiple_choice','video_multiple_choice_count_without_pointing','video_multiple_choice_multiple_correct','video_multiple_choice_w_subtitle'] %}{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% set has_subtitle = messages and messages[0]['role'].lower() == 'subtitle' %}{% for message in messages %}{% if message['content'] is not string %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% elif content['type'] == 'video' or 'video' in content or 'video_url' in content %}{% set video_count.value = video_count.value + 1 %}{% endif %}{% endfor %}{% endif %}{% endfor %}{% if image_count.value == 1 %}{{ '<|image|>' }}{% elif image_count.value > 1 %}{% for i in range(image_count.value) %}{{ 'Image ' ~ (i + 1) ~ '<|image|>' }}{% endfor %}{% endif %}{% for _ in range(video_count.value) %}{{ '<|video|>' }}{% endfor %}{% if has_subtitle %}{{ messages[0]['content'] }}{% endif %}{% for message in messages %}{% set role = message['role'].lower() %}{% if role == 'subtitle' %}{% continue %}{% endif %}{% set conv_index = loop.index - (1 if has_subtitle else 0) %}{%- if (conv_index % 2 == 1 and role != 'user') or (conv_index % 2 == 0 and role != 'assistant') -%}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{%- endif -%}{% if message['content'] is string %}{% set text_content = message['content'] %}{% else %}{% set m = namespace(text='') %}{% for content in message['content'] %}{% if content['type'] == 'text' %}{% if content['style'] is defined and content['style'] not in DEMO_STYLES %}{% set seg = content['style'] ~ ': ' ~ content['text'] %}{% else %}{% set seg = content['text'] %}{% endif %}{% set m.text = m.text ~ ('' if not m.text else ' ') ~ seg %}{% endif %}{% endfor %}{% set text_content = m.text %}{% endif %}{% if role == 'user' %}{% if not (has_subtitle and loop.index == 2) and not (not has_subtitle and loop.first) %}{{ '<|im_end|>\n' }}{% endif %}{{ '<|im_start|>user\n' }}{{ text_content }}{{ '<|im_end|>\n' }}{% else %} {# assistant #}{{ '<|im_start|>assistant\n' }}{{ text_content }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}
|
config.json
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"action_end_token_id": 151933,
|
| 3 |
+
"action_expert_condition_source": "kv_cache",
|
| 4 |
+
"action_expert_config": {
|
| 5 |
+
"attn_dropout": 0.0,
|
| 6 |
+
"causal_attn": false,
|
| 7 |
+
"compile": "blocks",
|
| 8 |
+
"context_layer_norm": true,
|
| 9 |
+
"dropout": 0.0,
|
| 10 |
+
"ffn_multiple_of": 256,
|
| 11 |
+
"hidden_size": 768,
|
| 12 |
+
"implementation": "new",
|
| 13 |
+
"max_action_dim": 32,
|
| 14 |
+
"max_horizon": 32,
|
| 15 |
+
"mlp_ratio": 4.0,
|
| 16 |
+
"model_type": "molmoact2_action_expert",
|
| 17 |
+
"num_heads": 8,
|
| 18 |
+
"num_layers": 36,
|
| 19 |
+
"qk_norm": true,
|
| 20 |
+
"qk_norm_eps": 1e-06,
|
| 21 |
+
"rope": true,
|
| 22 |
+
"rope_on_cross_attention": true,
|
| 23 |
+
"timestep_embed_dim": 256
|
| 24 |
+
},
|
| 25 |
+
"action_expert_depth_gate": false,
|
| 26 |
+
"action_expert_depth_gate_init_bias": -4.0,
|
| 27 |
+
"action_expert_depth_gate_per_layer": false,
|
| 28 |
+
"action_expert_layer_mode": "per_layer",
|
| 29 |
+
"action_format": "both",
|
| 30 |
+
"action_horizon": 10,
|
| 31 |
+
"action_output_token_id": 151931,
|
| 32 |
+
"action_start_token_id": 151932,
|
| 33 |
+
"action_token_start_id": 151934,
|
| 34 |
+
"adapter_config": {
|
| 35 |
+
"attention_dropout": 0.0,
|
| 36 |
+
"attn_implementation": "sdpa",
|
| 37 |
+
"float32_attention": true,
|
| 38 |
+
"head_dim": 72,
|
| 39 |
+
"hidden_act": "silu",
|
| 40 |
+
"hidden_size": 1152,
|
| 41 |
+
"image_feature_dropout": 0.0,
|
| 42 |
+
"initializer_range": 0.02,
|
| 43 |
+
"intermediate_size": 9728,
|
| 44 |
+
"model_type": "molmoact2",
|
| 45 |
+
"num_attention_heads": 16,
|
| 46 |
+
"num_key_value_heads": 16,
|
| 47 |
+
"pooling_attention_mask": true,
|
| 48 |
+
"residual_dropout": 0.0,
|
| 49 |
+
"text_hidden_size": 2560,
|
| 50 |
+
"vit_layers": [
|
| 51 |
+
-3,
|
| 52 |
+
-9
|
| 53 |
+
]
|
| 54 |
+
},
|
| 55 |
+
"add_action_expert": true,
|
| 56 |
+
"add_control_tokens": true,
|
| 57 |
+
"add_setup_tokens": true,
|
| 58 |
+
"architectures": [
|
| 59 |
+
"MolmoAct2ForConditionalGeneration"
|
| 60 |
+
],
|
| 61 |
+
"auto_map": {
|
| 62 |
+
"AutoConfig": "configuration_molmoact2.MolmoAct2Config",
|
| 63 |
+
"AutoModelForImageTextToText": "modeling_molmoact2.MolmoAct2ForConditionalGeneration"
|
| 64 |
+
},
|
| 65 |
+
"depth_end_token_id": null,
|
| 66 |
+
"depth_mode": 2,
|
| 67 |
+
"depth_output_token_id": null,
|
| 68 |
+
"depth_start_token_id": null,
|
| 69 |
+
"depth_token_start_id": null,
|
| 70 |
+
"dtype": "float32",
|
| 71 |
+
"enable_depth_reasoning": false,
|
| 72 |
+
"flow_matching_beta_alpha": 1.0,
|
| 73 |
+
"flow_matching_beta_beta": 1.5,
|
| 74 |
+
"flow_matching_cutoff": 1.0,
|
| 75 |
+
"flow_matching_num_steps": 10,
|
| 76 |
+
"flow_matching_time_offset": 0.001,
|
| 77 |
+
"flow_matching_time_scale": 0.999,
|
| 78 |
+
"frame_end_token_id": 154632,
|
| 79 |
+
"frame_start_token_id": 154631,
|
| 80 |
+
"image_col_id": 154627,
|
| 81 |
+
"image_end_token_id": 154625,
|
| 82 |
+
"image_high_res_id": 154626,
|
| 83 |
+
"image_low_res_id": 154630,
|
| 84 |
+
"image_patch_id": 154626,
|
| 85 |
+
"image_start_token_id": 154624,
|
| 86 |
+
"initializer_range": 0.02,
|
| 87 |
+
"low_res_image_start_token_id": 154628,
|
| 88 |
+
"mask_action_dim_padding": true,
|
| 89 |
+
"max_action_dim": 32,
|
| 90 |
+
"model_type": "molmoact2",
|
| 91 |
+
"n_obs_steps": 1,
|
| 92 |
+
"norm_stats_filename": "norm_stats.json",
|
| 93 |
+
"num_action_tokens": 2048,
|
| 94 |
+
"num_depth_codes": 100,
|
| 95 |
+
"num_depth_tokens": 0,
|
| 96 |
+
"num_state_tokens": 256,
|
| 97 |
+
"state_end_token_id": 151674,
|
| 98 |
+
"state_format": "discrete",
|
| 99 |
+
"state_start_token_id": 151673,
|
| 100 |
+
"state_token_start_id": 151675,
|
| 101 |
+
"text_config": {
|
| 102 |
+
"additional_vocab_size": 128,
|
| 103 |
+
"attention_dropout": 0.0,
|
| 104 |
+
"attn_implementation": "sdpa",
|
| 105 |
+
"embedding_dropout": 0.0,
|
| 106 |
+
"head_dim": 128,
|
| 107 |
+
"hidden_act": "silu",
|
| 108 |
+
"hidden_size": 2560,
|
| 109 |
+
"initializer_range": 0.02,
|
| 110 |
+
"intermediate_size": 9728,
|
| 111 |
+
"layer_norm_eps": 1e-06,
|
| 112 |
+
"max_position_embeddings": 16384,
|
| 113 |
+
"model_type": "molmoact2_text",
|
| 114 |
+
"norm_after": false,
|
| 115 |
+
"num_attention_heads": 32,
|
| 116 |
+
"num_hidden_layers": 36,
|
| 117 |
+
"num_key_value_heads": 8,
|
| 118 |
+
"qk_norm_type": "qwen3",
|
| 119 |
+
"qkv_bias": false,
|
| 120 |
+
"residual_dropout": 0.0,
|
| 121 |
+
"rope_parameters": {
|
| 122 |
+
"rope_theta": 5000000.0,
|
| 123 |
+
"rope_type": "default"
|
| 124 |
+
},
|
| 125 |
+
"rope_scaling_layers": null,
|
| 126 |
+
"rope_theta": 5000000.0,
|
| 127 |
+
"tie_word_embeddings": false,
|
| 128 |
+
"use_cache": true,
|
| 129 |
+
"use_qk_norm": true,
|
| 130 |
+
"vocab_size": 154624
|
| 131 |
+
},
|
| 132 |
+
"tie_word_embeddings": false,
|
| 133 |
+
"transformers_version": "5.3.0",
|
| 134 |
+
"use_frame_special_tokens": true,
|
| 135 |
+
"vit_config": {
|
| 136 |
+
"attention_dropout": 0.0,
|
| 137 |
+
"attn_implementation": "sdpa",
|
| 138 |
+
"float32_attention": true,
|
| 139 |
+
"head_dim": 72,
|
| 140 |
+
"hidden_act": "gelu_pytorch_tanh",
|
| 141 |
+
"hidden_size": 1152,
|
| 142 |
+
"image_default_input_size": [
|
| 143 |
+
378,
|
| 144 |
+
378
|
| 145 |
+
],
|
| 146 |
+
"image_num_pos": 729,
|
| 147 |
+
"image_patch_size": 14,
|
| 148 |
+
"initializer_range": 0.02,
|
| 149 |
+
"intermediate_size": 4304,
|
| 150 |
+
"layer_norm_eps": 1e-06,
|
| 151 |
+
"model_type": "molmoact2",
|
| 152 |
+
"num_attention_heads": 16,
|
| 153 |
+
"num_hidden_layers": 27,
|
| 154 |
+
"num_key_value_heads": 16,
|
| 155 |
+
"residual_dropout": 0.0
|
| 156 |
+
},
|
| 157 |
+
"bos_token_id": 151645,
|
| 158 |
+
"eos_token_id": 151645,
|
| 159 |
+
"pad_token_id": 151643
|
| 160 |
+
}
|
configuration_molmoact2.py
ADDED
|
@@ -0,0 +1,565 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MolmoAct2 configuration
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Optional, Any
|
| 6 |
+
|
| 7 |
+
from transformers import PretrainedConfig
|
| 8 |
+
from transformers.modeling_rope_utils import rope_config_validation
|
| 9 |
+
from transformers.utils import logging
|
| 10 |
+
|
| 11 |
+
logger = logging.get_logger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class MolmoAct2VitConfig(PretrainedConfig):
|
| 15 |
+
r"""
|
| 16 |
+
This is the configuration class to store the configuration of a [`MolmoAct2VisionTransformer`].
|
| 17 |
+
It is used to instantiate a `MolmoAct2VisionTransformer` according to the specified arguments,
|
| 18 |
+
defining the model architecture.
|
| 19 |
+
|
| 20 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 21 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 22 |
+
|
| 23 |
+
Example:
|
| 24 |
+
```python
|
| 25 |
+
>>> from transformers import MolmoAct2VitConfig, MolmoAct2VisionTransformer
|
| 26 |
+
|
| 27 |
+
>>> # Initializing a MolmoAct2VitConfig
|
| 28 |
+
>>> configuration = MolmoAct2VitConfig()
|
| 29 |
+
|
| 30 |
+
>>> # Initializing a MolmoAct2VisionTransformer (with random weights)
|
| 31 |
+
>>> model = MolmoAct2VisionTransformer(configuration)
|
| 32 |
+
|
| 33 |
+
>>> # Accessing the model configuration
|
| 34 |
+
>>> configuration = model.config
|
| 35 |
+
```"""
|
| 36 |
+
|
| 37 |
+
model_type = "molmoact2"
|
| 38 |
+
base_config_key = "vit_config"
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
hidden_size: int = 1152,
|
| 43 |
+
intermediate_size: int = 4304,
|
| 44 |
+
num_hidden_layers: int = 27,
|
| 45 |
+
num_attention_heads: int = 16,
|
| 46 |
+
num_key_value_heads: int = 16,
|
| 47 |
+
head_dim: int = 72,
|
| 48 |
+
hidden_act: str = "gelu_pytorch_tanh",
|
| 49 |
+
layer_norm_eps: float = 1e-6,
|
| 50 |
+
image_default_input_size: tuple[int, int] = (378, 378),
|
| 51 |
+
image_patch_size: int = 14,
|
| 52 |
+
image_num_pos: int = 577,
|
| 53 |
+
attention_dropout: float = 0.0,
|
| 54 |
+
residual_dropout: float = 0.0,
|
| 55 |
+
initializer_range: float = 0.02,
|
| 56 |
+
float32_attention: bool = True,
|
| 57 |
+
attn_implementation: str = "eager",
|
| 58 |
+
**kwargs,
|
| 59 |
+
):
|
| 60 |
+
self.attn_implementation = attn_implementation
|
| 61 |
+
super().__init__(
|
| 62 |
+
attn_implementation=attn_implementation,
|
| 63 |
+
**kwargs
|
| 64 |
+
)
|
| 65 |
+
self.hidden_size = hidden_size
|
| 66 |
+
self.intermediate_size = intermediate_size
|
| 67 |
+
self.num_hidden_layers = num_hidden_layers
|
| 68 |
+
self.num_attention_heads = num_attention_heads
|
| 69 |
+
self.num_key_value_heads = num_key_value_heads
|
| 70 |
+
self.head_dim = head_dim
|
| 71 |
+
self.hidden_act = hidden_act
|
| 72 |
+
self.layer_norm_eps = layer_norm_eps
|
| 73 |
+
self.image_default_input_size = image_default_input_size
|
| 74 |
+
self.image_patch_size = image_patch_size
|
| 75 |
+
self.image_num_pos = image_num_pos
|
| 76 |
+
self.attention_dropout = attention_dropout
|
| 77 |
+
self.residual_dropout = residual_dropout
|
| 78 |
+
self.initializer_range = initializer_range
|
| 79 |
+
self.float32_attention = float32_attention
|
| 80 |
+
|
| 81 |
+
@property
|
| 82 |
+
def image_num_patch(self):
|
| 83 |
+
h, w = self.image_default_input_size
|
| 84 |
+
return h // self.image_patch_size, w // self.image_patch_size
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class MolmoAct2AdapterConfig(PretrainedConfig):
|
| 88 |
+
r"""
|
| 89 |
+
This is the configuration class to store the configuration of MolmoAct2Adapter. With MolmoAct2VitConfig,
|
| 90 |
+
It is used to instantiate an MolmoAct2VisionBackbone according to the specified arguments,
|
| 91 |
+
defining the model architecture.
|
| 92 |
+
|
| 93 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 94 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 95 |
+
|
| 96 |
+
Example:
|
| 97 |
+
|
| 98 |
+
```python
|
| 99 |
+
>>> from transformers import MolmoAct2VitConfig, MolmoAct2AdapterConfig, MolmoAct2VisionBackbone
|
| 100 |
+
|
| 101 |
+
>>> # Initializing a MolmoAct2VitConfig and a MolmoAct2AdapterConfig
|
| 102 |
+
>>> vit_config = MolmoAct2VitConfig()
|
| 103 |
+
>>> adapter_config = MolmoPoolingConfig()
|
| 104 |
+
|
| 105 |
+
>>> # Initializing a MolmoAct2VisionBackbone (with random weights)
|
| 106 |
+
>>> model = MolmoAct2VisionBackbone(vit_config, adapter_config)
|
| 107 |
+
|
| 108 |
+
>>> # Accessing the model configuration
|
| 109 |
+
>>> vit_configuration = model.vit_config
|
| 110 |
+
>>> adapter_configuration = model.adapter_config
|
| 111 |
+
```"""
|
| 112 |
+
|
| 113 |
+
model_type = "molmoact2"
|
| 114 |
+
base_config_key = "adapter_config"
|
| 115 |
+
|
| 116 |
+
def __init__(
|
| 117 |
+
self,
|
| 118 |
+
vit_layers: tuple = (-3, -9),
|
| 119 |
+
pooling_attention_mask: bool = False,
|
| 120 |
+
hidden_size: int = 1152,
|
| 121 |
+
num_attention_heads: int = 16,
|
| 122 |
+
num_key_value_heads: int = 16,
|
| 123 |
+
head_dim: int = 72,
|
| 124 |
+
float32_attention: bool = True,
|
| 125 |
+
attention_dropout: float = 0.0,
|
| 126 |
+
residual_dropout: float = 0.0,
|
| 127 |
+
hidden_act: str = "silu",
|
| 128 |
+
intermediate_size: int = 18944,
|
| 129 |
+
text_hidden_size: int = 3584,
|
| 130 |
+
image_feature_dropout: float = 0.0,
|
| 131 |
+
initializer_range: float = 0.02,
|
| 132 |
+
attn_implementation: str = "eager",
|
| 133 |
+
**kwargs,
|
| 134 |
+
):
|
| 135 |
+
self.attn_implementation = attn_implementation
|
| 136 |
+
super().__init__(
|
| 137 |
+
attn_implementation=attn_implementation,
|
| 138 |
+
**kwargs
|
| 139 |
+
)
|
| 140 |
+
self.vit_layers = vit_layers
|
| 141 |
+
self.pooling_attention_mask = pooling_attention_mask
|
| 142 |
+
self.hidden_size = hidden_size
|
| 143 |
+
self.num_attention_heads = num_attention_heads
|
| 144 |
+
self.num_key_value_heads = num_key_value_heads
|
| 145 |
+
self.head_dim = head_dim
|
| 146 |
+
self.float32_attention = float32_attention
|
| 147 |
+
self.attention_dropout = attention_dropout
|
| 148 |
+
self.residual_dropout = residual_dropout
|
| 149 |
+
self.hidden_act = hidden_act
|
| 150 |
+
self.intermediate_size = intermediate_size
|
| 151 |
+
self.text_hidden_size = text_hidden_size
|
| 152 |
+
self.image_feature_dropout = image_feature_dropout
|
| 153 |
+
self.initializer_range = initializer_range
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class MolmoAct2TextConfig(PretrainedConfig):
|
| 157 |
+
r"""
|
| 158 |
+
This is the configuration class to store the configuration of a [`MolmoAct2TextModel`]. It is used to instantiate a
|
| 159 |
+
`MolmoAct2TextModel` according to the specified arguments, defining the model architecture.
|
| 160 |
+
|
| 161 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 162 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 163 |
+
|
| 164 |
+
Example:
|
| 165 |
+
```python
|
| 166 |
+
>>> from transformers import MolmoAct2TextConfig, MolmoAct2TextModel
|
| 167 |
+
|
| 168 |
+
>>> # Initializing a MolmoAct2TextConfig
|
| 169 |
+
>>> configuration = MolmoAct2TextConfig()
|
| 170 |
+
|
| 171 |
+
>>> # Initializing a MolmoAct2TextModel (with random weights)
|
| 172 |
+
>>> model = MolmoAct2TextModel(configuration)
|
| 173 |
+
|
| 174 |
+
>>> # Accessing the model configuration
|
| 175 |
+
>>> configuration = model.config
|
| 176 |
+
```"""
|
| 177 |
+
|
| 178 |
+
model_type = "molmoact2_text"
|
| 179 |
+
base_config_key = "text_config"
|
| 180 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 181 |
+
base_model_tp_plan = {
|
| 182 |
+
"blocks.*.self_attn.att_proj": "colwise",
|
| 183 |
+
"blocks.*.self_attn.attn_out": "rowwise",
|
| 184 |
+
"blocks.*.mlp.ff_proj": "colwise",
|
| 185 |
+
"blocks.*.mlp.ff_out": "rowwise",
|
| 186 |
+
}
|
| 187 |
+
base_model_pp_plan = {
|
| 188 |
+
"wte": (["input_ids"], ["inputs_embeds"]),
|
| 189 |
+
"blocks": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| 190 |
+
"ln_f": (["hidden_states"], ["hidden_states"]),
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
def __init__(
|
| 194 |
+
self,
|
| 195 |
+
hidden_size: int = 3584,
|
| 196 |
+
num_attention_heads: int = 28,
|
| 197 |
+
num_key_value_heads: Optional[int] = 4,
|
| 198 |
+
head_dim: int = 128,
|
| 199 |
+
vocab_size: int = 152064,
|
| 200 |
+
additional_vocab_size: int = 128,
|
| 201 |
+
qkv_bias: bool = True,
|
| 202 |
+
num_hidden_layers: int = 48,
|
| 203 |
+
intermediate_size: int = 18944,
|
| 204 |
+
hidden_act: str = "silu",
|
| 205 |
+
embedding_dropout: float=0.0,
|
| 206 |
+
attention_dropout: float=0.0,
|
| 207 |
+
residual_dropout: float = 0.0,
|
| 208 |
+
max_position_embeddings: int = 4096,
|
| 209 |
+
rope_theta: float = 1000000.0,
|
| 210 |
+
rope_scaling: dict[str, Any] = None,
|
| 211 |
+
rope_scaling_layers: Optional[list[int]] = None,
|
| 212 |
+
use_qk_norm: bool = False,
|
| 213 |
+
qk_norm_type: str = "olmo",
|
| 214 |
+
layer_norm_eps: int = 1e-6,
|
| 215 |
+
norm_after: bool = False,
|
| 216 |
+
initializer_range: float = 0.02,
|
| 217 |
+
use_cache=True,
|
| 218 |
+
tie_word_embeddings=False,
|
| 219 |
+
attn_implementation: str = "eager",
|
| 220 |
+
**kwargs,
|
| 221 |
+
):
|
| 222 |
+
self.attn_implementation = attn_implementation
|
| 223 |
+
super().__init__(
|
| 224 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 225 |
+
attn_implementation=attn_implementation,
|
| 226 |
+
**kwargs
|
| 227 |
+
)
|
| 228 |
+
self.hidden_size = hidden_size
|
| 229 |
+
self.num_attention_heads = num_attention_heads
|
| 230 |
+
if num_key_value_heads is None:
|
| 231 |
+
num_key_value_heads = num_attention_heads
|
| 232 |
+
self.num_key_value_heads = num_key_value_heads
|
| 233 |
+
self.head_dim = head_dim
|
| 234 |
+
self.vocab_size = vocab_size
|
| 235 |
+
self.additional_vocab_size = additional_vocab_size
|
| 236 |
+
self.qkv_bias = qkv_bias
|
| 237 |
+
self.num_hidden_layers = num_hidden_layers
|
| 238 |
+
self.intermediate_size = intermediate_size
|
| 239 |
+
self.hidden_act = hidden_act
|
| 240 |
+
self.embedding_dropout = embedding_dropout
|
| 241 |
+
self.attention_dropout = attention_dropout
|
| 242 |
+
self.residual_dropout = residual_dropout
|
| 243 |
+
self.max_position_embeddings = max_position_embeddings
|
| 244 |
+
self.rope_theta = rope_theta
|
| 245 |
+
self.rope_scaling = rope_scaling
|
| 246 |
+
self.rope_scaling_layers = rope_scaling_layers
|
| 247 |
+
self.use_qk_norm = use_qk_norm
|
| 248 |
+
self.qk_norm_type = qk_norm_type
|
| 249 |
+
self.layer_norm_eps = layer_norm_eps
|
| 250 |
+
self.norm_after = norm_after
|
| 251 |
+
self.initializer_range = initializer_range
|
| 252 |
+
self.use_cache = use_cache
|
| 253 |
+
|
| 254 |
+
# Validate the correctness of rotary position embeddings parameters
|
| 255 |
+
rope_config_validation(self)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class MolmoAct2ActionExpertConfig(PretrainedConfig):
|
| 259 |
+
r"""Configuration for the MolmoAct2 modern action expert."""
|
| 260 |
+
|
| 261 |
+
model_type = "molmoact2_action_expert"
|
| 262 |
+
base_config_key = "action_expert_config"
|
| 263 |
+
|
| 264 |
+
def __init__(
|
| 265 |
+
self,
|
| 266 |
+
max_horizon: int = 32,
|
| 267 |
+
max_action_dim: int = 14,
|
| 268 |
+
hidden_size: int = 1024,
|
| 269 |
+
num_layers: int = 32,
|
| 270 |
+
num_heads: int = 16,
|
| 271 |
+
mlp_ratio: float = 8.0 / 3.0,
|
| 272 |
+
ffn_multiple_of: int = 256,
|
| 273 |
+
timestep_embed_dim: int = 256,
|
| 274 |
+
dropout: float = 0.0,
|
| 275 |
+
attn_dropout: float = 0.0,
|
| 276 |
+
context_layer_norm: bool = True,
|
| 277 |
+
qk_norm: bool = True,
|
| 278 |
+
qk_norm_eps: float = 1e-6,
|
| 279 |
+
rope: bool = True,
|
| 280 |
+
rope_on_cross_attention: bool = False,
|
| 281 |
+
causal_attn: bool = False,
|
| 282 |
+
compile: str = "blocks",
|
| 283 |
+
implementation: str = "new",
|
| 284 |
+
**kwargs,
|
| 285 |
+
):
|
| 286 |
+
super().__init__(**kwargs)
|
| 287 |
+
if implementation != "new":
|
| 288 |
+
raise ValueError(
|
| 289 |
+
"MolmoAct2 HF export supports only action_expert.implementation='new'."
|
| 290 |
+
)
|
| 291 |
+
self.max_horizon = max_horizon
|
| 292 |
+
self.max_action_dim = max_action_dim
|
| 293 |
+
self.hidden_size = hidden_size
|
| 294 |
+
self.num_layers = num_layers
|
| 295 |
+
self.num_heads = num_heads
|
| 296 |
+
self.mlp_ratio = mlp_ratio
|
| 297 |
+
self.ffn_multiple_of = ffn_multiple_of
|
| 298 |
+
self.timestep_embed_dim = timestep_embed_dim
|
| 299 |
+
self.dropout = dropout
|
| 300 |
+
self.attn_dropout = attn_dropout
|
| 301 |
+
self.context_layer_norm = context_layer_norm
|
| 302 |
+
self.qk_norm = qk_norm
|
| 303 |
+
self.qk_norm_eps = qk_norm_eps
|
| 304 |
+
self.rope = rope
|
| 305 |
+
self.rope_on_cross_attention = rope_on_cross_attention
|
| 306 |
+
self.causal_attn = causal_attn
|
| 307 |
+
self.compile = compile
|
| 308 |
+
self.implementation = implementation
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
class MolmoAct2Config(PretrainedConfig):
|
| 312 |
+
r"""
|
| 313 |
+
This is the configuration class to store the configuration of a [`MolmoAct2ForConditionalGeneration`].
|
| 314 |
+
It is used to instantiate an MolmoAct2 model according to the specified arguments, defining the model architecture.
|
| 315 |
+
|
| 316 |
+
Example:
|
| 317 |
+
|
| 318 |
+
```python
|
| 319 |
+
>>> from transformers import MolmoAct2Config, MolmoAct2VitConfig, MolmoAct2AdapterConfig, MolmoAct2TextConfig
|
| 320 |
+
|
| 321 |
+
>>> # Initializing a MolmoAct2VitConfig
|
| 322 |
+
>>> vit_config = MolmoAct2VitConfig()
|
| 323 |
+
|
| 324 |
+
>>> # Initializing a MolmoAct2AdapterConfig
|
| 325 |
+
>>> adapter_config = MolmoAct2AdapterConfig()
|
| 326 |
+
|
| 327 |
+
>>> # Initializing a MolmoAct2TextConfig
|
| 328 |
+
>>> text_config = MolmoAct2TextConfig()
|
| 329 |
+
|
| 330 |
+
>>> # Initializing a MolmoAct2Config
|
| 331 |
+
>>> configuration = MolmoAct2Config(
|
| 332 |
+
>>> vit_config=vit_config,
|
| 333 |
+
>>> adapter_config=adapter_config,
|
| 334 |
+
>>> text_config=text_config,
|
| 335 |
+
>>> image_start_token_id=151936,
|
| 336 |
+
>>> image_end_token_id=151937,
|
| 337 |
+
>>> image_patch_id=151938,
|
| 338 |
+
>>> image_col_id=151939,
|
| 339 |
+
>>> low_res_image_start_token_id=151940,
|
| 340 |
+
>>> image_low_res_id=151942,
|
| 341 |
+
>>> frame_start_token_id=151943,
|
| 342 |
+
>>> frame_end_token_id=151944,
|
| 343 |
+
>>> )
|
| 344 |
+
|
| 345 |
+
>>> # Initializing a model
|
| 346 |
+
>>> model = MolmoAct2ForConditionalGeneration(configuration)
|
| 347 |
+
|
| 348 |
+
>>> # Accessing the model configuration
|
| 349 |
+
>>> configuration = model.config
|
| 350 |
+
```"""
|
| 351 |
+
|
| 352 |
+
model_type = "molmoact2"
|
| 353 |
+
sub_configs = {
|
| 354 |
+
"text_config": MolmoAct2TextConfig,
|
| 355 |
+
"vit_config": MolmoAct2VitConfig,
|
| 356 |
+
"adapter_config": MolmoAct2AdapterConfig,
|
| 357 |
+
"action_expert_config": MolmoAct2ActionExpertConfig,
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
def __init__(
|
| 361 |
+
self,
|
| 362 |
+
vit_config: MolmoAct2VitConfig = None,
|
| 363 |
+
adapter_config: MolmoAct2AdapterConfig = None,
|
| 364 |
+
text_config: MolmoAct2TextConfig = None,
|
| 365 |
+
action_expert_config: MolmoAct2ActionExpertConfig = None,
|
| 366 |
+
image_start_token_id: int = None,
|
| 367 |
+
low_res_image_start_token_id: int = None,
|
| 368 |
+
image_end_token_id: int = None,
|
| 369 |
+
image_low_res_id: int = None,
|
| 370 |
+
image_patch_id: int = None,
|
| 371 |
+
image_col_id: int = None,
|
| 372 |
+
frame_start_token_id: int = None,
|
| 373 |
+
frame_end_token_id: int = None,
|
| 374 |
+
use_frame_special_tokens: bool = True,
|
| 375 |
+
initializer_range: float = 0.02,
|
| 376 |
+
add_action_expert: bool = True,
|
| 377 |
+
max_action_dim: int = 7,
|
| 378 |
+
action_horizon: int = 16,
|
| 379 |
+
n_obs_steps: int = 1,
|
| 380 |
+
action_format: str = "continuous",
|
| 381 |
+
state_format: str = "discrete",
|
| 382 |
+
action_expert_condition_source: str = "kv_cache",
|
| 383 |
+
action_expert_layer_mode: str = "per_layer",
|
| 384 |
+
flow_matching_num_steps: int = 10,
|
| 385 |
+
flow_matching_cutoff: float = 1.0,
|
| 386 |
+
flow_matching_time_offset: float = 0.001,
|
| 387 |
+
flow_matching_time_scale: float = 0.999,
|
| 388 |
+
flow_matching_beta_alpha: float = 1.0,
|
| 389 |
+
flow_matching_beta_beta: float = 1.5,
|
| 390 |
+
mask_action_dim_padding: bool = True,
|
| 391 |
+
enable_depth_reasoning: bool = False,
|
| 392 |
+
depth_mode: int = 2,
|
| 393 |
+
num_depth_codes: int = 100,
|
| 394 |
+
action_expert_depth_gate: bool = False,
|
| 395 |
+
action_expert_depth_gate_per_layer: bool = False,
|
| 396 |
+
action_expert_depth_gate_init_bias: float = -4.0,
|
| 397 |
+
action_output_token_id: int = None,
|
| 398 |
+
action_start_token_id: int = None,
|
| 399 |
+
action_end_token_id: int = None,
|
| 400 |
+
action_token_start_id: int = None,
|
| 401 |
+
num_action_tokens: int = 0,
|
| 402 |
+
depth_output_token_id: int = None,
|
| 403 |
+
depth_start_token_id: int = None,
|
| 404 |
+
depth_end_token_id: int = None,
|
| 405 |
+
depth_token_start_id: int = None,
|
| 406 |
+
num_depth_tokens: int = 0,
|
| 407 |
+
state_start_token_id: int = None,
|
| 408 |
+
state_end_token_id: int = None,
|
| 409 |
+
state_token_start_id: int = None,
|
| 410 |
+
num_state_tokens: int = 0,
|
| 411 |
+
add_setup_tokens: bool = True,
|
| 412 |
+
add_control_tokens: bool = True,
|
| 413 |
+
norm_stats_filename: str = "norm_stats.json",
|
| 414 |
+
**kwargs,
|
| 415 |
+
):
|
| 416 |
+
super().__init__(**kwargs)
|
| 417 |
+
if vit_config is None:
|
| 418 |
+
self.vit_config = MolmoAct2VitConfig()
|
| 419 |
+
elif isinstance(vit_config, dict):
|
| 420 |
+
self.vit_config = MolmoAct2VitConfig(**vit_config)
|
| 421 |
+
else:
|
| 422 |
+
self.vit_config = vit_config
|
| 423 |
+
if adapter_config is None:
|
| 424 |
+
self.adapter_config = MolmoAct2AdapterConfig()
|
| 425 |
+
elif isinstance(adapter_config, dict):
|
| 426 |
+
self.adapter_config = MolmoAct2AdapterConfig(**adapter_config)
|
| 427 |
+
else:
|
| 428 |
+
self.adapter_config = adapter_config
|
| 429 |
+
if text_config is None:
|
| 430 |
+
self.text_config = MolmoAct2TextConfig()
|
| 431 |
+
elif isinstance(text_config, dict):
|
| 432 |
+
self.text_config = MolmoAct2TextConfig(**text_config)
|
| 433 |
+
else:
|
| 434 |
+
self.text_config = text_config
|
| 435 |
+
self.add_action_expert = bool(add_action_expert)
|
| 436 |
+
if not self.add_action_expert:
|
| 437 |
+
self.action_expert_config = None
|
| 438 |
+
elif action_expert_config is None:
|
| 439 |
+
self.action_expert_config = MolmoAct2ActionExpertConfig(
|
| 440 |
+
max_horizon=action_horizon,
|
| 441 |
+
max_action_dim=max_action_dim,
|
| 442 |
+
num_layers=self.text_config.num_hidden_layers,
|
| 443 |
+
)
|
| 444 |
+
elif isinstance(action_expert_config, dict):
|
| 445 |
+
self.action_expert_config = MolmoAct2ActionExpertConfig(**action_expert_config)
|
| 446 |
+
else:
|
| 447 |
+
self.action_expert_config = action_expert_config
|
| 448 |
+
if self.add_action_expert:
|
| 449 |
+
self._validate_release_action_config(
|
| 450 |
+
action_expert_config=self.action_expert_config,
|
| 451 |
+
action_expert_condition_source=action_expert_condition_source,
|
| 452 |
+
action_expert_layer_mode=action_expert_layer_mode,
|
| 453 |
+
state_format=state_format,
|
| 454 |
+
)
|
| 455 |
+
self.image_start_token_id = image_start_token_id
|
| 456 |
+
self.low_res_image_start_token_id = low_res_image_start_token_id
|
| 457 |
+
self.image_end_token_id = image_end_token_id
|
| 458 |
+
self.image_low_res_id = image_low_res_id
|
| 459 |
+
self.image_high_res_id = image_patch_id
|
| 460 |
+
self.image_patch_id = image_patch_id
|
| 461 |
+
self.image_col_id = image_col_id
|
| 462 |
+
self.frame_start_token_id = frame_start_token_id
|
| 463 |
+
self.frame_end_token_id = frame_end_token_id
|
| 464 |
+
self.use_frame_special_tokens = use_frame_special_tokens
|
| 465 |
+
self.initializer_range = initializer_range
|
| 466 |
+
self.max_action_dim = max_action_dim
|
| 467 |
+
self.action_horizon = action_horizon
|
| 468 |
+
self.n_obs_steps = n_obs_steps
|
| 469 |
+
self.action_format = action_format
|
| 470 |
+
self.state_format = state_format
|
| 471 |
+
self.action_expert_condition_source = action_expert_condition_source
|
| 472 |
+
self.action_expert_layer_mode = action_expert_layer_mode
|
| 473 |
+
self.flow_matching_num_steps = flow_matching_num_steps
|
| 474 |
+
self.flow_matching_cutoff = flow_matching_cutoff
|
| 475 |
+
self.flow_matching_time_offset = flow_matching_time_offset
|
| 476 |
+
self.flow_matching_time_scale = flow_matching_time_scale
|
| 477 |
+
self.flow_matching_beta_alpha = flow_matching_beta_alpha
|
| 478 |
+
self.flow_matching_beta_beta = flow_matching_beta_beta
|
| 479 |
+
self.mask_action_dim_padding = mask_action_dim_padding
|
| 480 |
+
self.enable_depth_reasoning = enable_depth_reasoning
|
| 481 |
+
self.depth_mode = depth_mode
|
| 482 |
+
self.num_depth_codes = num_depth_codes
|
| 483 |
+
self.action_expert_depth_gate = action_expert_depth_gate
|
| 484 |
+
self.action_expert_depth_gate_per_layer = action_expert_depth_gate_per_layer
|
| 485 |
+
self.action_expert_depth_gate_init_bias = action_expert_depth_gate_init_bias
|
| 486 |
+
self.action_output_token_id = action_output_token_id
|
| 487 |
+
self.action_start_token_id = action_start_token_id
|
| 488 |
+
self.action_end_token_id = action_end_token_id
|
| 489 |
+
self.action_token_start_id = action_token_start_id
|
| 490 |
+
self.num_action_tokens = num_action_tokens
|
| 491 |
+
self.depth_output_token_id = depth_output_token_id
|
| 492 |
+
self.depth_start_token_id = depth_start_token_id
|
| 493 |
+
self.depth_end_token_id = depth_end_token_id
|
| 494 |
+
self.depth_token_start_id = depth_token_start_id
|
| 495 |
+
self.num_depth_tokens = num_depth_tokens
|
| 496 |
+
self.state_start_token_id = state_start_token_id
|
| 497 |
+
self.state_end_token_id = state_end_token_id
|
| 498 |
+
self.state_token_start_id = state_token_start_id
|
| 499 |
+
self.num_state_tokens = num_state_tokens
|
| 500 |
+
self.add_setup_tokens = add_setup_tokens
|
| 501 |
+
self.add_control_tokens = add_control_tokens
|
| 502 |
+
self.norm_stats_filename = norm_stats_filename
|
| 503 |
+
|
| 504 |
+
@staticmethod
|
| 505 |
+
def _validate_release_action_config(
|
| 506 |
+
*,
|
| 507 |
+
action_expert_config: MolmoAct2ActionExpertConfig,
|
| 508 |
+
action_expert_condition_source: str,
|
| 509 |
+
action_expert_layer_mode: str,
|
| 510 |
+
state_format: str,
|
| 511 |
+
) -> None:
|
| 512 |
+
if action_expert_config.implementation != "new":
|
| 513 |
+
raise ValueError(
|
| 514 |
+
"MolmoAct2 HF export supports only action_expert.implementation='new'."
|
| 515 |
+
)
|
| 516 |
+
if action_expert_condition_source != "kv_cache":
|
| 517 |
+
raise ValueError(
|
| 518 |
+
"MolmoAct2 HF export supports only action_expert_condition_source='kv_cache'."
|
| 519 |
+
)
|
| 520 |
+
if action_expert_layer_mode != "per_layer":
|
| 521 |
+
raise ValueError(
|
| 522 |
+
"MolmoAct2 HF export supports only action_expert_layer_mode='per_layer'."
|
| 523 |
+
)
|
| 524 |
+
if state_format != "discrete":
|
| 525 |
+
raise ValueError("MolmoAct2 HF export supports only state_format='discrete'.")
|
| 526 |
+
|
| 527 |
+
@property
|
| 528 |
+
def image_num_patch(self):
|
| 529 |
+
assert self.vit_config is not None
|
| 530 |
+
return self.vit_config.image_num_patch
|
| 531 |
+
|
| 532 |
+
@property
|
| 533 |
+
def num_attention_heads(self):
|
| 534 |
+
return self.text_config.num_attention_heads
|
| 535 |
+
|
| 536 |
+
@property
|
| 537 |
+
def num_key_value_heads(self):
|
| 538 |
+
return self.text_config.num_key_value_heads
|
| 539 |
+
|
| 540 |
+
@property
|
| 541 |
+
def head_dim(self):
|
| 542 |
+
return self.text_config.head_dim
|
| 543 |
+
|
| 544 |
+
@property
|
| 545 |
+
def num_hidden_layers(self):
|
| 546 |
+
return self.text_config.num_hidden_layers
|
| 547 |
+
|
| 548 |
+
@property
|
| 549 |
+
def hidden_size(self):
|
| 550 |
+
return self.text_config.hidden_size
|
| 551 |
+
|
| 552 |
+
@property
|
| 553 |
+
def vocab_size(self):
|
| 554 |
+
return self.text_config.vocab_size
|
| 555 |
+
|
| 556 |
+
@property
|
| 557 |
+
def max_position_embeddings(self):
|
| 558 |
+
return self.text_config.max_position_embeddings
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
MolmoAct2VitConfig.register_for_auto_class()
|
| 562 |
+
MolmoAct2AdapterConfig.register_for_auto_class()
|
| 563 |
+
MolmoAct2TextConfig.register_for_auto_class()
|
| 564 |
+
MolmoAct2ActionExpertConfig.register_for_auto_class()
|
| 565 |
+
MolmoAct2Config.register_for_auto_class()
|
generation_config.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token_id": 151645,
|
| 3 |
+
"eos_token_id": 151645,
|
| 4 |
+
"pad_token_id": 151643,
|
| 5 |
+
"transformers_version": "5.3.0"
|
| 6 |
+
}
|
image_processing_molmoact2.py
ADDED
|
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Image processor class for MolmoAct2"""
|
| 2 |
+
from typing import Optional, Union
|
| 3 |
+
import numpy as np
|
| 4 |
+
import einops
|
| 5 |
+
import torch
|
| 6 |
+
import torchvision.transforms
|
| 7 |
+
|
| 8 |
+
from transformers.image_utils import (
|
| 9 |
+
IMAGENET_STANDARD_MEAN,
|
| 10 |
+
IMAGENET_STANDARD_STD,
|
| 11 |
+
ImageInput,
|
| 12 |
+
PILImageResampling,
|
| 13 |
+
make_flat_list_of_images,
|
| 14 |
+
valid_images,
|
| 15 |
+
to_numpy_array,
|
| 16 |
+
)
|
| 17 |
+
from transformers.image_transforms import convert_to_rgb
|
| 18 |
+
from transformers.processing_utils import ImagesKwargs
|
| 19 |
+
from transformers.image_processing_utils import BaseImageProcessor, get_size_dict
|
| 20 |
+
from transformers.utils import logging
|
| 21 |
+
from transformers.feature_extraction_utils import BatchFeature
|
| 22 |
+
from transformers.utils import TensorType, logging
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def normalize_image(
|
| 29 |
+
image: np.ndarray,
|
| 30 |
+
image_mean: list[float],
|
| 31 |
+
image_std: list[float],
|
| 32 |
+
) -> np.ndarray:
|
| 33 |
+
if np.allclose(image_mean, [0.5, 0.5, 0.5]) and np.allclose(image_std, [0.5, 0.5, 0.5]):
|
| 34 |
+
return image * np.asarray(2.0, dtype=np.float32) - np.asarray(1.0, dtype=np.float32)
|
| 35 |
+
image -= np.array(image_mean, dtype=np.float32)[None, None, :]
|
| 36 |
+
image /= np.array(image_std, dtype=np.float32)[None, None, :]
|
| 37 |
+
return image
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def resize_image(
|
| 41 |
+
image: np.ndarray,
|
| 42 |
+
desired_output_size: list[int],
|
| 43 |
+
resample: PILImageResampling,
|
| 44 |
+
) -> np.ndarray:
|
| 45 |
+
image = torch.permute(torch.from_numpy(image), [2, 0, 1])
|
| 46 |
+
dtype = image.dtype
|
| 47 |
+
if torch.is_floating_point(image):
|
| 48 |
+
in_min = 0.0
|
| 49 |
+
in_max = 1.0
|
| 50 |
+
resized = torchvision.transforms.Resize(
|
| 51 |
+
desired_output_size,
|
| 52 |
+
resample,
|
| 53 |
+
antialias=False,
|
| 54 |
+
)(image)
|
| 55 |
+
resized = torch.clip(resized, 0.0, 1.0).to(dtype)
|
| 56 |
+
else:
|
| 57 |
+
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(image.dtype)
|
| 58 |
+
in_min = 0.0
|
| 59 |
+
in_max = 255.0
|
| 60 |
+
resized = torchvision.transforms.Resize(
|
| 61 |
+
desired_output_size,
|
| 62 |
+
resample,
|
| 63 |
+
antialias=False,
|
| 64 |
+
)(image)
|
| 65 |
+
resized = torch.clip(resized, 0, 255).to(dtype)
|
| 66 |
+
|
| 67 |
+
resized = resized.to(torch.float32)
|
| 68 |
+
resized = (resized - in_min) / (in_max - in_min)
|
| 69 |
+
|
| 70 |
+
resized = torch.permute(resized, [1, 2, 0]).numpy()
|
| 71 |
+
|
| 72 |
+
return resized
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def select_tiling(h, w, patch_size, max_num_crops):
|
| 76 |
+
"""Divide in image of size [w, h] in up to max_num_patches of size patch_size"""
|
| 77 |
+
original_size = np.stack([h, w]) # [1, 2]
|
| 78 |
+
original_res = h * w
|
| 79 |
+
tilings = []
|
| 80 |
+
for i in range(1, max_num_crops + 1):
|
| 81 |
+
for j in range(1, max_num_crops + 1):
|
| 82 |
+
if i*j <= max_num_crops:
|
| 83 |
+
tilings.append((i, j))
|
| 84 |
+
# sort so argmin and argmax favour smaller tilings in the event of a tie
|
| 85 |
+
tilings.sort(key=lambda x: (x[0]*x[1], x[0]))
|
| 86 |
+
candidate_tilings = np.array(tilings, dtype=np.int32) # [n_resolutions, 2]
|
| 87 |
+
candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2]
|
| 88 |
+
|
| 89 |
+
# How much we would need to scale the image to fit exactly in each tiling
|
| 90 |
+
original_size = np.stack([h, w], dtype=np.float32) # [1, 2]
|
| 91 |
+
|
| 92 |
+
# The original size can be zero in rare cases if the image is smaller than the margin
|
| 93 |
+
# In those cases letting the scale become infinite means the tiling is based on the
|
| 94 |
+
# other side, or falls back to the smallest tiling
|
| 95 |
+
with np.errstate(divide='ignore'):
|
| 96 |
+
required_scale_d = candidate_resolutions.astype(np.float32) / original_size,
|
| 97 |
+
required_scale = np.min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1]
|
| 98 |
+
if np.all(required_scale < 1):
|
| 99 |
+
# We are forced to downscale, so try to minimize the amount of downscaling
|
| 100 |
+
ix = np.argmax(required_scale)
|
| 101 |
+
else:
|
| 102 |
+
# Pick the resolution that required the least upscaling so that it most closely fits the image
|
| 103 |
+
required_scale = np.where(required_scale < 1.0, 10e9, required_scale)
|
| 104 |
+
ix = np.argmin(required_scale)
|
| 105 |
+
return candidate_tilings[ix]
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def build_resized_image(
|
| 109 |
+
image: np.ndarray,
|
| 110 |
+
base_image_input_size: list[int],
|
| 111 |
+
resample: PILImageResampling,
|
| 112 |
+
image_mean: list[float],
|
| 113 |
+
image_std: list[float],
|
| 114 |
+
image_patch_size: int,
|
| 115 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
| 116 |
+
resized = resize_image(
|
| 117 |
+
image, base_image_input_size, resample,
|
| 118 |
+
)
|
| 119 |
+
resized = normalize_image(resized, image_mean, image_std)
|
| 120 |
+
if len(resized.shape) == 3:
|
| 121 |
+
resized = np.expand_dims(resized, 0)
|
| 122 |
+
crop_patch_w = base_image_input_size[1] // image_patch_size
|
| 123 |
+
crop_patch_h = base_image_input_size[0] // image_patch_size
|
| 124 |
+
resize_idx = np.arange(crop_patch_w*crop_patch_h).reshape([crop_patch_h, crop_patch_w])
|
| 125 |
+
return resized, resize_idx
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def build_overlapping_crops(
|
| 129 |
+
image: np.ndarray,
|
| 130 |
+
max_crops: int,
|
| 131 |
+
overlap_margins: list[int],
|
| 132 |
+
base_image_input_size: list[int],
|
| 133 |
+
resample: PILImageResampling,
|
| 134 |
+
image_mean: list[float],
|
| 135 |
+
image_std: list[float],
|
| 136 |
+
image_patch_size: int,
|
| 137 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
| 138 |
+
"""Decompose an image into a set of overlapping crops
|
| 139 |
+
|
| 140 |
+
:return crop_arr: [n_crops, h, w, 3] The crops
|
| 141 |
+
:return patch_idx: [overlap_patch_h, overlap_patch_w] For each patch in the resized image
|
| 142 |
+
the crops were extracted from, what patch in `crop_arr` it corresponds to
|
| 143 |
+
"""
|
| 144 |
+
original_image_h, original_image_w = image.shape[:2]
|
| 145 |
+
crop_size = base_image_input_size[0]
|
| 146 |
+
assert base_image_input_size[0] == base_image_input_size[1]
|
| 147 |
+
|
| 148 |
+
left_margin, right_margin = overlap_margins
|
| 149 |
+
total_margin_pixels = image_patch_size * (right_margin + left_margin) # pixels removed per dim
|
| 150 |
+
crop_patches = base_image_input_size[0] // image_patch_size # patches per crop dim
|
| 151 |
+
crop_window_patches = crop_patches - (right_margin + left_margin) # usable patches
|
| 152 |
+
crop_window_size = crop_window_patches * image_patch_size
|
| 153 |
+
crop_patch_w = base_image_input_size[1] // image_patch_size
|
| 154 |
+
crop_patch_h = base_image_input_size[0] // image_patch_size
|
| 155 |
+
original_image_h, original_image_w = image.shape[:2]
|
| 156 |
+
crop_size = base_image_input_size[0]
|
| 157 |
+
|
| 158 |
+
# Decide how to tile the image, to account for the overlap margins we compute the tiling
|
| 159 |
+
# as if we had an image without the margins and were using a crop size without the margins
|
| 160 |
+
tiling = select_tiling(
|
| 161 |
+
original_image_h - total_margin_pixels,
|
| 162 |
+
original_image_w - total_margin_pixels,
|
| 163 |
+
crop_window_size,
|
| 164 |
+
max_crops,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
src = resize_image(
|
| 168 |
+
image,
|
| 169 |
+
[tiling[0]*crop_window_size+total_margin_pixels, tiling[1]*crop_window_size+total_margin_pixels],
|
| 170 |
+
resample,
|
| 171 |
+
)
|
| 172 |
+
src = normalize_image(src, image_mean, image_std)
|
| 173 |
+
|
| 174 |
+
# Now we have to split the image into crops, and track what patches came from
|
| 175 |
+
# where in `patch_idx_arr`
|
| 176 |
+
n_crops = tiling[0] * tiling[1]
|
| 177 |
+
crop_arr = np.zeros([n_crops, crop_size, crop_size, 3], dtype=src.dtype)
|
| 178 |
+
patch_idx_arr = np.zeros([n_crops, crop_patch_h, crop_patch_w], dtype=np.int32)
|
| 179 |
+
on_crop = 0
|
| 180 |
+
for i in range(tiling[0]):
|
| 181 |
+
# Slide over `src` by `crop_window_size` steps, but extract crops of size `crops_size`
|
| 182 |
+
# which results in overlapping crop windows
|
| 183 |
+
y0 = i*crop_window_size
|
| 184 |
+
for j in range(tiling[1]):
|
| 185 |
+
x0 = j*crop_window_size
|
| 186 |
+
crop_arr[on_crop] = src[y0:y0+crop_size, x0:x0+crop_size]
|
| 187 |
+
patch_idx = np.arange(crop_patch_w*crop_patch_h).reshape(crop_patch_h, crop_patch_w)
|
| 188 |
+
patch_idx += on_crop * crop_patch_h * crop_patch_w
|
| 189 |
+
|
| 190 |
+
# Mask out idx that are in the overlap region
|
| 191 |
+
if i != 0:
|
| 192 |
+
patch_idx[:left_margin, :] = -1
|
| 193 |
+
if j != 0:
|
| 194 |
+
patch_idx[:, :left_margin] = -1
|
| 195 |
+
if i != tiling[0]-1:
|
| 196 |
+
patch_idx[-right_margin:, :] = -1
|
| 197 |
+
if j != tiling[1]-1:
|
| 198 |
+
patch_idx[:, -right_margin:] = -1
|
| 199 |
+
patch_idx_arr[on_crop] = patch_idx
|
| 200 |
+
on_crop += 1
|
| 201 |
+
|
| 202 |
+
# `patch_idx_arr` is ordered crop-by-crop, here we transpose `patch_idx_arr`
|
| 203 |
+
# so it is ordered left-to-right order
|
| 204 |
+
patch_idx_arr = np.reshape(
|
| 205 |
+
patch_idx_arr,
|
| 206 |
+
[tiling[0], tiling[1], crop_patch_h, crop_patch_w]
|
| 207 |
+
)
|
| 208 |
+
patch_idx_arr = np.transpose(patch_idx_arr, [0, 2, 1, 3])
|
| 209 |
+
patch_idx_arr = np.reshape(patch_idx_arr, [-1])
|
| 210 |
+
|
| 211 |
+
# Now get the parts not in the overlap region, so it should map each patch in `src`
|
| 212 |
+
# to the correct patch it should come from in `crop_arr`
|
| 213 |
+
patch_idx_arr = patch_idx_arr[patch_idx_arr >= 0].reshape(
|
| 214 |
+
src.shape[0]//image_patch_size,
|
| 215 |
+
src.shape[1]//image_patch_size,
|
| 216 |
+
)
|
| 217 |
+
return crop_arr, patch_idx_arr
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
|
| 221 |
+
"""Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
|
| 222 |
+
if len(array.shape) == 3:
|
| 223 |
+
n_crops, h, w = array.shape
|
| 224 |
+
h_patches = h//patch_size
|
| 225 |
+
w_patches = w//patch_size
|
| 226 |
+
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
|
| 227 |
+
array = np.transpose(array, [0, 1, 3, 2, 4])
|
| 228 |
+
array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size])
|
| 229 |
+
return array
|
| 230 |
+
else:
|
| 231 |
+
n_crops, h, w, c = array.shape
|
| 232 |
+
h_patches = h//patch_size
|
| 233 |
+
w_patches = w//patch_size
|
| 234 |
+
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
|
| 235 |
+
array = np.transpose(array, [0, 1, 3, 2, 4, 5])
|
| 236 |
+
array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size*c])
|
| 237 |
+
return array
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def arange_for_pooling(
|
| 241 |
+
idx_arr: np.ndarray,
|
| 242 |
+
pool_h: int,
|
| 243 |
+
pool_w: int,
|
| 244 |
+
) -> np.ndarray:
|
| 245 |
+
h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0]
|
| 246 |
+
w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1]
|
| 247 |
+
idx_arr = np.pad(idx_arr, [[h_pad//2, (h_pad+1)//2], [w_pad//2, (w_pad+1)//2]],
|
| 248 |
+
mode='constant',constant_values=-1)
|
| 249 |
+
return einops.rearrange(
|
| 250 |
+
idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def image_to_patches_and_grids(
|
| 254 |
+
image: np.ndarray,
|
| 255 |
+
max_crops: int,
|
| 256 |
+
overlap_margins: list[int],
|
| 257 |
+
base_image_input_size: list[int],
|
| 258 |
+
resample: PILImageResampling,
|
| 259 |
+
image_mean: list[float],
|
| 260 |
+
image_std: list[float],
|
| 261 |
+
image_patch_size: int,
|
| 262 |
+
image_pooling_w: int,
|
| 263 |
+
image_pooling_h: int,
|
| 264 |
+
crop_mode: str = "overlap-and-resize-c2",
|
| 265 |
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 266 |
+
"""
|
| 267 |
+
:return image_grids, the shape of each (low-res, high-res) image after pooling
|
| 268 |
+
:return crops, the image crops to processes with the ViT
|
| 269 |
+
:return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the
|
| 270 |
+
patches in `crops` to pool for that token, masked with -1
|
| 271 |
+
"""
|
| 272 |
+
if isinstance(base_image_input_size, int):
|
| 273 |
+
base_image_input_size = (base_image_input_size, base_image_input_size)
|
| 274 |
+
|
| 275 |
+
base_image_input_d = image_patch_size
|
| 276 |
+
pooling_w = image_pooling_w
|
| 277 |
+
pooling_h = image_pooling_h
|
| 278 |
+
crop_patch_w = base_image_input_size[1] // base_image_input_d
|
| 279 |
+
crop_patch_h = base_image_input_size[0] // base_image_input_d
|
| 280 |
+
|
| 281 |
+
if crop_mode == "resize":
|
| 282 |
+
resized, resize_idx = build_resized_image(
|
| 283 |
+
image,
|
| 284 |
+
base_image_input_size,
|
| 285 |
+
resample,
|
| 286 |
+
image_mean,
|
| 287 |
+
image_std,
|
| 288 |
+
image_patch_size,
|
| 289 |
+
)
|
| 290 |
+
resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
|
| 291 |
+
resized_h, resized_w = resize_idx.shape[:2]
|
| 292 |
+
resize_idx = resize_idx.reshape([-1, pooling_h * pooling_w])
|
| 293 |
+
image_grid = [np.array([resized_h, resized_w, 0, 0])]
|
| 294 |
+
return (
|
| 295 |
+
np.stack(image_grid, 0),
|
| 296 |
+
batch_pixels_to_patches(resized, image_patch_size),
|
| 297 |
+
resize_idx,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
if crop_mode not in {"overlap-and-resize-c2", "overlap-and-resize"}:
|
| 301 |
+
raise ValueError(f"Unsupported MolmoAct2 image crop_mode {crop_mode!r}.")
|
| 302 |
+
|
| 303 |
+
crop_arr, patch_idx_arr = build_overlapping_crops(
|
| 304 |
+
image,
|
| 305 |
+
max_crops,
|
| 306 |
+
overlap_margins,
|
| 307 |
+
base_image_input_size,
|
| 308 |
+
resample,
|
| 309 |
+
image_mean,
|
| 310 |
+
image_std,
|
| 311 |
+
image_patch_size,
|
| 312 |
+
)
|
| 313 |
+
pooling_idx = arange_for_pooling(patch_idx_arr, pooling_h, pooling_w)
|
| 314 |
+
h, w = pooling_idx.shape[:2]
|
| 315 |
+
pooling_idx = pooling_idx.reshape([-1, pooling_h*pooling_w])
|
| 316 |
+
|
| 317 |
+
# Finally do the same for the global image
|
| 318 |
+
resized, resize_idx = build_resized_image(
|
| 319 |
+
image,
|
| 320 |
+
base_image_input_size,
|
| 321 |
+
resample,
|
| 322 |
+
image_mean,
|
| 323 |
+
image_std,
|
| 324 |
+
image_patch_size,
|
| 325 |
+
)
|
| 326 |
+
crop_arr = np.concatenate([resized, crop_arr], 0)
|
| 327 |
+
|
| 328 |
+
resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
|
| 329 |
+
resized_h, resized_w = resize_idx.shape[:2]
|
| 330 |
+
resize_idx = resize_idx.reshape([-1, pooling_h*pooling_w])
|
| 331 |
+
|
| 332 |
+
# Global image goes first, so the order of patches in previous crops gets increased
|
| 333 |
+
pooling_idx = np.where(
|
| 334 |
+
pooling_idx >= 0,
|
| 335 |
+
pooling_idx + crop_patch_h*crop_patch_w,
|
| 336 |
+
-1
|
| 337 |
+
)
|
| 338 |
+
pooling_idx = np.concatenate([resize_idx, pooling_idx])
|
| 339 |
+
image_grid = [np.array([resized_h, resized_w, h, w])]
|
| 340 |
+
|
| 341 |
+
return (
|
| 342 |
+
np.stack(image_grid, 0),
|
| 343 |
+
batch_pixels_to_patches(crop_arr, image_patch_size),
|
| 344 |
+
pooling_idx
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
class MolmoAct2ImagesKwargs(ImagesKwargs, total=False):
|
| 349 |
+
max_crops: Optional[int]
|
| 350 |
+
overlap_margins: Optional[list[int]]
|
| 351 |
+
crop_mode: Optional[str]
|
| 352 |
+
patch_size: Optional[int]
|
| 353 |
+
pooling_size: Optional[list[int]]
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
class MolmoAct2ImageProcessor(BaseImageProcessor):
|
| 357 |
+
r"""
|
| 358 |
+
Constructs a MolmoAct2 image processor that preprocesses images for the model.
|
| 359 |
+
|
| 360 |
+
Args:
|
| 361 |
+
size (`dict[str, int]` *optional*, defaults to `{"height": 378, "width": 378}`):
|
| 362 |
+
Size of the image after resizing.
|
| 363 |
+
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
| 364 |
+
Resampling filter to use when resizing the image.
|
| 365 |
+
image_mean (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
| 366 |
+
Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
| 367 |
+
image_std (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
| 368 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
| 369 |
+
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
| 370 |
+
Whether to convert the image to RGB.
|
| 371 |
+
max_crops (`int`, *optional*, defaults to `8`):
|
| 372 |
+
Maximum number of crops to use per image.
|
| 373 |
+
overlap_margins (`list[int]`, *optional*, defaults to `[4, 4]`):
|
| 374 |
+
Overlap margins to use.
|
| 375 |
+
patch_size (`int`, *optional*, defaults to 14):
|
| 376 |
+
The spatial patch size of the vision encoder.
|
| 377 |
+
pooling_size (`list[int]`, *optional*, defaults to `[2, 2]`):
|
| 378 |
+
The pooling size of the vision adapter.
|
| 379 |
+
"""
|
| 380 |
+
|
| 381 |
+
model_input_names = ["pixel_values", "image_token_pooling", "image_grids", "image_num_crops"]
|
| 382 |
+
|
| 383 |
+
def __init__(
|
| 384 |
+
self,
|
| 385 |
+
size: Optional[dict[str, int]] = None,
|
| 386 |
+
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
| 387 |
+
image_mean: Optional[Union[float, list[float]]] = None,
|
| 388 |
+
image_std: Optional[Union[float, list[float]]] = None,
|
| 389 |
+
do_convert_rgb: bool = True,
|
| 390 |
+
max_crops: int = 8,
|
| 391 |
+
overlap_margins: list[int] = [4, 4],
|
| 392 |
+
crop_mode: str = "overlap-and-resize-c2",
|
| 393 |
+
patch_size: int = 14,
|
| 394 |
+
pooling_size: list[int] = [2, 2],
|
| 395 |
+
**kwargs,
|
| 396 |
+
) -> None:
|
| 397 |
+
super().__init__(**kwargs)
|
| 398 |
+
size = size if size is not None else {"height": 378, "width": 378}
|
| 399 |
+
size = get_size_dict(size, default_to_square=True)
|
| 400 |
+
self.size = size
|
| 401 |
+
|
| 402 |
+
self.resample = resample
|
| 403 |
+
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
| 404 |
+
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
| 405 |
+
self.do_convert_rgb = do_convert_rgb
|
| 406 |
+
|
| 407 |
+
self.max_crops = max_crops
|
| 408 |
+
self.overlap_margins = overlap_margins
|
| 409 |
+
self.crop_mode = crop_mode
|
| 410 |
+
self.patch_size = patch_size
|
| 411 |
+
self.pooling_size = pooling_size
|
| 412 |
+
|
| 413 |
+
def preprocess(
|
| 414 |
+
self,
|
| 415 |
+
images: ImageInput,
|
| 416 |
+
size: Optional[dict[str, int]] = None,
|
| 417 |
+
resample: Optional[PILImageResampling] = None,
|
| 418 |
+
image_mean: Optional[Union[float, list[float]]] = None,
|
| 419 |
+
image_std: Optional[Union[float, list[float]]] = None,
|
| 420 |
+
do_convert_rgb: Optional[bool] = None,
|
| 421 |
+
max_crops: Optional[int] = None,
|
| 422 |
+
overlap_margins: Optional[list[int]] = None,
|
| 423 |
+
crop_mode: Optional[str] = None,
|
| 424 |
+
patch_size: Optional[int] = None,
|
| 425 |
+
pooling_size: Optional[list[int]] = None,
|
| 426 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 427 |
+
**kwargs,
|
| 428 |
+
) -> BatchFeature:
|
| 429 |
+
"""
|
| 430 |
+
Args:
|
| 431 |
+
images (`ImageInput`):
|
| 432 |
+
Image to preprocess.
|
| 433 |
+
size (`dict[str, int]`, *optional*, defaults to `self.size`):
|
| 434 |
+
Size of the image after resizing.
|
| 435 |
+
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
| 436 |
+
Resampling filter to use when resizing the image. This can be one of the enum `PILImageResampling`. Only
|
| 437 |
+
has an effect if `do_resize` is set to `True`.
|
| 438 |
+
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
|
| 439 |
+
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
| 440 |
+
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
|
| 441 |
+
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
| 442 |
+
`True`.
|
| 443 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
| 444 |
+
Whether to convert the image to RGB.
|
| 445 |
+
max_crops (`int`, *optional*, defaults to `self.max_crops`):
|
| 446 |
+
Maximum number of crops to use per image.
|
| 447 |
+
overlap_margins (`list[int]`, *optional*, defaults to `self.overlap_margins`):
|
| 448 |
+
Overlap margins to use.
|
| 449 |
+
patch_size (`int`, *optional*, defaults to `self.patch_size`):
|
| 450 |
+
The spatial patch size of the vision encoder.
|
| 451 |
+
pooling_size (`list[int]`, *optional*, defaults to `self.pooling_size`):
|
| 452 |
+
The pooling size of the vision adapter.
|
| 453 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 454 |
+
The type of tensors to return. Can be one of:
|
| 455 |
+
- Unset: Return a list of `np.ndarray`.
|
| 456 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 457 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 458 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 459 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 460 |
+
|
| 461 |
+
Returns:
|
| 462 |
+
A `BatchFeature` containing the following keys:
|
| 463 |
+
- `pixel_values`: The preprocessed images.
|
| 464 |
+
- `image_token_pooling`: The indices of the patches in `crops` to pool for each token in `image_tokens`.
|
| 465 |
+
- `image_grids`: The image grids.
|
| 466 |
+
- `image_num_crops`: The number of crops for each image.
|
| 467 |
+
"""
|
| 468 |
+
if size is not None:
|
| 469 |
+
if "height" not in size or "width" not in size:
|
| 470 |
+
raise ValueError("size must contain 'height' and 'width' keys.")
|
| 471 |
+
else:
|
| 472 |
+
size = {**self.size}
|
| 473 |
+
|
| 474 |
+
base_image_input_size = [size["height"], size["width"]]
|
| 475 |
+
|
| 476 |
+
resample = resample or self.resample
|
| 477 |
+
image_mean = image_mean or self.image_mean
|
| 478 |
+
image_std = image_std or self.image_std
|
| 479 |
+
do_convert_rgb = do_convert_rgb or self.do_convert_rgb
|
| 480 |
+
|
| 481 |
+
max_crops = max_crops or self.max_crops
|
| 482 |
+
overlap_margins = overlap_margins or self.overlap_margins
|
| 483 |
+
crop_mode = crop_mode or self.crop_mode
|
| 484 |
+
patch_size = patch_size or self.patch_size
|
| 485 |
+
pooling_size = pooling_size or self.pooling_size
|
| 486 |
+
|
| 487 |
+
image_pooling_h, image_pooling_w = pooling_size
|
| 488 |
+
|
| 489 |
+
if images is not None:
|
| 490 |
+
images = self.fetch_images(images)
|
| 491 |
+
images = make_flat_list_of_images(images)
|
| 492 |
+
|
| 493 |
+
if images is not None and not valid_images(images):
|
| 494 |
+
raise ValueError(
|
| 495 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 496 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
if do_convert_rgb:
|
| 500 |
+
images = [convert_to_rgb(image) for image in images]
|
| 501 |
+
|
| 502 |
+
# All transformations expect numpy arrays.
|
| 503 |
+
images = [to_numpy_array(image) for image in images]
|
| 504 |
+
|
| 505 |
+
data = {}
|
| 506 |
+
if images is not None:
|
| 507 |
+
batch_grids = []
|
| 508 |
+
batch_crops = []
|
| 509 |
+
batch_pooled_patches_idx = []
|
| 510 |
+
batch_num_crops = []
|
| 511 |
+
|
| 512 |
+
for image in images:
|
| 513 |
+
image_grid, crops, pooled_idx = image_to_patches_and_grids(
|
| 514 |
+
image,
|
| 515 |
+
max_crops,
|
| 516 |
+
overlap_margins,
|
| 517 |
+
base_image_input_size,
|
| 518 |
+
resample,
|
| 519 |
+
image_mean,
|
| 520 |
+
image_std,
|
| 521 |
+
patch_size,
|
| 522 |
+
image_pooling_w,
|
| 523 |
+
image_pooling_h,
|
| 524 |
+
crop_mode,
|
| 525 |
+
)
|
| 526 |
+
batch_grids.append(image_grid)
|
| 527 |
+
batch_crops.append(crops)
|
| 528 |
+
batch_pooled_patches_idx.append(pooled_idx)
|
| 529 |
+
batch_num_crops.append(crops.shape[0])
|
| 530 |
+
|
| 531 |
+
pixel_values = np.concatenate(batch_crops, 0)
|
| 532 |
+
image_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
|
| 533 |
+
image_grids = np.concatenate(batch_grids, 0)
|
| 534 |
+
image_num_crops = np.array(batch_num_crops)
|
| 535 |
+
|
| 536 |
+
data.update(
|
| 537 |
+
pixel_values=pixel_values,
|
| 538 |
+
image_token_pooling=image_token_pooling,
|
| 539 |
+
image_grids=image_grids,
|
| 540 |
+
image_num_crops=image_num_crops,
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
return BatchFeature(data, tensor_type=return_tensors)
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
MolmoAct2ImageProcessor.register_for_auto_class()
|
inference.py
ADDED
|
@@ -0,0 +1,768 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Inference utilities for MolmoAct2"""
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Iterable, Optional, Sequence, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
from transformers.cache_utils import Cache
|
| 9 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class _ActionFlowInputs:
|
| 14 |
+
trajectory: torch.Tensor
|
| 15 |
+
context: Any
|
| 16 |
+
modulations: Sequence[Any]
|
| 17 |
+
action_dim_is_pad: Optional[torch.Tensor]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class _ActionFlowCudaGraph:
|
| 22 |
+
key: Tuple[Any, ...]
|
| 23 |
+
graph: torch.cuda.CUDAGraph
|
| 24 |
+
static_inputs: _ActionFlowInputs
|
| 25 |
+
output: torch.Tensor
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class _DepthDecodeCudaGraphLayerStage:
|
| 30 |
+
residual: torch.Tensor
|
| 31 |
+
query: torch.Tensor
|
| 32 |
+
key: torch.Tensor
|
| 33 |
+
value: torch.Tensor
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class _DepthDecodeCudaGraphPostStage:
|
| 38 |
+
graph: torch.cuda.CUDAGraph
|
| 39 |
+
attn_context: torch.Tensor
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class _DepthDecodeCudaGraph:
|
| 44 |
+
cache_key: Tuple[Any, ...]
|
| 45 |
+
pre_graph: torch.cuda.CUDAGraph
|
| 46 |
+
token_ids: torch.Tensor
|
| 47 |
+
cos: torch.Tensor
|
| 48 |
+
sin: torch.Tensor
|
| 49 |
+
positions: torch.Tensor
|
| 50 |
+
stages: Sequence[_DepthDecodeCudaGraphLayerStage]
|
| 51 |
+
post_graphs: Sequence[_DepthDecodeCudaGraphPostStage]
|
| 52 |
+
output: torch.Tensor
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@dataclass
|
| 56 |
+
class _DepthDecodeCudaGraphSpec:
|
| 57 |
+
eligible: bool
|
| 58 |
+
cache_key_prefix: Tuple[Any, ...]
|
| 59 |
+
num_hidden_layers: int
|
| 60 |
+
head_dim: int
|
| 61 |
+
num_attention_heads: int
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _cache_seq_len_int(past_key_values: Optional[Cache]) -> int:
|
| 65 |
+
if past_key_values is None:
|
| 66 |
+
return 0
|
| 67 |
+
seq_len = past_key_values.get_seq_length()
|
| 68 |
+
if torch.is_tensor(seq_len):
|
| 69 |
+
return int(seq_len.item())
|
| 70 |
+
return int(seq_len)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _cache_max_len_int(past_key_values: Optional[Cache]) -> int:
|
| 74 |
+
if past_key_values is None:
|
| 75 |
+
return -1
|
| 76 |
+
max_len = past_key_values.get_max_cache_shape()
|
| 77 |
+
if torch.is_tensor(max_len):
|
| 78 |
+
return int(max_len.item())
|
| 79 |
+
return int(max_len)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _iter_cache_key_values(
|
| 83 |
+
past_key_values: Cache,
|
| 84 |
+
) -> Iterable[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]]:
|
| 85 |
+
layers = getattr(past_key_values, "layers", None)
|
| 86 |
+
if layers is not None:
|
| 87 |
+
for layer in layers:
|
| 88 |
+
yield getattr(layer, "keys", None), getattr(layer, "values", None)
|
| 89 |
+
return
|
| 90 |
+
for layer in past_key_values:
|
| 91 |
+
yield layer[0], layer[1]
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class _DepthDecodeStaticLayerCache:
|
| 95 |
+
is_compileable = False
|
| 96 |
+
is_sliding = False
|
| 97 |
+
|
| 98 |
+
def __init__(self, max_cache_len: int) -> None:
|
| 99 |
+
self.max_cache_len = int(max_cache_len)
|
| 100 |
+
self.cumulative_length = 0
|
| 101 |
+
self.keys: Optional[torch.Tensor] = None
|
| 102 |
+
self.values: Optional[torch.Tensor] = None
|
| 103 |
+
|
| 104 |
+
def _allocate(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
|
| 105 |
+
bsz, n_heads = key_states.shape[:2]
|
| 106 |
+
self.keys = torch.empty(
|
| 107 |
+
(bsz, n_heads, self.max_cache_len, key_states.shape[-1]),
|
| 108 |
+
dtype=key_states.dtype,
|
| 109 |
+
device=key_states.device,
|
| 110 |
+
)
|
| 111 |
+
self.values = torch.empty(
|
| 112 |
+
(bsz, n_heads, self.max_cache_len, value_states.shape[-1]),
|
| 113 |
+
dtype=value_states.dtype,
|
| 114 |
+
device=value_states.device,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def update(
|
| 118 |
+
self,
|
| 119 |
+
key_states: torch.Tensor,
|
| 120 |
+
value_states: torch.Tensor,
|
| 121 |
+
*args,
|
| 122 |
+
**kwargs,
|
| 123 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 124 |
+
if self.keys is None:
|
| 125 |
+
self._allocate(key_states, value_states)
|
| 126 |
+
start = self.cumulative_length
|
| 127 |
+
end = start + key_states.shape[-2]
|
| 128 |
+
if end > self.max_cache_len:
|
| 129 |
+
raise RuntimeError(
|
| 130 |
+
f"KV cache length {end} exceeds max_cache_len={self.max_cache_len}."
|
| 131 |
+
)
|
| 132 |
+
self.keys[:, :, start:end, :].copy_(key_states)
|
| 133 |
+
self.values[:, :, start:end, :].copy_(value_states)
|
| 134 |
+
self.cumulative_length = end
|
| 135 |
+
return self.keys[:, :, :end, :], self.values[:, :, :end, :]
|
| 136 |
+
|
| 137 |
+
def get_seq_length(self) -> int:
|
| 138 |
+
return self.cumulative_length
|
| 139 |
+
|
| 140 |
+
def get_max_cache_shape(self) -> int:
|
| 141 |
+
return -1
|
| 142 |
+
|
| 143 |
+
def reset(self) -> None:
|
| 144 |
+
self.cumulative_length = 0
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class _DepthDecodeStaticCache(Cache):
|
| 148 |
+
def __init__(self, config: PretrainedConfig, max_cache_len: int) -> None:
|
| 149 |
+
text_config = config.get_text_config(decoder=True)
|
| 150 |
+
super().__init__(
|
| 151 |
+
layers=[
|
| 152 |
+
_DepthDecodeStaticLayerCache(max_cache_len=max_cache_len)
|
| 153 |
+
for _ in range(text_config.num_hidden_layers)
|
| 154 |
+
]
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
def get_seq_length(self, layer_idx: int = 0) -> int:
|
| 158 |
+
return self.layers[layer_idx].get_seq_length()
|
| 159 |
+
|
| 160 |
+
def get_max_cache_shape(self, layer_idx: int = 0) -> int:
|
| 161 |
+
return self.layers[layer_idx].get_max_cache_shape()
|
| 162 |
+
|
| 163 |
+
def reset(self) -> None:
|
| 164 |
+
for layer in self.layers:
|
| 165 |
+
layer.reset()
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class ActionCudaGraphManager:
|
| 169 |
+
def __init__(self, model: Any) -> None:
|
| 170 |
+
self.model = model
|
| 171 |
+
self.enabled = True
|
| 172 |
+
self.action_flow_graph: Optional[_ActionFlowCudaGraph] = None
|
| 173 |
+
|
| 174 |
+
def set_enabled(self, enabled: bool) -> None:
|
| 175 |
+
self.enabled = bool(enabled)
|
| 176 |
+
|
| 177 |
+
def can_use_action_flow(self, inputs: _ActionFlowInputs) -> bool:
|
| 178 |
+
action_model = self.model
|
| 179 |
+
if not self.enabled:
|
| 180 |
+
return False
|
| 181 |
+
if action_model.training or action_model._require_action_expert().training:
|
| 182 |
+
return False
|
| 183 |
+
if inputs.trajectory.device.type != "cuda":
|
| 184 |
+
return False
|
| 185 |
+
|
| 186 |
+
def all_on_cuda():
|
| 187 |
+
yield inputs.trajectory
|
| 188 |
+
for k, v in inputs.context.kv_contexts:
|
| 189 |
+
yield k
|
| 190 |
+
yield v
|
| 191 |
+
for t in (
|
| 192 |
+
inputs.context.cross_mask,
|
| 193 |
+
inputs.context.self_mask,
|
| 194 |
+
inputs.context.valid_action,
|
| 195 |
+
inputs.action_dim_is_pad,
|
| 196 |
+
):
|
| 197 |
+
if t is not None:
|
| 198 |
+
yield t
|
| 199 |
+
if inputs.context.rope_cache is not None:
|
| 200 |
+
yield from inputs.context.rope_cache
|
| 201 |
+
for step in inputs.modulations:
|
| 202 |
+
yield step.conditioning
|
| 203 |
+
for block_modulation in step.block_modulations:
|
| 204 |
+
yield from block_modulation
|
| 205 |
+
yield from step.final_modulation
|
| 206 |
+
|
| 207 |
+
return all(t.device.type == "cuda" for t in all_on_cuda())
|
| 208 |
+
|
| 209 |
+
def run_action_flow(
|
| 210 |
+
self,
|
| 211 |
+
inputs: _ActionFlowInputs,
|
| 212 |
+
steps: int,
|
| 213 |
+
run_loop,
|
| 214 |
+
) -> torch.Tensor:
|
| 215 |
+
key = _cuda_graph_key(inputs, steps)
|
| 216 |
+
cache = self.action_flow_graph
|
| 217 |
+
if cache is None or cache.key != key:
|
| 218 |
+
static_inputs = _clone_static_inputs(inputs)
|
| 219 |
+
graph, output = _capture_cuda_graph(
|
| 220 |
+
lambda: run_loop(static_inputs, steps),
|
| 221 |
+
inputs.trajectory.device,
|
| 222 |
+
after_warmup=lambda: static_inputs.trajectory.copy_(inputs.trajectory),
|
| 223 |
+
)
|
| 224 |
+
cache = _ActionFlowCudaGraph(
|
| 225 |
+
key=key,
|
| 226 |
+
graph=graph,
|
| 227 |
+
static_inputs=static_inputs,
|
| 228 |
+
output=output,
|
| 229 |
+
)
|
| 230 |
+
self.action_flow_graph = cache
|
| 231 |
+
else:
|
| 232 |
+
_copy_inputs_(cache.static_inputs, inputs)
|
| 233 |
+
|
| 234 |
+
cache.graph.replay()
|
| 235 |
+
return cache.output.clone()
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class DepthDecodeCudaGraphManager:
|
| 239 |
+
def __init__(self, model: Any) -> None:
|
| 240 |
+
self.model = model
|
| 241 |
+
self.backbone = model.model
|
| 242 |
+
self.enabled = True
|
| 243 |
+
self.graph: Optional[_DepthDecodeCudaGraph] = None
|
| 244 |
+
self.graph_spec: Optional[_DepthDecodeCudaGraphSpec] = None
|
| 245 |
+
|
| 246 |
+
def set_enabled(self, enabled: bool) -> None:
|
| 247 |
+
self.enabled = bool(enabled)
|
| 248 |
+
|
| 249 |
+
def make_static_cache(self, max_cache_len: int) -> _DepthDecodeStaticCache:
|
| 250 |
+
return _DepthDecodeStaticCache(
|
| 251 |
+
config=self.model.config.text_config,
|
| 252 |
+
max_cache_len=max_cache_len,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
def _depth_decode_spec(self) -> _DepthDecodeCudaGraphSpec:
|
| 256 |
+
static = self.graph_spec
|
| 257 |
+
if static is None:
|
| 258 |
+
cfg = self.backbone.transformer.config
|
| 259 |
+
rotary_emb = getattr(self.backbone.transformer, "rotary_emb", None)
|
| 260 |
+
static = _DepthDecodeCudaGraphSpec(
|
| 261 |
+
eligible=(
|
| 262 |
+
not cfg.norm_after
|
| 263 |
+
and cfg.rope_scaling_layers is None
|
| 264 |
+
and getattr(rotary_emb, "rope_type", None) == "default"
|
| 265 |
+
and cfg._attn_implementation == "sdpa"
|
| 266 |
+
),
|
| 267 |
+
cache_key_prefix=(
|
| 268 |
+
cfg.hidden_size,
|
| 269 |
+
cfg.num_attention_heads,
|
| 270 |
+
cfg.num_key_value_heads,
|
| 271 |
+
cfg.head_dim,
|
| 272 |
+
cfg.num_hidden_layers,
|
| 273 |
+
cfg.use_qk_norm,
|
| 274 |
+
cfg.qk_norm_type,
|
| 275 |
+
cfg._attn_implementation,
|
| 276 |
+
),
|
| 277 |
+
num_hidden_layers=cfg.num_hidden_layers,
|
| 278 |
+
head_dim=cfg.head_dim,
|
| 279 |
+
num_attention_heads=cfg.num_attention_heads,
|
| 280 |
+
)
|
| 281 |
+
self.graph_spec = static
|
| 282 |
+
return static
|
| 283 |
+
|
| 284 |
+
def can_use(
|
| 285 |
+
self,
|
| 286 |
+
next_input_ids: torch.Tensor,
|
| 287 |
+
*,
|
| 288 |
+
past_key_values: Cache,
|
| 289 |
+
attention_bias: torch.Tensor,
|
| 290 |
+
) -> bool:
|
| 291 |
+
if (
|
| 292 |
+
not self.enabled
|
| 293 |
+
or self.model.training
|
| 294 |
+
or self.backbone.transformer.training
|
| 295 |
+
):
|
| 296 |
+
return False
|
| 297 |
+
if next_input_ids.device.type != "cuda":
|
| 298 |
+
return False
|
| 299 |
+
if (
|
| 300 |
+
next_input_ids.ndim != 2
|
| 301 |
+
or next_input_ids.shape[0] != 1
|
| 302 |
+
or next_input_ids.shape[1] != 1
|
| 303 |
+
):
|
| 304 |
+
return False
|
| 305 |
+
if not isinstance(past_key_values, _DepthDecodeStaticCache):
|
| 306 |
+
return False
|
| 307 |
+
if (
|
| 308 |
+
not torch.is_tensor(attention_bias)
|
| 309 |
+
or attention_bias.device != next_input_ids.device
|
| 310 |
+
):
|
| 311 |
+
return False
|
| 312 |
+
return self._depth_decode_spec().eligible
|
| 313 |
+
|
| 314 |
+
def _depth_decode_key(
|
| 315 |
+
self,
|
| 316 |
+
next_input_ids: torch.Tensor,
|
| 317 |
+
attention_bias: torch.Tensor,
|
| 318 |
+
) -> Tuple[Any, ...]:
|
| 319 |
+
device = next_input_ids.device
|
| 320 |
+
return (
|
| 321 |
+
self._depth_decode_spec().cache_key_prefix,
|
| 322 |
+
device.type,
|
| 323 |
+
device.index,
|
| 324 |
+
self.model.lm_head.weight.dtype,
|
| 325 |
+
attention_bias.shape[-1],
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
def _select_depth_decode_rope(
|
| 329 |
+
self, cos: torch.Tensor, sin: torch.Tensor, *, past_length: int
|
| 330 |
+
) -> None:
|
| 331 |
+
emb = self.backbone.transformer.rotary_emb
|
| 332 |
+
cos.copy_(emb._pos_cos_cache[0, :, past_length : past_length + 1, :])
|
| 333 |
+
sin.copy_(emb._pos_sin_cache[0, :, past_length : past_length + 1, :])
|
| 334 |
+
|
| 335 |
+
def _depth_decode_pre_layer(
|
| 336 |
+
self,
|
| 337 |
+
layer_idx: int,
|
| 338 |
+
hidden_states: torch.Tensor,
|
| 339 |
+
cos: torch.Tensor,
|
| 340 |
+
sin: torch.Tensor,
|
| 341 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 342 |
+
block = self.backbone.transformer.blocks[layer_idx]
|
| 343 |
+
attention = block.self_attn
|
| 344 |
+
residual = hidden_states
|
| 345 |
+
hidden_states = block.attn_norm(hidden_states)
|
| 346 |
+
|
| 347 |
+
input_shape = hidden_states.shape[:-1]
|
| 348 |
+
hidden_shape = (*input_shape, -1, attention.head_dim)
|
| 349 |
+
qkv = attention.att_proj(hidden_states)
|
| 350 |
+
query_states, key_states, value_states = qkv.split(attention.fused_dims, dim=-1)
|
| 351 |
+
value_states = value_states.view(hidden_shape)
|
| 352 |
+
|
| 353 |
+
apply_qk_norm = attention.q_norm is not None and attention.k_norm is not None
|
| 354 |
+
norm_after_view = apply_qk_norm and attention.qk_norm_type == "qwen3"
|
| 355 |
+
|
| 356 |
+
if apply_qk_norm and not norm_after_view:
|
| 357 |
+
query_states = attention.q_norm(query_states)
|
| 358 |
+
key_states = attention.k_norm(key_states)
|
| 359 |
+
|
| 360 |
+
query_states = query_states.view(hidden_shape)
|
| 361 |
+
key_states = key_states.view(hidden_shape)
|
| 362 |
+
|
| 363 |
+
if norm_after_view:
|
| 364 |
+
query_states = attention.q_norm(query_states)
|
| 365 |
+
key_states = attention.k_norm(key_states)
|
| 366 |
+
|
| 367 |
+
query_states = query_states.transpose(1, 2)
|
| 368 |
+
key_states = key_states.transpose(1, 2)
|
| 369 |
+
value_states = value_states.transpose(1, 2)
|
| 370 |
+
query_states, key_states = _apply_rotary_pos_emb(
|
| 371 |
+
query_states, key_states, cos, sin
|
| 372 |
+
)
|
| 373 |
+
return residual, query_states, key_states, value_states
|
| 374 |
+
|
| 375 |
+
def _depth_decode_pre0(
|
| 376 |
+
self,
|
| 377 |
+
token_ids: torch.Tensor,
|
| 378 |
+
cos: torch.Tensor,
|
| 379 |
+
sin: torch.Tensor,
|
| 380 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 381 |
+
inputs_embeds = self.model._embed_base_tokens(token_ids)
|
| 382 |
+
return self._depth_decode_pre_layer(0, inputs_embeds, cos, sin)
|
| 383 |
+
|
| 384 |
+
def _depth_decode_post_layer(
|
| 385 |
+
self,
|
| 386 |
+
layer_idx: int,
|
| 387 |
+
residual: torch.Tensor,
|
| 388 |
+
attn_context: torch.Tensor,
|
| 389 |
+
) -> torch.Tensor:
|
| 390 |
+
block = self.backbone.transformer.blocks[layer_idx]
|
| 391 |
+
attention = block.self_attn
|
| 392 |
+
input_shape = residual.shape[:-1]
|
| 393 |
+
attn_output = attn_context.reshape(*input_shape, -1).contiguous()
|
| 394 |
+
attn_output = attention.attn_out(attn_output)
|
| 395 |
+
hidden_states = residual + block.dropout(attn_output)
|
| 396 |
+
|
| 397 |
+
residual = hidden_states
|
| 398 |
+
hidden_states = block.ff_norm(hidden_states)
|
| 399 |
+
hidden_states = block.mlp(hidden_states)
|
| 400 |
+
hidden_states = residual + block.dropout(hidden_states)
|
| 401 |
+
return hidden_states
|
| 402 |
+
|
| 403 |
+
def _depth_decode_post_and_pre_next(
|
| 404 |
+
self,
|
| 405 |
+
layer_idx: int,
|
| 406 |
+
residual: torch.Tensor,
|
| 407 |
+
attn_context: torch.Tensor,
|
| 408 |
+
cos: torch.Tensor,
|
| 409 |
+
sin: torch.Tensor,
|
| 410 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 411 |
+
hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
|
| 412 |
+
return self._depth_decode_pre_layer(layer_idx + 1, hidden_states, cos, sin)
|
| 413 |
+
|
| 414 |
+
def _depth_decode_last_post(
|
| 415 |
+
self,
|
| 416 |
+
layer_idx: int,
|
| 417 |
+
residual: torch.Tensor,
|
| 418 |
+
attn_context: torch.Tensor,
|
| 419 |
+
) -> torch.Tensor:
|
| 420 |
+
hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
|
| 421 |
+
return self.backbone.transformer.ln_f(hidden_states)
|
| 422 |
+
|
| 423 |
+
def _build_depth_decode_graph(
|
| 424 |
+
self,
|
| 425 |
+
next_input_ids: torch.Tensor,
|
| 426 |
+
*,
|
| 427 |
+
past_length: int,
|
| 428 |
+
attention_bias: torch.Tensor,
|
| 429 |
+
) -> _DepthDecodeCudaGraph:
|
| 430 |
+
text_config = self.backbone.transformer.config
|
| 431 |
+
device = next_input_ids.device
|
| 432 |
+
dtype = self.model.lm_head.weight.dtype
|
| 433 |
+
static = self._depth_decode_spec()
|
| 434 |
+
num_layers = static.num_hidden_layers
|
| 435 |
+
head_dim = static.head_dim
|
| 436 |
+
max_cache_len = int(attention_bias.shape[-1])
|
| 437 |
+
max_rope_len = max(int(text_config.max_position_embeddings or 0), max_cache_len)
|
| 438 |
+
self.backbone.transformer.prepare_rope_cache(
|
| 439 |
+
device=device, max_seq_len=max_rope_len
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
token_ids = torch.empty((1, 1), device=device, dtype=torch.long)
|
| 443 |
+
cos = torch.empty((1, 1, head_dim), device=device, dtype=dtype)
|
| 444 |
+
sin = torch.empty_like(cos)
|
| 445 |
+
positions = torch.arange(max_cache_len, device=device, dtype=torch.long)
|
| 446 |
+
context_shape = (1, 1, static.num_attention_heads, head_dim)
|
| 447 |
+
|
| 448 |
+
token_ids.copy_(next_input_ids)
|
| 449 |
+
self._select_depth_decode_rope(cos, sin, past_length=past_length)
|
| 450 |
+
|
| 451 |
+
pre_graph, pre_output = _capture_cuda_graph(
|
| 452 |
+
lambda: self._depth_decode_pre0(token_ids, cos, sin),
|
| 453 |
+
device,
|
| 454 |
+
)
|
| 455 |
+
stages = [_DepthDecodeCudaGraphLayerStage(*pre_output)]
|
| 456 |
+
post_graphs = []
|
| 457 |
+
for layer_idx in range(num_layers - 1):
|
| 458 |
+
stage = stages[-1]
|
| 459 |
+
attn_context = torch.empty(context_shape, device=device, dtype=dtype)
|
| 460 |
+
graph, output = _capture_cuda_graph(
|
| 461 |
+
lambda layer_idx=layer_idx, stage=stage, attn_context=attn_context: (
|
| 462 |
+
self._depth_decode_post_and_pre_next(
|
| 463 |
+
layer_idx,
|
| 464 |
+
stage.residual,
|
| 465 |
+
attn_context,
|
| 466 |
+
cos,
|
| 467 |
+
sin,
|
| 468 |
+
)
|
| 469 |
+
),
|
| 470 |
+
device,
|
| 471 |
+
)
|
| 472 |
+
post_graphs.append(
|
| 473 |
+
_DepthDecodeCudaGraphPostStage(graph=graph, attn_context=attn_context)
|
| 474 |
+
)
|
| 475 |
+
stages.append(_DepthDecodeCudaGraphLayerStage(*output))
|
| 476 |
+
|
| 477 |
+
last_stage = stages[-1]
|
| 478 |
+
last_attn_context = torch.empty(context_shape, device=device, dtype=dtype)
|
| 479 |
+
last_graph, last_output = _capture_cuda_graph(
|
| 480 |
+
lambda: self._depth_decode_last_post(
|
| 481 |
+
num_layers - 1,
|
| 482 |
+
last_stage.residual,
|
| 483 |
+
last_attn_context,
|
| 484 |
+
),
|
| 485 |
+
device,
|
| 486 |
+
)
|
| 487 |
+
post_graphs.append(
|
| 488 |
+
_DepthDecodeCudaGraphPostStage(
|
| 489 |
+
graph=last_graph, attn_context=last_attn_context
|
| 490 |
+
)
|
| 491 |
+
)
|
| 492 |
+
return _DepthDecodeCudaGraph(
|
| 493 |
+
cache_key=self._depth_decode_key(next_input_ids, attention_bias),
|
| 494 |
+
pre_graph=pre_graph,
|
| 495 |
+
token_ids=token_ids,
|
| 496 |
+
cos=cos,
|
| 497 |
+
sin=sin,
|
| 498 |
+
positions=positions,
|
| 499 |
+
stages=tuple(stages),
|
| 500 |
+
post_graphs=tuple(post_graphs),
|
| 501 |
+
output=last_output,
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
def _get_depth_decode_graph(
|
| 505 |
+
self,
|
| 506 |
+
next_input_ids: torch.Tensor,
|
| 507 |
+
*,
|
| 508 |
+
past_length: int,
|
| 509 |
+
attention_bias: torch.Tensor,
|
| 510 |
+
) -> _DepthDecodeCudaGraph:
|
| 511 |
+
key = self._depth_decode_key(next_input_ids, attention_bias)
|
| 512 |
+
decode_graph = self.graph
|
| 513 |
+
if decode_graph is None or decode_graph.cache_key != key:
|
| 514 |
+
decode_graph = self._build_depth_decode_graph(
|
| 515 |
+
next_input_ids,
|
| 516 |
+
past_length=past_length,
|
| 517 |
+
attention_bias=attention_bias,
|
| 518 |
+
)
|
| 519 |
+
self.graph = decode_graph
|
| 520 |
+
else:
|
| 521 |
+
decode_graph.token_ids.copy_(next_input_ids)
|
| 522 |
+
self._select_depth_decode_rope(
|
| 523 |
+
decode_graph.cos, decode_graph.sin, past_length=past_length
|
| 524 |
+
)
|
| 525 |
+
return decode_graph
|
| 526 |
+
|
| 527 |
+
def _run_depth_decode_attention_core(
|
| 528 |
+
self,
|
| 529 |
+
layer_idx: int,
|
| 530 |
+
stage: _DepthDecodeCudaGraphLayerStage,
|
| 531 |
+
*,
|
| 532 |
+
past_key_values: Cache,
|
| 533 |
+
attention_bias: torch.Tensor,
|
| 534 |
+
cache_position: torch.Tensor,
|
| 535 |
+
cos: torch.Tensor,
|
| 536 |
+
sin: torch.Tensor,
|
| 537 |
+
) -> torch.Tensor:
|
| 538 |
+
attention = self.backbone.transformer.blocks[layer_idx].self_attn
|
| 539 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 540 |
+
key_states, value_states = past_key_values.update(
|
| 541 |
+
stage.key,
|
| 542 |
+
stage.value,
|
| 543 |
+
layer_idx,
|
| 544 |
+
cache_kwargs,
|
| 545 |
+
)
|
| 546 |
+
key_states = _repeat_kv(key_states, attention.num_key_value_groups)
|
| 547 |
+
value_states = _repeat_kv(value_states, attention.num_key_value_groups)
|
| 548 |
+
attn_output = F.scaled_dot_product_attention(
|
| 549 |
+
stage.query,
|
| 550 |
+
key_states,
|
| 551 |
+
value_states,
|
| 552 |
+
attn_mask=attention_bias,
|
| 553 |
+
dropout_p=0.0,
|
| 554 |
+
is_causal=False,
|
| 555 |
+
)
|
| 556 |
+
return attn_output.transpose(1, 2)
|
| 557 |
+
|
| 558 |
+
def run(
|
| 559 |
+
self,
|
| 560 |
+
next_input_ids: torch.Tensor,
|
| 561 |
+
*,
|
| 562 |
+
past_key_values: Cache,
|
| 563 |
+
attention_bias: torch.Tensor,
|
| 564 |
+
past_length: int,
|
| 565 |
+
) -> Tuple[torch.Tensor, Cache]:
|
| 566 |
+
end = past_length + 1
|
| 567 |
+
decode_graph = self._get_depth_decode_graph(
|
| 568 |
+
next_input_ids,
|
| 569 |
+
past_length=past_length,
|
| 570 |
+
attention_bias=attention_bias,
|
| 571 |
+
)
|
| 572 |
+
cache_position = decode_graph.positions[past_length:end]
|
| 573 |
+
attention_bias_q = attention_bias[:, :, past_length:end, :end]
|
| 574 |
+
|
| 575 |
+
decode_graph.pre_graph.replay()
|
| 576 |
+
|
| 577 |
+
for layer_idx, post_graph in enumerate(decode_graph.post_graphs):
|
| 578 |
+
attn_context = self._run_depth_decode_attention_core(
|
| 579 |
+
layer_idx,
|
| 580 |
+
decode_graph.stages[layer_idx],
|
| 581 |
+
past_key_values=past_key_values,
|
| 582 |
+
attention_bias=attention_bias_q,
|
| 583 |
+
cache_position=cache_position,
|
| 584 |
+
cos=decode_graph.cos,
|
| 585 |
+
sin=decode_graph.sin,
|
| 586 |
+
)
|
| 587 |
+
post_graph.attn_context.copy_(attn_context)
|
| 588 |
+
post_graph.graph.replay()
|
| 589 |
+
|
| 590 |
+
return decode_graph.output, past_key_values
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
def _cuda_graph_tensor_signature(
|
| 594 |
+
tensor: Optional[torch.Tensor],
|
| 595 |
+
) -> Optional[Tuple[Any, ...]]:
|
| 596 |
+
if tensor is None:
|
| 597 |
+
return None
|
| 598 |
+
return (
|
| 599 |
+
tuple(tensor.shape),
|
| 600 |
+
tuple(tensor.stride()),
|
| 601 |
+
str(tensor.dtype),
|
| 602 |
+
str(tensor.device),
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
def _cuda_graph_context_signature(context: Any) -> Tuple[Any, ...]:
|
| 607 |
+
sig = _cuda_graph_tensor_signature
|
| 608 |
+
return (
|
| 609 |
+
tuple((sig(k), sig(v)) for k, v in context.kv_contexts),
|
| 610 |
+
sig(context.cross_mask),
|
| 611 |
+
sig(context.self_mask),
|
| 612 |
+
sig(context.valid_action),
|
| 613 |
+
None
|
| 614 |
+
if context.rope_cache is None
|
| 615 |
+
else tuple(sig(t) for t in context.rope_cache),
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
def _cuda_graph_modulation_signature(modulations: Sequence[Any]) -> Tuple[Any, ...]:
|
| 620 |
+
sig = _cuda_graph_tensor_signature
|
| 621 |
+
return tuple(
|
| 622 |
+
(
|
| 623 |
+
sig(step.conditioning),
|
| 624 |
+
tuple(
|
| 625 |
+
tuple(sig(t) for t in block_modulation)
|
| 626 |
+
for block_modulation in step.block_modulations
|
| 627 |
+
),
|
| 628 |
+
tuple(sig(t) for t in step.final_modulation),
|
| 629 |
+
)
|
| 630 |
+
for step in modulations
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
def _cuda_graph_key(inputs: _ActionFlowInputs, steps: int) -> Tuple[Any, ...]:
|
| 635 |
+
sig = _cuda_graph_tensor_signature
|
| 636 |
+
return (
|
| 637 |
+
sig(inputs.trajectory),
|
| 638 |
+
_cuda_graph_context_signature(inputs.context),
|
| 639 |
+
_cuda_graph_modulation_signature(inputs.modulations),
|
| 640 |
+
sig(inputs.action_dim_is_pad),
|
| 641 |
+
int(steps),
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
def _clone_static_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
| 646 |
+
if tensor is None:
|
| 647 |
+
return None
|
| 648 |
+
static = torch.empty_strided(
|
| 649 |
+
tuple(tensor.shape),
|
| 650 |
+
tuple(tensor.stride()),
|
| 651 |
+
device=tensor.device,
|
| 652 |
+
dtype=tensor.dtype,
|
| 653 |
+
)
|
| 654 |
+
static.copy_(tensor)
|
| 655 |
+
return static
|
| 656 |
+
|
| 657 |
+
|
| 658 |
+
def _clone_static_context(context: Any) -> Any:
|
| 659 |
+
rope_cache = None
|
| 660 |
+
if context.rope_cache is not None:
|
| 661 |
+
rope_cache = tuple(_clone_static_tensor(t) for t in context.rope_cache)
|
| 662 |
+
return context.__class__(
|
| 663 |
+
kv_contexts=tuple(
|
| 664 |
+
(_clone_static_tensor(k), _clone_static_tensor(v))
|
| 665 |
+
for k, v in context.kv_contexts
|
| 666 |
+
),
|
| 667 |
+
cross_mask=_clone_static_tensor(context.cross_mask),
|
| 668 |
+
self_mask=_clone_static_tensor(context.self_mask),
|
| 669 |
+
valid_action=_clone_static_tensor(context.valid_action),
|
| 670 |
+
rope_cache=rope_cache,
|
| 671 |
+
)
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
def _clone_static_modulations(modulations: Sequence[Any]) -> Sequence[Any]:
|
| 675 |
+
return tuple(
|
| 676 |
+
step.__class__(
|
| 677 |
+
conditioning=_clone_static_tensor(step.conditioning),
|
| 678 |
+
block_modulations=tuple(
|
| 679 |
+
tuple(_clone_static_tensor(t) for t in block_modulation)
|
| 680 |
+
for block_modulation in step.block_modulations
|
| 681 |
+
),
|
| 682 |
+
final_modulation=tuple(
|
| 683 |
+
_clone_static_tensor(t) for t in step.final_modulation
|
| 684 |
+
),
|
| 685 |
+
)
|
| 686 |
+
for step in modulations
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
def _clone_static_inputs(inputs: _ActionFlowInputs) -> _ActionFlowInputs:
|
| 691 |
+
return _ActionFlowInputs(
|
| 692 |
+
trajectory=_clone_static_tensor(inputs.trajectory),
|
| 693 |
+
context=_clone_static_context(inputs.context),
|
| 694 |
+
modulations=_clone_static_modulations(inputs.modulations),
|
| 695 |
+
action_dim_is_pad=_clone_static_tensor(inputs.action_dim_is_pad),
|
| 696 |
+
)
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
def _copy_context_(dst: Any, src: Any) -> None:
|
| 700 |
+
for (dst_k, dst_v), (src_k, src_v) in zip(dst.kv_contexts, src.kv_contexts):
|
| 701 |
+
dst_k.copy_(src_k)
|
| 702 |
+
dst_v.copy_(src_v)
|
| 703 |
+
if src.cross_mask is not None:
|
| 704 |
+
dst.cross_mask.copy_(src.cross_mask)
|
| 705 |
+
if src.self_mask is not None:
|
| 706 |
+
dst.self_mask.copy_(src.self_mask)
|
| 707 |
+
if src.valid_action is not None:
|
| 708 |
+
dst.valid_action.copy_(src.valid_action)
|
| 709 |
+
if src.rope_cache is not None:
|
| 710 |
+
for dst_tensor, src_tensor in zip(dst.rope_cache, src.rope_cache):
|
| 711 |
+
dst_tensor.copy_(src_tensor)
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
def _copy_inputs_(dst: _ActionFlowInputs, src: _ActionFlowInputs) -> None:
|
| 715 |
+
dst.trajectory.copy_(src.trajectory)
|
| 716 |
+
_copy_context_(dst.context, src.context)
|
| 717 |
+
if src.action_dim_is_pad is not None:
|
| 718 |
+
dst.action_dim_is_pad.copy_(src.action_dim_is_pad)
|
| 719 |
+
|
| 720 |
+
|
| 721 |
+
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 722 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 723 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 724 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 725 |
+
|
| 726 |
+
|
| 727 |
+
def _apply_rotary_pos_emb(
|
| 728 |
+
q: torch.Tensor,
|
| 729 |
+
k: torch.Tensor,
|
| 730 |
+
cos: torch.Tensor,
|
| 731 |
+
sin: torch.Tensor,
|
| 732 |
+
unsqueeze_dim: int = 1,
|
| 733 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 734 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 735 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 736 |
+
q_embed = (q * cos) + (_rotate_half(q) * sin)
|
| 737 |
+
k_embed = (k * cos) + (_rotate_half(k) * sin)
|
| 738 |
+
return q_embed, k_embed
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 742 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 743 |
+
if n_rep == 1:
|
| 744 |
+
return hidden_states
|
| 745 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(
|
| 746 |
+
batch, num_key_value_heads, n_rep, slen, head_dim
|
| 747 |
+
)
|
| 748 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 749 |
+
|
| 750 |
+
|
| 751 |
+
def _capture_cuda_graph(
|
| 752 |
+
fn,
|
| 753 |
+
device: torch.device,
|
| 754 |
+
*,
|
| 755 |
+
after_warmup=None,
|
| 756 |
+
) -> Tuple[torch.cuda.CUDAGraph, Any]:
|
| 757 |
+
warmup_stream = torch.cuda.Stream(device=device)
|
| 758 |
+
warmup_stream.wait_stream(torch.cuda.current_stream(device))
|
| 759 |
+
with torch.cuda.stream(warmup_stream):
|
| 760 |
+
fn()
|
| 761 |
+
torch.cuda.current_stream(device).wait_stream(warmup_stream)
|
| 762 |
+
if after_warmup is not None:
|
| 763 |
+
after_warmup()
|
| 764 |
+
|
| 765 |
+
graph = torch.cuda.CUDAGraph()
|
| 766 |
+
with torch.cuda.graph(graph):
|
| 767 |
+
output = fn()
|
| 768 |
+
return graph, output
|
model-00001-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5aa56113a3c1c95f0c03b90bada5b7ac2babf59c752a3d6201b8362a530c6809
|
| 3 |
+
size 4919324120
|
model-00002-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1fd95aec9ad144d3f4eac42db88a000321e016beeae606ecfcb84fcbdb13ecf4
|
| 3 |
+
size 4844690992
|
model-00003-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e6c455434915e74a8d750bedd0fc4c23b5a8ce37954aaff9c3e3f22085a26890
|
| 3 |
+
size 4844691024
|
model-00004-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ec4de13f1384f7b206f6090ade912667df17bfdb5da66fab7a3edf5c18c48ded
|
| 3 |
+
size 4998106920
|
model-00005-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2808d1df315627f809d54baf4108160b10daa5095d64d4319ed5cd24e94be851
|
| 3 |
+
size 2334605176
|
model.safetensors.index.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
modeling_molmoact2.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
norm_stats.json
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"format": "molmoact2_norm_stats.v1",
|
| 3 |
+
"norm_mode": "q01_q99",
|
| 4 |
+
"metadata_by_tag": {
|
| 5 |
+
"libero": {
|
| 6 |
+
"action_key": "action",
|
| 7 |
+
"state_key": "observation.state",
|
| 8 |
+
"camera_keys": [
|
| 9 |
+
"observation.images.image",
|
| 10 |
+
"observation.images.wrist_image"
|
| 11 |
+
],
|
| 12 |
+
"normalize_gripper": false,
|
| 13 |
+
"action_horizon": 10,
|
| 14 |
+
"n_action_steps": 10,
|
| 15 |
+
"setup_type": "single franka robotic arm in libero",
|
| 16 |
+
"control_mode": "delta end-effector pose",
|
| 17 |
+
"action_stats": {
|
| 18 |
+
"min": [
|
| 19 |
+
-0.9375,
|
| 20 |
+
-0.9375,
|
| 21 |
+
-0.9375,
|
| 22 |
+
-0.2582142949104309,
|
| 23 |
+
-0.375,
|
| 24 |
+
-0.3675000071525574,
|
| 25 |
+
-1.0
|
| 26 |
+
],
|
| 27 |
+
"max": [
|
| 28 |
+
0.9375,
|
| 29 |
+
0.9375,
|
| 30 |
+
0.9375,
|
| 31 |
+
0.3557142913341522,
|
| 32 |
+
0.375,
|
| 33 |
+
0.375,
|
| 34 |
+
1.0
|
| 35 |
+
],
|
| 36 |
+
"mean": [
|
| 37 |
+
0.06278156570450202,
|
| 38 |
+
0.08684081017968912,
|
| 39 |
+
-0.09037305936952836,
|
| 40 |
+
0.0005407430783705141,
|
| 41 |
+
0.0056433796450358715,
|
| 42 |
+
-0.005229098518603562,
|
| 43 |
+
-0.04964072167678376
|
| 44 |
+
],
|
| 45 |
+
"std": [
|
| 46 |
+
0.3355237114945633,
|
| 47 |
+
0.3784469867268323,
|
| 48 |
+
0.44472859911256607,
|
| 49 |
+
0.03924354049229973,
|
| 50 |
+
0.06339296407444922,
|
| 51 |
+
0.07797027713976648,
|
| 52 |
+
0.9987671529022402
|
| 53 |
+
],
|
| 54 |
+
"count": [
|
| 55 |
+
273465.0
|
| 56 |
+
],
|
| 57 |
+
"q01": [
|
| 58 |
+
-0.6792031928846481,
|
| 59 |
+
-0.7736573115323259,
|
| 60 |
+
-0.8728073904104404,
|
| 61 |
+
-0.10277447185825356,
|
| 62 |
+
-0.15509810617083444,
|
| 63 |
+
-0.20289961475228455,
|
| 64 |
+
-1.0
|
| 65 |
+
],
|
| 66 |
+
"q10": [
|
| 67 |
+
-0.328718721971874,
|
| 68 |
+
-0.3626162647358338,
|
| 69 |
+
-0.6610056625361599,
|
| 70 |
+
-0.03907064459203904,
|
| 71 |
+
-0.06428551162168497,
|
| 72 |
+
-0.07928202560631951,
|
| 73 |
+
-1.0
|
| 74 |
+
],
|
| 75 |
+
"q50": [
|
| 76 |
+
0.015333975787982875,
|
| 77 |
+
0.006437010746251905,
|
| 78 |
+
-0.07265095199149316,
|
| 79 |
+
-1.701317418858285e-05,
|
| 80 |
+
0.00021801956089207239,
|
| 81 |
+
-5.852172701796134e-05,
|
| 82 |
+
-0.12287333595187695
|
| 83 |
+
],
|
| 84 |
+
"q90": [
|
| 85 |
+
0.5238177265233007,
|
| 86 |
+
0.671417970219526,
|
| 87 |
+
0.5384412174699407,
|
| 88 |
+
0.040331002487738146,
|
| 89 |
+
0.08240652401791884,
|
| 90 |
+
0.0690125677722944,
|
| 91 |
+
0.9999141552827842
|
| 92 |
+
],
|
| 93 |
+
"q99": [
|
| 94 |
+
0.8536542808794264,
|
| 95 |
+
0.8637811051429717,
|
| 96 |
+
0.9363295547540081,
|
| 97 |
+
0.13045695485814487,
|
| 98 |
+
0.18015313802054606,
|
| 99 |
+
0.24129727661704234,
|
| 100 |
+
0.9999914155282784
|
| 101 |
+
],
|
| 102 |
+
"names": [
|
| 103 |
+
"x",
|
| 104 |
+
"y",
|
| 105 |
+
"z",
|
| 106 |
+
"roll",
|
| 107 |
+
"pitch",
|
| 108 |
+
"yaw",
|
| 109 |
+
"gripper"
|
| 110 |
+
],
|
| 111 |
+
"mask": [
|
| 112 |
+
true,
|
| 113 |
+
true,
|
| 114 |
+
true,
|
| 115 |
+
true,
|
| 116 |
+
true,
|
| 117 |
+
true,
|
| 118 |
+
false
|
| 119 |
+
]
|
| 120 |
+
},
|
| 121 |
+
"state_stats": {
|
| 122 |
+
"min": [
|
| 123 |
+
-0.4828203022480011,
|
| 124 |
+
-0.3255046010017395,
|
| 125 |
+
0.008128180168569088,
|
| 126 |
+
0.35277295112609863,
|
| 127 |
+
-3.641430377960205,
|
| 128 |
+
-1.842738389968872,
|
| 129 |
+
-0.0013586411951109767,
|
| 130 |
+
-0.042040832340717316
|
| 131 |
+
],
|
| 132 |
+
"max": [
|
| 133 |
+
0.21031762659549713,
|
| 134 |
+
0.39128610491752625,
|
| 135 |
+
1.3660105466842651,
|
| 136 |
+
3.6714255809783936,
|
| 137 |
+
3.560650587081909,
|
| 138 |
+
1.386339545249939,
|
| 139 |
+
0.04233968257904053,
|
| 140 |
+
0.0013633022317662835
|
| 141 |
+
],
|
| 142 |
+
"mean": [
|
| 143 |
+
-0.04651878279191748,
|
| 144 |
+
0.034409066787269356,
|
| 145 |
+
0.7645525031210381,
|
| 146 |
+
2.9722094975655056,
|
| 147 |
+
-0.22046978549041713,
|
| 148 |
+
-0.1255794031738752,
|
| 149 |
+
0.026914253269017054,
|
| 150 |
+
-0.027190783616938205
|
| 151 |
+
],
|
| 152 |
+
"std": [
|
| 153 |
+
0.10494395508556839,
|
| 154 |
+
0.1517661933220375,
|
| 155 |
+
0.378516707505034,
|
| 156 |
+
0.34427344187858827,
|
| 157 |
+
0.9069468516043042,
|
| 158 |
+
0.32539190149967406,
|
| 159 |
+
0.01417590382231912,
|
| 160 |
+
0.014058894296088888
|
| 161 |
+
],
|
| 162 |
+
"count": [
|
| 163 |
+
273465.0
|
| 164 |
+
],
|
| 165 |
+
"q01": [
|
| 166 |
+
-0.31479429659059555,
|
| 167 |
+
-0.26691552643710226,
|
| 168 |
+
0.5194626050191016,
|
| 169 |
+
2.159994551314992,
|
| 170 |
+
-1.801294177865994,
|
| 171 |
+
-0.8949778881389838,
|
| 172 |
+
0.003382730811955442,
|
| 173 |
+
-0.04008920533069468
|
| 174 |
+
],
|
| 175 |
+
"q10": [
|
| 176 |
+
-0.18409729127502492,
|
| 177 |
+
-0.158759498072202,
|
| 178 |
+
0.5694822295083012,
|
| 179 |
+
2.501970046458546,
|
| 180 |
+
-1.1889107640062022,
|
| 181 |
+
-0.5297043790093273,
|
| 182 |
+
0.007573322430226042,
|
| 183 |
+
-0.039827946964434036
|
| 184 |
+
],
|
| 185 |
+
"q50": [
|
| 186 |
+
-0.02822545357081922,
|
| 187 |
+
0.029718887641213443,
|
| 188 |
+
0.7185643731428462,
|
| 189 |
+
3.0915725099012166,
|
| 190 |
+
-0.12491069931831773,
|
| 191 |
+
-0.08338984738533357,
|
| 192 |
+
0.030648370056451133,
|
| 193 |
+
-0.031519123023466586
|
| 194 |
+
],
|
| 195 |
+
"q90": [
|
| 196 |
+
0.06725052913150302,
|
| 197 |
+
0.23387160335018267,
|
| 198 |
+
0.9599947530498419,
|
| 199 |
+
3.1743361507512997,
|
| 200 |
+
0.5456820212337484,
|
| 201 |
+
0.20414514594693875,
|
| 202 |
+
0.03985537019679712,
|
| 203 |
+
-0.008040434619037518
|
| 204 |
+
],
|
| 205 |
+
"q99": [
|
| 206 |
+
0.1222615490116252,
|
| 207 |
+
0.3140223876046953,
|
| 208 |
+
1.042961724319958,
|
| 209 |
+
3.277638017923068,
|
| 210 |
+
1.724488202195691,
|
| 211 |
+
0.5659922739094448,
|
| 212 |
+
0.04009682017699841,
|
| 213 |
+
-0.003493522538066522
|
| 214 |
+
],
|
| 215 |
+
"names": [
|
| 216 |
+
"x",
|
| 217 |
+
"y",
|
| 218 |
+
"z",
|
| 219 |
+
"rx",
|
| 220 |
+
"ry",
|
| 221 |
+
"rz",
|
| 222 |
+
"rw",
|
| 223 |
+
"gripper"
|
| 224 |
+
],
|
| 225 |
+
"mask": [
|
| 226 |
+
true,
|
| 227 |
+
true,
|
| 228 |
+
true,
|
| 229 |
+
true,
|
| 230 |
+
true,
|
| 231 |
+
true,
|
| 232 |
+
true,
|
| 233 |
+
false
|
| 234 |
+
]
|
| 235 |
+
}
|
| 236 |
+
}
|
| 237 |
+
}
|
| 238 |
+
}
|
processing_molmoact2.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Processor class for MolmoAct2.
|
| 3 |
+
"""
|
| 4 |
+
from typing import Optional, Union
|
| 5 |
+
import dataclasses
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from transformers.image_utils import ImageInput
|
| 10 |
+
from transformers.video_utils import VideoInput
|
| 11 |
+
from transformers.processing_utils import (
|
| 12 |
+
Unpack,
|
| 13 |
+
ProcessingKwargs,
|
| 14 |
+
ProcessorMixin,
|
| 15 |
+
)
|
| 16 |
+
from transformers.feature_extraction_utils import BatchFeature
|
| 17 |
+
from transformers.tokenization_utils_base import TextInput, PreTokenizedInput
|
| 18 |
+
from transformers.utils import logging
|
| 19 |
+
|
| 20 |
+
from transformers import AutoTokenizer
|
| 21 |
+
from .image_processing_molmoact2 import MolmoAct2ImagesKwargs, MolmoAct2ImageProcessor
|
| 22 |
+
from .video_processing_molmoact2 import MolmoAct2VideoProcessorKwargs, MolmoAct2VideoProcessor
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# Special tokens, these should be present in any tokenizer we use since the preprocessor uses them
|
| 29 |
+
IMAGE_PATCH_TOKEN = f"<im_patch>" # Where to insert high-res tokens
|
| 30 |
+
IMAGE_LOW_RES_TOKEN = f"<im_low>" # Where to insert low-res tokens
|
| 31 |
+
IM_START_TOKEN = f"<im_start>"
|
| 32 |
+
LOW_RES_IMAGE_START_TOKEN = f"<low_res_im_start>"
|
| 33 |
+
FRAME_START_TOKEN = f"<frame_start>"
|
| 34 |
+
IM_END_TOKEN = f"<im_end>"
|
| 35 |
+
FRAME_END_TOKEN= f"<frame_end>"
|
| 36 |
+
IM_COL_TOKEN = f"<im_col>"
|
| 37 |
+
IMAGE_PROMPT = "<|image|>"
|
| 38 |
+
VIDEO_PROMPT = "<|video|>"
|
| 39 |
+
|
| 40 |
+
IMAGE_TOKENS = [
|
| 41 |
+
IMAGE_PATCH_TOKEN,
|
| 42 |
+
IM_COL_TOKEN,
|
| 43 |
+
IM_START_TOKEN,
|
| 44 |
+
LOW_RES_IMAGE_START_TOKEN,
|
| 45 |
+
FRAME_START_TOKEN,
|
| 46 |
+
IM_END_TOKEN,
|
| 47 |
+
FRAME_END_TOKEN,
|
| 48 |
+
IMAGE_LOW_RES_TOKEN,
|
| 49 |
+
]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class MolmoAct2ProcessorKwargs(ProcessingKwargs, total=False):
|
| 53 |
+
"""MolmoAct2 processor kwargs"""
|
| 54 |
+
images_kwargs: MolmoAct2ImagesKwargs
|
| 55 |
+
videos_kwargs: MolmoAct2VideoProcessorKwargs
|
| 56 |
+
_defaults = {
|
| 57 |
+
"text_kwargs": {
|
| 58 |
+
"padding": False,
|
| 59 |
+
"return_mm_token_type_ids": True,
|
| 60 |
+
},
|
| 61 |
+
"videos_kwargs": {"return_metadata": True},
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class MolmoAct2Processor(ProcessorMixin):
|
| 66 |
+
attributes = ["image_processor", "video_processor", "tokenizer"]
|
| 67 |
+
optional_attributes = [
|
| 68 |
+
"chat_template",
|
| 69 |
+
"time_mode",
|
| 70 |
+
"image_use_col_tokens",
|
| 71 |
+
"use_single_crop_col_tokens",
|
| 72 |
+
"use_single_crop_start_token",
|
| 73 |
+
"video_use_col_tokens",
|
| 74 |
+
"use_frame_special_tokens",
|
| 75 |
+
]
|
| 76 |
+
image_processor_class = "AutoImageProcessor"
|
| 77 |
+
video_processor_class = "AutoVideoProcessor"
|
| 78 |
+
tokenizer_class = "AutoTokenizer"
|
| 79 |
+
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
image_processor: MolmoAct2ImageProcessor = None,
|
| 83 |
+
video_processor: MolmoAct2VideoProcessor = None,
|
| 84 |
+
tokenizer: AutoTokenizer = None,
|
| 85 |
+
chat_template: Optional[str] = None,
|
| 86 |
+
image_use_col_tokens: Optional[bool] = True,
|
| 87 |
+
use_single_crop_col_tokens: Optional[bool] = None,
|
| 88 |
+
use_single_crop_start_token: Optional[bool] = True,
|
| 89 |
+
video_use_col_tokens: Optional[bool] = False,
|
| 90 |
+
use_frame_special_tokens: Optional[bool] = True,
|
| 91 |
+
**kwargs
|
| 92 |
+
) -> None:
|
| 93 |
+
super().__init__(
|
| 94 |
+
image_processor,
|
| 95 |
+
video_processor,
|
| 96 |
+
tokenizer,
|
| 97 |
+
chat_template=chat_template,
|
| 98 |
+
)
|
| 99 |
+
self.image_use_col_tokens = image_use_col_tokens
|
| 100 |
+
self.use_single_crop_col_tokens = use_single_crop_col_tokens
|
| 101 |
+
self.use_single_crop_start_token = use_single_crop_start_token
|
| 102 |
+
self.video_use_col_tokens = video_use_col_tokens
|
| 103 |
+
self.use_frame_special_tokens = use_frame_special_tokens
|
| 104 |
+
|
| 105 |
+
self.image_placeholder_token = IMAGE_PROMPT
|
| 106 |
+
self.video_placeholder_token = VIDEO_PROMPT
|
| 107 |
+
self.image_token_ids = [
|
| 108 |
+
tokenizer.convert_tokens_to_ids(token)
|
| 109 |
+
for token in IMAGE_TOKENS
|
| 110 |
+
]
|
| 111 |
+
|
| 112 |
+
def get_image_tokens(self, image_grid: np.ndarray):
|
| 113 |
+
resized_h, resized_w, height, width = image_grid
|
| 114 |
+
if int(height) == 0 or int(width) == 0:
|
| 115 |
+
per_row = np.full(resized_w, IMAGE_PATCH_TOKEN)
|
| 116 |
+
use_single_crop_col_tokens = (
|
| 117 |
+
self.image_use_col_tokens
|
| 118 |
+
if self.use_single_crop_col_tokens is None
|
| 119 |
+
else self.use_single_crop_col_tokens
|
| 120 |
+
)
|
| 121 |
+
if use_single_crop_col_tokens:
|
| 122 |
+
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
|
| 123 |
+
joint = [
|
| 124 |
+
[IM_START_TOKEN],
|
| 125 |
+
np.tile(per_row, [resized_h]),
|
| 126 |
+
[IM_END_TOKEN],
|
| 127 |
+
]
|
| 128 |
+
return np.concatenate(joint)
|
| 129 |
+
per_row = np.full(width, IMAGE_PATCH_TOKEN)
|
| 130 |
+
if self.image_use_col_tokens:
|
| 131 |
+
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
|
| 132 |
+
joint = [
|
| 133 |
+
[IM_START_TOKEN],
|
| 134 |
+
np.tile(per_row, [height]),
|
| 135 |
+
[IM_END_TOKEN],
|
| 136 |
+
]
|
| 137 |
+
per_row = np.full(resized_w, IMAGE_PATCH_TOKEN)
|
| 138 |
+
use_single_crop_col_tokens = (
|
| 139 |
+
self.image_use_col_tokens
|
| 140 |
+
if self.use_single_crop_col_tokens is None
|
| 141 |
+
else self.use_single_crop_col_tokens
|
| 142 |
+
)
|
| 143 |
+
image_start_token = (
|
| 144 |
+
LOW_RES_IMAGE_START_TOKEN
|
| 145 |
+
if self.use_single_crop_start_token
|
| 146 |
+
else IM_START_TOKEN
|
| 147 |
+
)
|
| 148 |
+
if use_single_crop_col_tokens:
|
| 149 |
+
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
|
| 150 |
+
joint = [
|
| 151 |
+
[image_start_token],
|
| 152 |
+
np.tile(per_row, [resized_h]),
|
| 153 |
+
[IM_END_TOKEN],
|
| 154 |
+
] + joint
|
| 155 |
+
|
| 156 |
+
return np.concatenate(joint)
|
| 157 |
+
|
| 158 |
+
def get_video_string(
|
| 159 |
+
self,
|
| 160 |
+
video_grid: np.ndarray,
|
| 161 |
+
timestamps: np.ndarray,
|
| 162 |
+
):
|
| 163 |
+
if self.use_frame_special_tokens:
|
| 164 |
+
start_token_id = FRAME_START_TOKEN
|
| 165 |
+
end_token_id = FRAME_END_TOKEN
|
| 166 |
+
else:
|
| 167 |
+
start_token_id = IM_START_TOKEN
|
| 168 |
+
end_token_id = IM_END_TOKEN
|
| 169 |
+
|
| 170 |
+
num_frames, h, w = video_grid
|
| 171 |
+
video_string: str = ""
|
| 172 |
+
for frame_idx, frame_time in enumerate(timestamps):
|
| 173 |
+
# `per-frame-compact` time mode
|
| 174 |
+
prev_space = " " if frame_idx > 0 else ""
|
| 175 |
+
frame_prefix = prev_space + f"{frame_time:.1f} " # explicit whitespace before/after image tokens
|
| 176 |
+
|
| 177 |
+
video_string += frame_prefix
|
| 178 |
+
per_row = np.full(w, IMAGE_PATCH_TOKEN)
|
| 179 |
+
if self.video_use_col_tokens:
|
| 180 |
+
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
|
| 181 |
+
extra_tokens = np.tile(per_row, [h])
|
| 182 |
+
video_tokens = [
|
| 183 |
+
[start_token_id],
|
| 184 |
+
extra_tokens,
|
| 185 |
+
[end_token_id],
|
| 186 |
+
]
|
| 187 |
+
video_string += "".join(np.concatenate(video_tokens, 0))
|
| 188 |
+
|
| 189 |
+
return video_string
|
| 190 |
+
|
| 191 |
+
def insert_bos(
|
| 192 |
+
self,
|
| 193 |
+
input_ids: np.ndarray,
|
| 194 |
+
attention_mask: np.ndarray,
|
| 195 |
+
bos_token_id: int,
|
| 196 |
+
pad_token_id: int,
|
| 197 |
+
):
|
| 198 |
+
"""
|
| 199 |
+
Args:
|
| 200 |
+
input_ids: [B, S] array with left padding
|
| 201 |
+
attention_mask: [B, S] array (0 for pad, 1 for valid)
|
| 202 |
+
bos_token_id: int
|
| 203 |
+
pad_token_id: int
|
| 204 |
+
Returns:
|
| 205 |
+
input_ids_out: [B, S] or [B, S+1] array with bos inserted if needed
|
| 206 |
+
attention_mask_out: same shape as input_ids_out
|
| 207 |
+
"""
|
| 208 |
+
|
| 209 |
+
need_to_expand = len(input_ids.shape) == 1
|
| 210 |
+
if need_to_expand:
|
| 211 |
+
input_ids = input_ids[None, :]
|
| 212 |
+
attention_mask = attention_mask[None, :]
|
| 213 |
+
|
| 214 |
+
B, S = input_ids.shape
|
| 215 |
+
|
| 216 |
+
# Handle zero-length sequence
|
| 217 |
+
if S == 0:
|
| 218 |
+
new_input_ids = np.full((B, 1), bos_token_id, dtype=input_ids.dtype)
|
| 219 |
+
new_attention_mask = np.ones((B, 1), dtype=attention_mask.dtype)
|
| 220 |
+
if need_to_expand:
|
| 221 |
+
new_input_ids = new_input_ids[0]
|
| 222 |
+
new_attention_mask = new_attention_mask[0]
|
| 223 |
+
return new_input_ids, new_attention_mask
|
| 224 |
+
|
| 225 |
+
first_valid_index = (attention_mask == 1).argmax(axis=-1) # [B]
|
| 226 |
+
bos_already_present = np.all(input_ids[np.arange(B), first_valid_index] == bos_token_id)
|
| 227 |
+
|
| 228 |
+
if bos_already_present:
|
| 229 |
+
if need_to_expand:
|
| 230 |
+
input_ids = input_ids[0]
|
| 231 |
+
attention_mask = attention_mask[0]
|
| 232 |
+
return input_ids, attention_mask
|
| 233 |
+
else:
|
| 234 |
+
new_input_ids = np.full((B, S+1), pad_token_id, dtype=input_ids.dtype)
|
| 235 |
+
new_attention_mask = np.zeros((B, S+1), dtype=attention_mask.dtype)
|
| 236 |
+
|
| 237 |
+
src_idx = np.tile(np.arange(S), (B, 1)) # [B, S]
|
| 238 |
+
valid_mask = src_idx >= first_valid_index[:, None] # [B, S]
|
| 239 |
+
tgt_idx = src_idx + 1 # shit right
|
| 240 |
+
batch_idx = np.tile(np.arange(B)[:, None], (1, S)) # [B, S]
|
| 241 |
+
|
| 242 |
+
# flatten valid_positions
|
| 243 |
+
flat_vals = input_ids[valid_mask]
|
| 244 |
+
flat_batch = batch_idx[valid_mask]
|
| 245 |
+
flat_tgt = tgt_idx[valid_mask]
|
| 246 |
+
|
| 247 |
+
new_input_ids[flat_batch, flat_tgt] = flat_vals
|
| 248 |
+
new_attention_mask[flat_batch, flat_tgt] = 1
|
| 249 |
+
|
| 250 |
+
insert_pos = first_valid_index
|
| 251 |
+
new_input_ids[np.arange(B), insert_pos] = bos_token_id
|
| 252 |
+
new_attention_mask[np.arange(B), insert_pos] = 1
|
| 253 |
+
|
| 254 |
+
if need_to_expand:
|
| 255 |
+
new_input_ids = new_input_ids[0]
|
| 256 |
+
new_attention_mask = new_attention_mask[0]
|
| 257 |
+
|
| 258 |
+
return new_input_ids, new_attention_mask
|
| 259 |
+
|
| 260 |
+
def __call__(
|
| 261 |
+
self,
|
| 262 |
+
text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
|
| 263 |
+
images: ImageInput = None,
|
| 264 |
+
videos: VideoInput = None,
|
| 265 |
+
**kwargs: Unpack[MolmoAct2ProcessorKwargs],
|
| 266 |
+
) -> BatchFeature:
|
| 267 |
+
"""
|
| 268 |
+
|
| 269 |
+
Args:
|
| 270 |
+
text (`str`, `list[str]`, `list[list[str]]`):
|
| 271 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
| 272 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
| 273 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
| 274 |
+
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
|
| 275 |
+
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
| 276 |
+
tensor. Both channels-first and channels-last formats are supported.
|
| 277 |
+
videos (`dict[str, Any]` or `list[dict[str, Any]]`):
|
| 278 |
+
The video or batch of videos to be prepared. Each video can be a dictionary with the following keys:
|
| 279 |
+
- `"frames"`: `np.ndarray` of shape (T, H, W, 3)
|
| 280 |
+
- `"timestamps"`: `np.ndarray` of shape (T,)
|
| 281 |
+
- `"sampled_fps"`: `float` (optional)
|
| 282 |
+
- `"sampling_augmentation"`: `str` (optional)
|
| 283 |
+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
| 284 |
+
If set, will return tensors of a particular framework. Acceptable values are:
|
| 285 |
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
| 286 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
| 287 |
+
- `'np'`: Return NumPy `np.ndarray` objects.
|
| 288 |
+
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
`BatchFeature`: A [`BatchFeature`] with the following fields:
|
| 292 |
+
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
| 293 |
+
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
| 294 |
+
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`).
|
| 295 |
+
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
| 296 |
+
- **image_token_pooling** -- Indices of the patches in `image_grids` to pool for each token in `image_tokens`.
|
| 297 |
+
Returned when `images` is not `None`.
|
| 298 |
+
- **image_grids** -- Grids of images. Returned when `images` is not `None`.
|
| 299 |
+
- **image_num_crops** -- Number of crops for each image. Returned when `images` is not `None`.
|
| 300 |
+
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
|
| 301 |
+
- **video_token_pooling** -- Indices of the patches in `video_grids` to pool for each token in `video_tokens`.
|
| 302 |
+
Returned when `videos` is not `None`.
|
| 303 |
+
- **video_grids** -- Grids of videos. Returned when `videos` is not `None`.
|
| 304 |
+
"""
|
| 305 |
+
|
| 306 |
+
output_kwargs = self._merge_kwargs(
|
| 307 |
+
MolmoAct2ProcessorKwargs,
|
| 308 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
| 309 |
+
**kwargs,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
if images is not None:
|
| 313 |
+
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
|
| 314 |
+
image_grids = image_inputs["image_grids"]
|
| 315 |
+
else:
|
| 316 |
+
image_inputs = {}
|
| 317 |
+
image_grids = None
|
| 318 |
+
|
| 319 |
+
if videos is not None:
|
| 320 |
+
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
|
| 321 |
+
video_grids = videos_inputs["video_grids"]
|
| 322 |
+
# If user has not requested video metadata, pop it
|
| 323 |
+
if "return_metadata" not in kwargs:
|
| 324 |
+
video_metadata = videos_inputs.pop("video_metadata")
|
| 325 |
+
else:
|
| 326 |
+
video_metadata = videos_inputs["video_metadata"]
|
| 327 |
+
else:
|
| 328 |
+
videos_inputs = {}
|
| 329 |
+
video_grids = None
|
| 330 |
+
|
| 331 |
+
if not isinstance(text, list):
|
| 332 |
+
text = [text]
|
| 333 |
+
|
| 334 |
+
text = text.copy() # below lines change text in-place
|
| 335 |
+
|
| 336 |
+
if image_grids is not None:
|
| 337 |
+
index = 0
|
| 338 |
+
for i in range(len(text)):
|
| 339 |
+
num_images = text[i].count(self.image_placeholder_token)
|
| 340 |
+
image_grids_i = image_grids[index:index+num_images]
|
| 341 |
+
for image_grid in image_grids_i:
|
| 342 |
+
image_tokens = self.get_image_tokens(image_grid)
|
| 343 |
+
image_string = "".join(image_tokens)
|
| 344 |
+
text[i] = text[i].replace(self.image_placeholder_token, image_string, 1)
|
| 345 |
+
index += num_images
|
| 346 |
+
|
| 347 |
+
if video_grids is not None:
|
| 348 |
+
index = 0
|
| 349 |
+
for i in range(len(text)):
|
| 350 |
+
num_videos = text[i].count(self.video_placeholder_token)
|
| 351 |
+
assert num_videos in {0, 1}, "At most one video is supported for now"
|
| 352 |
+
video_grids_i = video_grids[index:index+num_videos]
|
| 353 |
+
metadata_i = video_metadata[index:index+num_videos]
|
| 354 |
+
for video_grid, metadata in zip(video_grids_i, metadata_i):
|
| 355 |
+
video_string = self.get_video_string(
|
| 356 |
+
video_grid,
|
| 357 |
+
metadata.timestamps,
|
| 358 |
+
)
|
| 359 |
+
text[i] = text[i].replace(self.video_placeholder_token, video_string, 1)
|
| 360 |
+
index += num_videos
|
| 361 |
+
|
| 362 |
+
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
| 363 |
+
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
|
| 364 |
+
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
| 365 |
+
|
| 366 |
+
input_ids = text_inputs["input_ids"]
|
| 367 |
+
attention_mask = text_inputs["attention_mask"]
|
| 368 |
+
|
| 369 |
+
input_ids = np.array(input_ids)
|
| 370 |
+
attention_mask = np.array(attention_mask)
|
| 371 |
+
|
| 372 |
+
bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
|
| 373 |
+
input_ids, attention_mask = self.insert_bos(
|
| 374 |
+
input_ids, attention_mask, bos, self.tokenizer.pad_token_id
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
if return_mm_token_type_ids:
|
| 378 |
+
image_tokens = np.array(self.image_token_ids).astype(input_ids.dtype)
|
| 379 |
+
token_type_ids = np.any(input_ids[:, :, None] == image_tokens[None, None, :], axis=-1)
|
| 380 |
+
text_inputs["token_type_ids"] = token_type_ids.tolist()
|
| 381 |
+
|
| 382 |
+
text_inputs["input_ids"] = input_ids.tolist()
|
| 383 |
+
text_inputs["attention_mask"] = attention_mask.tolist()
|
| 384 |
+
|
| 385 |
+
return BatchFeature(
|
| 386 |
+
data={**text_inputs, **image_inputs, **videos_inputs},
|
| 387 |
+
tensor_type=return_tensors,
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
def post_process_image_text_to_text(
|
| 391 |
+
self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
|
| 392 |
+
):
|
| 393 |
+
"""
|
| 394 |
+
Post-process the output of the model to decode the text.
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
generated_outputs (`torch.Tensor` or `np.ndarray`):
|
| 398 |
+
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
|
| 399 |
+
or `(sequence_length,)`.
|
| 400 |
+
skip_special_tokens (`bool`, *optional*, defaults to `True`):
|
| 401 |
+
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
|
| 402 |
+
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
|
| 403 |
+
Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
|
| 404 |
+
**kwargs:
|
| 405 |
+
Additional arguments to be passed to the tokenizer's `batch_decode method`.
|
| 406 |
+
|
| 407 |
+
Returns:
|
| 408 |
+
`list[str]`: The decoded text.
|
| 409 |
+
"""
|
| 410 |
+
return self.tokenizer.batch_decode(
|
| 411 |
+
generated_outputs,
|
| 412 |
+
skip_special_tokens=skip_special_tokens,
|
| 413 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 414 |
+
**kwargs,
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
MolmoAct2Processor.register_for_auto_class()
|
processor_config.json
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"auto_map": {
|
| 3 |
+
"AutoProcessor": "processing_molmoact2.MolmoAct2Processor"
|
| 4 |
+
},
|
| 5 |
+
"image_processor": {
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoImageProcessor": "image_processing_molmoact2.MolmoAct2ImageProcessor",
|
| 8 |
+
"AutoProcessor": "processing_molmoact2.MolmoAct2Processor"
|
| 9 |
+
},
|
| 10 |
+
"crop_mode": "resize",
|
| 11 |
+
"do_convert_rgb": true,
|
| 12 |
+
"image_mean": [
|
| 13 |
+
0.5,
|
| 14 |
+
0.5,
|
| 15 |
+
0.5
|
| 16 |
+
],
|
| 17 |
+
"image_processor_type": "MolmoAct2ImageProcessor",
|
| 18 |
+
"image_std": [
|
| 19 |
+
0.5,
|
| 20 |
+
0.5,
|
| 21 |
+
0.5
|
| 22 |
+
],
|
| 23 |
+
"max_crops": 8,
|
| 24 |
+
"overlap_margins": [
|
| 25 |
+
4,
|
| 26 |
+
4
|
| 27 |
+
],
|
| 28 |
+
"patch_size": 14,
|
| 29 |
+
"pooling_size": [
|
| 30 |
+
2,
|
| 31 |
+
2
|
| 32 |
+
],
|
| 33 |
+
"resample": 2,
|
| 34 |
+
"size": {
|
| 35 |
+
"height": 378,
|
| 36 |
+
"width": 378
|
| 37 |
+
}
|
| 38 |
+
},
|
| 39 |
+
"image_use_col_tokens": true,
|
| 40 |
+
"processor_class": "MolmoAct2Processor",
|
| 41 |
+
"use_frame_special_tokens": true,
|
| 42 |
+
"use_single_crop_col_tokens": false,
|
| 43 |
+
"use_single_crop_start_token": true,
|
| 44 |
+
"video_processor": {
|
| 45 |
+
"auto_map": {
|
| 46 |
+
"AutoProcessor": "processing_molmoact2.MolmoAct2Processor",
|
| 47 |
+
"AutoVideoProcessor": "video_processing_molmoact2.MolmoAct2VideoProcessor"
|
| 48 |
+
},
|
| 49 |
+
"data_format": "channels_first",
|
| 50 |
+
"default_to_square": true,
|
| 51 |
+
"do_convert_rgb": true,
|
| 52 |
+
"do_normalize": true,
|
| 53 |
+
"do_rescale": true,
|
| 54 |
+
"do_resize": true,
|
| 55 |
+
"do_sample_frames": true,
|
| 56 |
+
"frame_sample_mode": "uniform_last_frame",
|
| 57 |
+
"image_mean": [
|
| 58 |
+
0.5,
|
| 59 |
+
0.5,
|
| 60 |
+
0.5
|
| 61 |
+
],
|
| 62 |
+
"image_std": [
|
| 63 |
+
0.5,
|
| 64 |
+
0.5,
|
| 65 |
+
0.5
|
| 66 |
+
],
|
| 67 |
+
"max_fps": 2.0,
|
| 68 |
+
"num_frames": 8,
|
| 69 |
+
"patch_size": 14,
|
| 70 |
+
"pooling_size": [
|
| 71 |
+
3,
|
| 72 |
+
3
|
| 73 |
+
],
|
| 74 |
+
"resample": 2,
|
| 75 |
+
"rescale_factor": 0.00392156862745098,
|
| 76 |
+
"return_metadata": false,
|
| 77 |
+
"sampling_fps": 2,
|
| 78 |
+
"size": {
|
| 79 |
+
"height": 378,
|
| 80 |
+
"width": 378
|
| 81 |
+
},
|
| 82 |
+
"video_processor_type": "MolmoAct2VideoProcessor"
|
| 83 |
+
},
|
| 84 |
+
"video_use_col_tokens": false
|
| 85 |
+
}
|
tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d5395aefc9b1b7f0385d8c86a2f1775e5af81bdfbf9f2d97827ea37921d9f862
|
| 3 |
+
size 11983605
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": false,
|
| 3 |
+
"auto_map": {
|
| 4 |
+
"AutoProcessor": "processing_molmoact2.MolmoAct2Processor"
|
| 5 |
+
},
|
| 6 |
+
"backend": "tokenizers",
|
| 7 |
+
"bos_token": "<|im_end|>",
|
| 8 |
+
"clean_up_tokenization_spaces": false,
|
| 9 |
+
"eos_token": "<|im_end|>",
|
| 10 |
+
"errors": "replace",
|
| 11 |
+
"extra_special_tokens": [
|
| 12 |
+
"<im_start>",
|
| 13 |
+
"<im_end>",
|
| 14 |
+
"<im_patch>",
|
| 15 |
+
"<im_col>",
|
| 16 |
+
"<low_res_im_start>",
|
| 17 |
+
"<|image|>",
|
| 18 |
+
"<im_low>",
|
| 19 |
+
"<frame_start>",
|
| 20 |
+
"<frame_end>",
|
| 21 |
+
"<|video|>",
|
| 22 |
+
"<|points|>",
|
| 23 |
+
"<|token_index|>",
|
| 24 |
+
"<|vit_index|>",
|
| 25 |
+
"<|vit_loc|>"
|
| 26 |
+
],
|
| 27 |
+
"is_local": false,
|
| 28 |
+
"model_max_length": 1010000,
|
| 29 |
+
"pad_token": "<|endoftext|>",
|
| 30 |
+
"processor_class": "MolmoAct2Processor",
|
| 31 |
+
"split_special_tokens": false,
|
| 32 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
| 33 |
+
"unk_token": null
|
| 34 |
+
}
|
video_processing_molmoact2.py
ADDED
|
@@ -0,0 +1,969 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Video processor class for MolmoAct2"""
|
| 2 |
+
from functools import partial
|
| 3 |
+
import os
|
| 4 |
+
import warnings
|
| 5 |
+
from contextlib import redirect_stdout
|
| 6 |
+
from io import BytesIO
|
| 7 |
+
from urllib.parse import urlparse
|
| 8 |
+
from typing import Optional, Union, Callable
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import requests
|
| 12 |
+
import einops
|
| 13 |
+
import torch
|
| 14 |
+
import torchvision.transforms
|
| 15 |
+
|
| 16 |
+
from transformers.image_utils import (
|
| 17 |
+
IMAGENET_STANDARD_MEAN,
|
| 18 |
+
IMAGENET_STANDARD_STD,
|
| 19 |
+
ImageInput,
|
| 20 |
+
PILImageResampling,
|
| 21 |
+
SizeDict,
|
| 22 |
+
validate_kwargs,
|
| 23 |
+
)
|
| 24 |
+
from transformers.video_utils import (
|
| 25 |
+
VideoInput,
|
| 26 |
+
is_valid_video,
|
| 27 |
+
make_batched_videos,
|
| 28 |
+
make_batched_metadata,
|
| 29 |
+
VideoMetadata,
|
| 30 |
+
)
|
| 31 |
+
from transformers.processing_utils import Unpack, VideosKwargs
|
| 32 |
+
from transformers.video_processing_utils import BaseVideoProcessor
|
| 33 |
+
from transformers.utils import logging
|
| 34 |
+
from transformers.feature_extraction_utils import BatchFeature
|
| 35 |
+
from transformers.utils import (
|
| 36 |
+
is_av_available,
|
| 37 |
+
is_decord_available,
|
| 38 |
+
is_torchcodec_available,
|
| 39 |
+
is_yt_dlp_available,
|
| 40 |
+
TensorType,
|
| 41 |
+
logging,
|
| 42 |
+
to_numpy,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
logger = logging.get_logger(__name__)
|
| 47 |
+
|
| 48 |
+
MAX_VIDEO_FPS = 8
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def normalize_image(
|
| 52 |
+
image: np.ndarray,
|
| 53 |
+
image_mean: list[float],
|
| 54 |
+
image_std: list[float],
|
| 55 |
+
) -> np.ndarray:
|
| 56 |
+
if np.allclose(image_mean, [0.5, 0.5, 0.5]) and np.allclose(image_std, [0.5, 0.5, 0.5]):
|
| 57 |
+
return image * np.asarray(2.0, dtype=np.float32) - np.asarray(1.0, dtype=np.float32)
|
| 58 |
+
image -= np.array(image_mean, dtype=np.float32)[None, None, :]
|
| 59 |
+
image /= np.array(image_std, dtype=np.float32)[None, None, :]
|
| 60 |
+
return image
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def resize_image(
|
| 64 |
+
image: np.ndarray,
|
| 65 |
+
desired_output_size: list[int],
|
| 66 |
+
resample: PILImageResampling,
|
| 67 |
+
) -> np.ndarray:
|
| 68 |
+
if len(image.shape) == 3:
|
| 69 |
+
is_video = False
|
| 70 |
+
image = torch.permute(torch.from_numpy(image), [2, 0, 1])
|
| 71 |
+
else:
|
| 72 |
+
is_video = True
|
| 73 |
+
image = torch.permute(torch.from_numpy(image), [0, 3, 1, 2])
|
| 74 |
+
dtype = image.dtype
|
| 75 |
+
if torch.is_floating_point(image):
|
| 76 |
+
in_min = 0.0
|
| 77 |
+
in_max = 1.0
|
| 78 |
+
resized = torchvision.transforms.Resize(
|
| 79 |
+
desired_output_size,
|
| 80 |
+
resample,
|
| 81 |
+
antialias=False,
|
| 82 |
+
)(image)
|
| 83 |
+
resized = torch.clip(resized, 0.0, 1.0).to(dtype)
|
| 84 |
+
else:
|
| 85 |
+
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(image.dtype)
|
| 86 |
+
in_min = 0.0
|
| 87 |
+
in_max = 255.0
|
| 88 |
+
resized = torchvision.transforms.Resize(
|
| 89 |
+
desired_output_size,
|
| 90 |
+
resample,
|
| 91 |
+
antialias=False,
|
| 92 |
+
)(image)
|
| 93 |
+
resized = torch.clip(resized, 0, 255).to(dtype)
|
| 94 |
+
|
| 95 |
+
resized = resized.to(torch.float32)
|
| 96 |
+
resized = (resized - in_min) / (in_max - in_min)
|
| 97 |
+
|
| 98 |
+
if is_video:
|
| 99 |
+
resized = torch.permute(resized, [0, 2, 3, 1]).numpy()
|
| 100 |
+
else:
|
| 101 |
+
resized = torch.permute(resized, [1, 2, 0]).numpy()
|
| 102 |
+
|
| 103 |
+
return resized
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def build_resized_image(
|
| 107 |
+
image: np.ndarray,
|
| 108 |
+
base_image_input_size: list[int],
|
| 109 |
+
resample: PILImageResampling,
|
| 110 |
+
image_mean: list[float],
|
| 111 |
+
image_std: list[float],
|
| 112 |
+
image_patch_size: int,
|
| 113 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
| 114 |
+
resized = resize_image(
|
| 115 |
+
image, base_image_input_size, resample,
|
| 116 |
+
)
|
| 117 |
+
resized = normalize_image(resized, image_mean, image_std)
|
| 118 |
+
if len(resized.shape) == 3:
|
| 119 |
+
resized = np.expand_dims(resized, 0)
|
| 120 |
+
crop_patch_w = base_image_input_size[1] // image_patch_size
|
| 121 |
+
crop_patch_h = base_image_input_size[0] // image_patch_size
|
| 122 |
+
resize_idx = np.arange(crop_patch_w*crop_patch_h).reshape([crop_patch_h, crop_patch_w])
|
| 123 |
+
return resized, resize_idx
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
|
| 127 |
+
"""Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
|
| 128 |
+
if len(array.shape) == 3:
|
| 129 |
+
n_crops, h, w = array.shape
|
| 130 |
+
h_patches = h//patch_size
|
| 131 |
+
w_patches = w//patch_size
|
| 132 |
+
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
|
| 133 |
+
array = np.transpose(array, [0, 1, 3, 2, 4])
|
| 134 |
+
array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size])
|
| 135 |
+
return array
|
| 136 |
+
else:
|
| 137 |
+
n_crops, h, w, c = array.shape
|
| 138 |
+
h_patches = h//patch_size
|
| 139 |
+
w_patches = w//patch_size
|
| 140 |
+
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
|
| 141 |
+
array = np.transpose(array, [0, 1, 3, 2, 4, 5])
|
| 142 |
+
array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size*c])
|
| 143 |
+
return array
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def arange_for_pooling(
|
| 147 |
+
idx_arr: np.ndarray,
|
| 148 |
+
pool_h: int,
|
| 149 |
+
pool_w: int,
|
| 150 |
+
) -> np.ndarray:
|
| 151 |
+
h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0]
|
| 152 |
+
w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1]
|
| 153 |
+
idx_arr = np.pad(idx_arr, [[h_pad//2, (h_pad+1)//2], [w_pad//2, (w_pad+1)//2]],
|
| 154 |
+
mode='constant',constant_values=-1)
|
| 155 |
+
return einops.rearrange(
|
| 156 |
+
idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def image_to_patches_and_grids(
|
| 160 |
+
image: ImageInput,
|
| 161 |
+
base_image_input_size: list[int],
|
| 162 |
+
resample: PILImageResampling,
|
| 163 |
+
image_mean: list[float],
|
| 164 |
+
image_std: list[float],
|
| 165 |
+
image_patch_size: int,
|
| 166 |
+
image_pooling_w: int,
|
| 167 |
+
image_pooling_h: int,
|
| 168 |
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 169 |
+
"""
|
| 170 |
+
:return image_grids, the shape of each image after pooling
|
| 171 |
+
:return crops, the image crops to processes with the ViT
|
| 172 |
+
:return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the
|
| 173 |
+
patches in `crops` to pool for that token, masked with -1
|
| 174 |
+
"""
|
| 175 |
+
if isinstance(base_image_input_size, int):
|
| 176 |
+
base_image_input_size = (base_image_input_size, base_image_input_size)
|
| 177 |
+
|
| 178 |
+
pooling_w = image_pooling_w
|
| 179 |
+
pooling_h = image_pooling_h
|
| 180 |
+
|
| 181 |
+
resized, resize_idx = build_resized_image(
|
| 182 |
+
image,
|
| 183 |
+
base_image_input_size,
|
| 184 |
+
resample,
|
| 185 |
+
image_mean,
|
| 186 |
+
image_std,
|
| 187 |
+
image_patch_size,
|
| 188 |
+
)
|
| 189 |
+
pooling_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
|
| 190 |
+
h, w = pooling_idx.shape[:2]
|
| 191 |
+
pooling_idx = pooling_idx.reshape([-1, pooling_h*pooling_w])
|
| 192 |
+
image_grid = [h, w]
|
| 193 |
+
return (
|
| 194 |
+
image_grid,
|
| 195 |
+
batch_pixels_to_patches(resized, image_patch_size),
|
| 196 |
+
pooling_idx,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def get_candidate_target_fps(
|
| 201 |
+
video_fps: Union[int, float],
|
| 202 |
+
sampling_fps: Union[int, float],
|
| 203 |
+
max_fps: Union[int, float] = MAX_VIDEO_FPS,
|
| 204 |
+
) -> list[float]:
|
| 205 |
+
"""
|
| 206 |
+
Return the subset of `video_fps` factors that remain multiples of `sampling_fps`.
|
| 207 |
+
|
| 208 |
+
Examples:
|
| 209 |
+
>>> get_candidate_target_fps(video_fps=6, sampling_fps=2)
|
| 210 |
+
[2, 6]
|
| 211 |
+
>>> get_candidate_target_fps(video_fps=5, sampling_fps=1)
|
| 212 |
+
[1, 5]
|
| 213 |
+
>>> get_candidate_target_fps(video_fps=2, sampling_fps=2)
|
| 214 |
+
[2]
|
| 215 |
+
>>> get_candidate_target_fps(video_fps=5, sampling_fps=2)
|
| 216 |
+
Traceback (most recent call last):
|
| 217 |
+
...
|
| 218 |
+
ValueError: sampling_fps=2 must divide video_fps=5 to produce consistent frame steps.
|
| 219 |
+
"""
|
| 220 |
+
video_fps = int(video_fps)
|
| 221 |
+
sampling_fps = int(sampling_fps)
|
| 222 |
+
max_fps = int(max_fps)
|
| 223 |
+
|
| 224 |
+
if sampling_fps is None:
|
| 225 |
+
raise ValueError("sampling_fps must be provided")
|
| 226 |
+
if video_fps <= 0 or sampling_fps <= 0:
|
| 227 |
+
raise ValueError(f"video_fps and sampling_fps must be positive (got {video_fps}, {sampling_fps})")
|
| 228 |
+
if video_fps % sampling_fps != 0:
|
| 229 |
+
raise ValueError(f"sampling_fps={sampling_fps} must divide video_fps={video_fps}.")
|
| 230 |
+
|
| 231 |
+
candidates = []
|
| 232 |
+
for candidate in range(sampling_fps, video_fps + 1, sampling_fps):
|
| 233 |
+
if candidate > max_fps:
|
| 234 |
+
break
|
| 235 |
+
if video_fps % candidate == 0:
|
| 236 |
+
candidates.append(float(candidate))
|
| 237 |
+
|
| 238 |
+
return candidates
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def read_video_decord(
|
| 242 |
+
video_path,
|
| 243 |
+
sample_timestamps_fn: Callable,
|
| 244 |
+
**kwargs,
|
| 245 |
+
) -> np.ndarray:
|
| 246 |
+
"""
|
| 247 |
+
Decode a video using the Decord backend.
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
video_path (`str`):
|
| 251 |
+
Path to the video file.
|
| 252 |
+
sample_timestamps_fn (`Callable`):
|
| 253 |
+
A callable function that will return timestamps at which the video should be sampled.
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
| 257 |
+
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
| 258 |
+
- `VideoMetadata` object.
|
| 259 |
+
"""
|
| 260 |
+
# Lazy import from decord
|
| 261 |
+
import importlib
|
| 262 |
+
decord = importlib.import_module("decord")
|
| 263 |
+
|
| 264 |
+
vr = decord.VideoReader(uri=video_path, ctx=decord.cpu(0)) # decord has problems with gpu
|
| 265 |
+
video_fps = vr.get_avg_fps()
|
| 266 |
+
total_num_frames = len(vr)
|
| 267 |
+
time_stamps = vr.get_frame_timestamp(list(range(len(vr))))
|
| 268 |
+
duration = time_stamps[-1][1] - time_stamps[0][0]
|
| 269 |
+
|
| 270 |
+
metadata = VideoMetadata(
|
| 271 |
+
total_num_frames=int(total_num_frames),
|
| 272 |
+
fps=float(video_fps),
|
| 273 |
+
duration=float(duration),
|
| 274 |
+
video_backend="decord",
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
|
| 278 |
+
target_timestamps = np.array(target_timestamps)
|
| 279 |
+
offset = time_stamps[0, 0]
|
| 280 |
+
|
| 281 |
+
ix = np.searchsorted(time_stamps[:, 1], target_timestamps + offset, side='right')
|
| 282 |
+
ix = np.minimum(ix, len(time_stamps) - 1)
|
| 283 |
+
|
| 284 |
+
video = vr.get_batch(ix).asnumpy()
|
| 285 |
+
metadata.update(
|
| 286 |
+
{
|
| 287 |
+
"frames_indices": target_timestamps * video_fps,
|
| 288 |
+
"height": video.shape[1],
|
| 289 |
+
"width": video.shape[2],
|
| 290 |
+
}
|
| 291 |
+
)
|
| 292 |
+
return video, metadata
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def read_video_torchcodec(
|
| 296 |
+
video_path,
|
| 297 |
+
sample_timestamps_fn: Callable,
|
| 298 |
+
**kwargs,
|
| 299 |
+
) -> np.ndarray:
|
| 300 |
+
"""
|
| 301 |
+
Decode a video using torchcodec decoder.
|
| 302 |
+
|
| 303 |
+
Args:
|
| 304 |
+
video_path (`str`):
|
| 305 |
+
Path to the video file.
|
| 306 |
+
sample_timestamps_fn (`Callable`):
|
| 307 |
+
A callable function that will return timestamps at which the video should be sampled.
|
| 308 |
+
|
| 309 |
+
Returns:
|
| 310 |
+
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
| 311 |
+
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
| 312 |
+
- `VideoMetadata` object.
|
| 313 |
+
"""
|
| 314 |
+
# Lazy import torchcodec
|
| 315 |
+
import importlib
|
| 316 |
+
torchcodec = importlib.import_module("torchcodec")
|
| 317 |
+
|
| 318 |
+
decoder = torchcodec.decoders.VideoDecoder(
|
| 319 |
+
video_path,
|
| 320 |
+
# Interestingly `exact` mode takes less than approximate when we load the whole video
|
| 321 |
+
seek_mode="exact",
|
| 322 |
+
# Allow FFmpeg decide on the number of threads for efficiency
|
| 323 |
+
num_ffmpeg_threads=0,
|
| 324 |
+
)
|
| 325 |
+
# If the first frame starts at > 0, we effectively clip the video starting at that time
|
| 326 |
+
# since (most) video players would also skip to that time
|
| 327 |
+
time_offset = decoder.metadata.begin_stream_seconds_from_content
|
| 328 |
+
# Note this duration does assume we started playing at `time_offset`
|
| 329 |
+
duration = decoder.metadata.duration_seconds
|
| 330 |
+
|
| 331 |
+
metadata = VideoMetadata(
|
| 332 |
+
total_num_frames=decoder.metadata.num_frames,
|
| 333 |
+
fps=decoder.metadata.average_fps,
|
| 334 |
+
duration=duration,
|
| 335 |
+
video_backend="torchcodec",
|
| 336 |
+
height=decoder.metadata.height,
|
| 337 |
+
width=decoder.metadata.width,
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
|
| 341 |
+
|
| 342 |
+
# Floating point/rounding issues might cause `target_timestamps` to be very slightly
|
| 343 |
+
# out-of-bounds, to handle this we sanity check then clip them
|
| 344 |
+
assert all(x >= 0 for x in target_timestamps)
|
| 345 |
+
assert all(x < duration+1e-6 for x in target_timestamps)
|
| 346 |
+
# 1e-6 padding since torchcodec can throw out-of-bounds errors even if you ask for the
|
| 347 |
+
# exact boundary value, we should still get the first/last frame anyway
|
| 348 |
+
max_timestamp = decoder.metadata.end_stream_seconds_from_content - 1e-6
|
| 349 |
+
min_timestamp = decoder.metadata.begin_stream_seconds_from_content + 1e-6
|
| 350 |
+
# Note we avoid using numpy ops here to reduce floating precision issues
|
| 351 |
+
timestamps = [x + time_offset for x in target_timestamps]
|
| 352 |
+
timestamps = [max(min_timestamp, min(max_timestamp, x)) for x in timestamps]
|
| 353 |
+
|
| 354 |
+
video = decoder.get_frames_played_at(timestamps).data.numpy().transpose(0, 2, 3, 1) # Convert to THWC format
|
| 355 |
+
target_timestamps = np.array(target_timestamps)
|
| 356 |
+
metadata.frames_indices = target_timestamps * metadata.fps
|
| 357 |
+
|
| 358 |
+
return video, metadata
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def read_video_pyav(
|
| 362 |
+
video_path,
|
| 363 |
+
sample_timestamps_fn: Callable,
|
| 364 |
+
**kwargs,
|
| 365 |
+
) -> np.ndarray:
|
| 366 |
+
"""
|
| 367 |
+
Decode a video using the PyAV backend.
|
| 368 |
+
|
| 369 |
+
Args:
|
| 370 |
+
video_path (`str`):
|
| 371 |
+
Path to the video file.
|
| 372 |
+
sample_timestamps_fn (`Callable`):
|
| 373 |
+
A callable function that will return timestamps at which the video should be sampled.
|
| 374 |
+
|
| 375 |
+
Returns:
|
| 376 |
+
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
| 377 |
+
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
| 378 |
+
- `VideoMetadata` object.
|
| 379 |
+
"""
|
| 380 |
+
# Lazy import torchcodec
|
| 381 |
+
import importlib
|
| 382 |
+
av = importlib.import_module("av")
|
| 383 |
+
|
| 384 |
+
with av.open(video_path) as container:
|
| 385 |
+
video_stream = container.streams.video[0]
|
| 386 |
+
fps = video_stream.average_rate or video_stream.guessed_rate
|
| 387 |
+
it = container.decode(video=0)
|
| 388 |
+
frames = list(it)
|
| 389 |
+
|
| 390 |
+
stream = container.streams.video[0]
|
| 391 |
+
start = frames[0].pts * stream.time_base
|
| 392 |
+
container_end = stream.duration
|
| 393 |
+
if container_end is not None:
|
| 394 |
+
container_end *= stream.time_base
|
| 395 |
+
if container_end is None or container_end < frames[-1].pts:
|
| 396 |
+
# Some problem with stream duration, so use the frame PTS directly
|
| 397 |
+
# and guess the duration of the last frame
|
| 398 |
+
end = frames[-1].pts * stream.time_base + 1/fps
|
| 399 |
+
else:
|
| 400 |
+
end = container_end
|
| 401 |
+
duration = float(end - start)
|
| 402 |
+
|
| 403 |
+
metadata = VideoMetadata(
|
| 404 |
+
total_num_frames=len(frames),
|
| 405 |
+
fps=float(fps),
|
| 406 |
+
duration=float(duration),
|
| 407 |
+
video_backend="pyav",
|
| 408 |
+
height=video_stream.height,
|
| 409 |
+
width=video_stream.width,
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
|
| 413 |
+
offset = float(start)
|
| 414 |
+
|
| 415 |
+
target_timestamps = np.array(target_timestamps)
|
| 416 |
+
end_time_stamps = np.array([float(frame.pts * stream.time_base) for frame in frames[1:]] + [duration])
|
| 417 |
+
indices = np.searchsorted(end_time_stamps, target_timestamps + offset, side='right')
|
| 418 |
+
indices = np.minimum(indices, len(end_time_stamps) - 1)
|
| 419 |
+
|
| 420 |
+
video = np.stack(
|
| 421 |
+
[frames[i].to_ndarray(format="rgb24", channel_last=True) for i in indices],
|
| 422 |
+
axis=0,
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
metadata.frames_indices = target_timestamps * fps
|
| 426 |
+
|
| 427 |
+
return video, metadata
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
VIDEO_DECODERS = {
|
| 431 |
+
"decord": read_video_decord,
|
| 432 |
+
"torchcodec": read_video_torchcodec,
|
| 433 |
+
"pyav": read_video_pyav,
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def load_video(
|
| 438 |
+
video: VideoInput,
|
| 439 |
+
backend: str = "decord",
|
| 440 |
+
sample_timestamps_fn: Optional[Callable] = None,
|
| 441 |
+
**kwargs,
|
| 442 |
+
):
|
| 443 |
+
"""
|
| 444 |
+
Loads `video` to a numpy array.
|
| 445 |
+
|
| 446 |
+
Args:
|
| 447 |
+
video (`VideoInput`):
|
| 448 |
+
The video to convert to the numpy array format. Can be a link to video or local path.
|
| 449 |
+
backend (`str`, *optional*, defaults to `"decord"`):
|
| 450 |
+
The backend to use when loading the video. Can be any of ["decord", "pyav", ""torchcodec"]. Defaults to "decord".
|
| 451 |
+
sample_timestamps_fn (`Callable`):
|
| 452 |
+
A callable function that will return timestamps at which the video should be sampled.
|
| 453 |
+
"""
|
| 454 |
+
|
| 455 |
+
# Early exit if provided an array or `PIL` frames
|
| 456 |
+
if not isinstance(video, str):
|
| 457 |
+
metadata = [None] * len(video)
|
| 458 |
+
return video, metadata
|
| 459 |
+
|
| 460 |
+
if urlparse(video).netloc in ["www.youtube.com", "youtube.com"]:
|
| 461 |
+
if not is_yt_dlp_available():
|
| 462 |
+
raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.")
|
| 463 |
+
# Lazy import from yt_dlp
|
| 464 |
+
import importlib
|
| 465 |
+
yt_dlp = importlib.import_module("yt_dlp")
|
| 466 |
+
|
| 467 |
+
buffer = BytesIO()
|
| 468 |
+
with redirect_stdout(buffer), yt_dlp.YoutubeDL() as f:
|
| 469 |
+
f.download([video])
|
| 470 |
+
bytes_obj = buffer.getvalue()
|
| 471 |
+
file_obj = BytesIO(bytes_obj)
|
| 472 |
+
elif video.startswith("http://") or video.startswith("https://"):
|
| 473 |
+
file_obj = BytesIO(requests.get(video).content)
|
| 474 |
+
elif os.path.isfile(video):
|
| 475 |
+
file_obj = video
|
| 476 |
+
else:
|
| 477 |
+
raise TypeError("Incorrect format used for video. Should be an url linking to an video or a local path.")
|
| 478 |
+
|
| 479 |
+
# can also load with decord, but not cv2/torchvision
|
| 480 |
+
# both will fail in case of url links
|
| 481 |
+
video_is_url = video.startswith("http://") or video.startswith("https://")
|
| 482 |
+
if video_is_url and backend == "opencv":
|
| 483 |
+
raise ValueError("If you are trying to load a video from URL, you cannot use 'opencv' as backend")
|
| 484 |
+
|
| 485 |
+
if (
|
| 486 |
+
(not is_decord_available() and backend == "decord")
|
| 487 |
+
or (not is_torchcodec_available() and backend == "torchcodec")
|
| 488 |
+
or (not is_av_available() and backend == "pyav")
|
| 489 |
+
):
|
| 490 |
+
raise ImportError(
|
| 491 |
+
f"You chose backend={backend} for loading the video but the required library is not found in your environment "
|
| 492 |
+
f"Make sure to install {backend} before loading the video."
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
video_decoder = VIDEO_DECODERS[backend]
|
| 496 |
+
video, metadata = video_decoder(file_obj, sample_timestamps_fn, **kwargs)
|
| 497 |
+
return video, metadata
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def get_target_fps(
|
| 501 |
+
video_fps: float,
|
| 502 |
+
max_frames: int,
|
| 503 |
+
total_frames: int,
|
| 504 |
+
frame_sample_mode: str,
|
| 505 |
+
candidate_target_fps: tuple[float],
|
| 506 |
+
) -> float:
|
| 507 |
+
"""
|
| 508 |
+
Get the target fps that best spans the video and has the most frames sampled
|
| 509 |
+
"""
|
| 510 |
+
num_frames_sampled = 0
|
| 511 |
+
selected_target_fps = None
|
| 512 |
+
for target_fps in candidate_target_fps:
|
| 513 |
+
step_size = max(int(video_fps / target_fps), 1)
|
| 514 |
+
num_frames_sampled_at_fps = int(total_frames / step_size)
|
| 515 |
+
if num_frames_sampled == 0:
|
| 516 |
+
if "uniform" in frame_sample_mode:
|
| 517 |
+
if num_frames_sampled_at_fps > max_frames:
|
| 518 |
+
break
|
| 519 |
+
selected_target_fps = target_fps
|
| 520 |
+
num_frames_sampled = num_frames_sampled_at_fps
|
| 521 |
+
|
| 522 |
+
else:
|
| 523 |
+
# the candidate sampling fps increases so frame count can't decrease
|
| 524 |
+
assert num_frames_sampled <= num_frames_sampled_at_fps
|
| 525 |
+
if num_frames_sampled_at_fps > max_frames:
|
| 526 |
+
# choose the sampling fps that spans the video
|
| 527 |
+
continue
|
| 528 |
+
|
| 529 |
+
elif num_frames_sampled_at_fps > num_frames_sampled:
|
| 530 |
+
# both are less than max_frames, choose the one with higher density of frames sampled
|
| 531 |
+
selected_target_fps = target_fps
|
| 532 |
+
num_frames_sampled = num_frames_sampled_at_fps
|
| 533 |
+
return selected_target_fps
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
def get_frame_times_and_chosen_fps(
|
| 537 |
+
selected_target_fps,
|
| 538 |
+
total_frames,
|
| 539 |
+
max_frames,
|
| 540 |
+
video_fps
|
| 541 |
+
):
|
| 542 |
+
if selected_target_fps is None:
|
| 543 |
+
frame_indices = np.linspace(0, total_frames, max_frames, endpoint=False, dtype=int)
|
| 544 |
+
else:
|
| 545 |
+
step_size = max(int(video_fps / selected_target_fps), 1)
|
| 546 |
+
frame_indices = np.arange(0, total_frames, step_size)
|
| 547 |
+
if len(frame_indices) > max_frames:
|
| 548 |
+
frame_indices = frame_indices[:max_frames]
|
| 549 |
+
return selected_target_fps, frame_indices
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
class MolmoAct2VideoProcessorKwargs(VideosKwargs, total=False):
|
| 553 |
+
patch_size: Optional[int]
|
| 554 |
+
pooling_size: Optional[list[int]]
|
| 555 |
+
frame_sample_mode: Optional[str]
|
| 556 |
+
max_fps: Optional[int]
|
| 557 |
+
sampling_fps: Optional[int]
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
class MolmoAct2VideoProcessor(BaseVideoProcessor):
|
| 561 |
+
resample = PILImageResampling.BILINEAR
|
| 562 |
+
size = {"height": 378, "width": 378}
|
| 563 |
+
image_mean = IMAGENET_STANDARD_MEAN
|
| 564 |
+
image_std = IMAGENET_STANDARD_STD
|
| 565 |
+
do_resize = True
|
| 566 |
+
do_rescale = True
|
| 567 |
+
do_normalize = True
|
| 568 |
+
do_convert_rgb = True
|
| 569 |
+
patch_size = 14
|
| 570 |
+
pooling_size = [3, 3]
|
| 571 |
+
do_sample_frames = True
|
| 572 |
+
frame_sample_mode = "uniform_last_frame"
|
| 573 |
+
max_fps = 2
|
| 574 |
+
sampling_fps = 2
|
| 575 |
+
valid_kwargs = MolmoAct2VideoProcessorKwargs
|
| 576 |
+
model_input_names = ["pixel_values_videos", "video_token_pooling", "video_grids"]
|
| 577 |
+
|
| 578 |
+
def __init__(self, **kwargs: Unpack[MolmoAct2VideoProcessorKwargs]):
|
| 579 |
+
super().__init__(**kwargs)
|
| 580 |
+
if self.size is not None and (
|
| 581 |
+
self.size.get("height", None) is None or self.size.get("width", None) is None
|
| 582 |
+
):
|
| 583 |
+
raise ValueError("size must contain 'height' and 'width' keys.")
|
| 584 |
+
|
| 585 |
+
def _further_process_kwargs(
|
| 586 |
+
self,
|
| 587 |
+
size: Optional[SizeDict] = None,
|
| 588 |
+
**kwargs,
|
| 589 |
+
) -> dict:
|
| 590 |
+
"""
|
| 591 |
+
Update kwargs that need further processing before being validated
|
| 592 |
+
Can be overridden by subclasses to customize the processing of kwargs.
|
| 593 |
+
"""
|
| 594 |
+
if size is not None and ("height" not in size or "width" not in size):
|
| 595 |
+
raise ValueError("size must contain 'height' and 'width' keys.")
|
| 596 |
+
|
| 597 |
+
return super()._further_process_kwargs(size=size, **kwargs)
|
| 598 |
+
|
| 599 |
+
def sample_times(
|
| 600 |
+
self,
|
| 601 |
+
metadata: VideoMetadata,
|
| 602 |
+
frame_sample_mode: str,
|
| 603 |
+
num_frames: int,
|
| 604 |
+
max_fps: Optional[int] = None,
|
| 605 |
+
sampling_fps: Optional[int] = None,
|
| 606 |
+
**kwargs,
|
| 607 |
+
) -> np.ndarray:
|
| 608 |
+
"""
|
| 609 |
+
Time-based sampling if an array video is passed
|
| 610 |
+
Args:
|
| 611 |
+
metadata (`VideoMetadata`):
|
| 612 |
+
Metadata of the video containing information about total duration, fps and total number of frames.
|
| 613 |
+
frame_sample_mode (`str`, *optional*):
|
| 614 |
+
Mode to sample frames. Defaults to `self.frame_sample_mode`.
|
| 615 |
+
num_frames (`int`, *optional*):
|
| 616 |
+
Maximum number of frames to sample. Defaults to `self.num_frames`.
|
| 617 |
+
man_fps (`int`, *optional*):
|
| 618 |
+
Maximum frames per second to sample.
|
| 619 |
+
sampling_fps (`int`, *optional*):
|
| 620 |
+
Sampling frames per second. Defaults to `self.sampling_fps`.
|
| 621 |
+
Used when `frame_sample_mode` is `"fps"`.
|
| 622 |
+
"""
|
| 623 |
+
frame_sample_mode = frame_sample_mode or self.frame_sample_mode
|
| 624 |
+
num_frames = num_frames or self.num_frames
|
| 625 |
+
sampling_fps = sampling_fps or self.sampling_fps
|
| 626 |
+
|
| 627 |
+
duration = metadata.duration or metadata.total_num_frames / metadata.fps
|
| 628 |
+
if frame_sample_mode == "fps":
|
| 629 |
+
candidate_target_fps = get_candidate_target_fps(metadata.fps, sampling_fps)
|
| 630 |
+
# Try larger and larger FPSs until we hit one that can't span the video
|
| 631 |
+
target_fps = candidate_target_fps[0]
|
| 632 |
+
for candidate_fps in candidate_target_fps[1:]:
|
| 633 |
+
if num_frames / candidate_fps < duration:
|
| 634 |
+
break
|
| 635 |
+
target_fps = candidate_fps
|
| 636 |
+
times = np.arange(0, num_frames) / target_fps
|
| 637 |
+
times = times[times < duration]
|
| 638 |
+
return times
|
| 639 |
+
elif frame_sample_mode == "uniform_last_frame":
|
| 640 |
+
if max_fps is not None:
|
| 641 |
+
max_duration = (num_frames-1) / max_fps # -1 to include the last frame
|
| 642 |
+
if max_duration < duration:
|
| 643 |
+
times = np.linspace(
|
| 644 |
+
0, duration, num=num_frames, endpoint=True, dtype=np.float64
|
| 645 |
+
)
|
| 646 |
+
else:
|
| 647 |
+
times = np.arange(0.0, stop=duration, step=1/max_fps)
|
| 648 |
+
times = np.concatenate([times, [duration]], axis=0)
|
| 649 |
+
assert len(times) <= num_frames
|
| 650 |
+
else:
|
| 651 |
+
times = np.linspace(
|
| 652 |
+
0, duration, num=num_frames, endpoint=True, dtype=np.float64
|
| 653 |
+
)
|
| 654 |
+
return times
|
| 655 |
+
else:
|
| 656 |
+
raise NotImplementedError(frame_sample_mode)
|
| 657 |
+
|
| 658 |
+
def sample_frames(
|
| 659 |
+
self,
|
| 660 |
+
metadata: VideoMetadata,
|
| 661 |
+
frame_sample_mode: Optional[str] = None,
|
| 662 |
+
num_frames: Optional[int] = None,
|
| 663 |
+
max_fps: Optional[int] = None,
|
| 664 |
+
sampling_fps: Optional[int] = None,
|
| 665 |
+
**kwargs,
|
| 666 |
+
) -> np.ndarray:
|
| 667 |
+
"""
|
| 668 |
+
Frame-based sampling if an array video is passed
|
| 669 |
+
Args:
|
| 670 |
+
metadata (`VideoMetadata`):
|
| 671 |
+
Metadata of the video containing information about total duration, fps and total number of frames.
|
| 672 |
+
frame_sample_mode (`str`, *optional*):
|
| 673 |
+
Mode to sample frames. Defaults to `self.frame_sample_mode`.
|
| 674 |
+
num_frames (`int`, *optional*):
|
| 675 |
+
Maximum number of frames to sample. Defaults to `self.num_frames`.
|
| 676 |
+
max_fps (`int`, *optional*):
|
| 677 |
+
Maximum frames per second to sample.
|
| 678 |
+
sampling_fps (`int`, *optional*):
|
| 679 |
+
Sampling frames per second. Defaults to `self.sampling_fps`.
|
| 680 |
+
Used when `frame_sample_mode` is `"fps"`.
|
| 681 |
+
"""
|
| 682 |
+
frame_sample_mode = frame_sample_mode or self.frame_sample_mode
|
| 683 |
+
num_frames = num_frames or self.num_frames
|
| 684 |
+
sampling_fps = sampling_fps or self.sampling_fps
|
| 685 |
+
|
| 686 |
+
total_num_frames = metadata.total_num_frames
|
| 687 |
+
if frame_sample_mode == "uniform_last_frame" and max_fps is not None:
|
| 688 |
+
duration = total_num_frames / metadata.fps
|
| 689 |
+
if total_num_frames <= 2:
|
| 690 |
+
return np.arange(total_num_frames).astype(int)
|
| 691 |
+
if duration > (num_frames - 1) / max_fps: # -1 to include the last frame
|
| 692 |
+
# uniform fallback
|
| 693 |
+
indices = np.linspace(
|
| 694 |
+
0,
|
| 695 |
+
total_num_frames - 1,
|
| 696 |
+
num=min(num_frames, total_num_frames),
|
| 697 |
+
endpoint=True,
|
| 698 |
+
).astype(int)
|
| 699 |
+
return indices
|
| 700 |
+
else:
|
| 701 |
+
float_indices = np.arange(
|
| 702 |
+
0.0, stop=total_num_frames - 1, step=float(metadata.fps / max_fps),
|
| 703 |
+
)
|
| 704 |
+
if np.round(float_indices[-1]) != total_num_frames - 1:
|
| 705 |
+
float_indices = np.concatenate([float_indices, [total_num_frames - 1]], axis=0)
|
| 706 |
+
indices = np.round(float_indices).astype(int)
|
| 707 |
+
assert indices[-1] < total_num_frames
|
| 708 |
+
assert len(float_indices) <= num_frames
|
| 709 |
+
return indices
|
| 710 |
+
elif frame_sample_mode == "uniform_last_frame":
|
| 711 |
+
indices = np.linspace(
|
| 712 |
+
0, total_num_frames - 1, num=min(num_frames, total_num_frames), endpoint=True,
|
| 713 |
+
).astype(int)
|
| 714 |
+
return indices
|
| 715 |
+
elif frame_sample_mode == "fps":
|
| 716 |
+
candidate_target_fps = get_candidate_target_fps(metadata.fps, sampling_fps)
|
| 717 |
+
selected_target_fps = get_target_fps(
|
| 718 |
+
metadata.fps,
|
| 719 |
+
num_frames,
|
| 720 |
+
total_num_frames,
|
| 721 |
+
frame_sample_mode,
|
| 722 |
+
candidate_target_fps,
|
| 723 |
+
)
|
| 724 |
+
_, indices = get_frame_times_and_chosen_fps(
|
| 725 |
+
selected_target_fps,
|
| 726 |
+
total_num_frames,
|
| 727 |
+
num_frames,
|
| 728 |
+
metadata.fps,
|
| 729 |
+
)
|
| 730 |
+
return indices
|
| 731 |
+
else:
|
| 732 |
+
raise NotImplementedError(frame_sample_mode)
|
| 733 |
+
|
| 734 |
+
def fetch_videos(
|
| 735 |
+
self,
|
| 736 |
+
video_url_or_urls: Union[str, list[str], list[list[str]]],
|
| 737 |
+
sample_timestamps_fn=None
|
| 738 |
+
):
|
| 739 |
+
"""
|
| 740 |
+
Convert a single or a list of urls into the corresponding `np.array` objects.
|
| 741 |
+
|
| 742 |
+
If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
|
| 743 |
+
returned.
|
| 744 |
+
"""
|
| 745 |
+
if (
|
| 746 |
+
(not is_decord_available())
|
| 747 |
+
and (not is_torchcodec_available())
|
| 748 |
+
and (not is_av_available())
|
| 749 |
+
):
|
| 750 |
+
raise ImportError(
|
| 751 |
+
"MolmoAct2VideoProcessor requires `decord`, `torchcodec`, or `av` to be installed."
|
| 752 |
+
)
|
| 753 |
+
|
| 754 |
+
if is_decord_available():
|
| 755 |
+
backend = "decord"
|
| 756 |
+
elif is_torchcodec_available():
|
| 757 |
+
warnings.warn(
|
| 758 |
+
"`decord` is not installed and cannot be used to decode the video by default. "
|
| 759 |
+
"Falling back to `torchcodec`."
|
| 760 |
+
)
|
| 761 |
+
backend = "torchcodec"
|
| 762 |
+
else:
|
| 763 |
+
warnings.warn(
|
| 764 |
+
"`decord` is not installed and cannot be used to decode the video by default. "
|
| 765 |
+
"Falling back to `PyAV`."
|
| 766 |
+
)
|
| 767 |
+
backend = "pyav"
|
| 768 |
+
|
| 769 |
+
if isinstance(video_url_or_urls, list):
|
| 770 |
+
return list(zip(*[self.fetch_videos(x, sample_timestamps_fn=sample_timestamps_fn) for x in video_url_or_urls]))
|
| 771 |
+
else:
|
| 772 |
+
return load_video(video_url_or_urls, backend=backend, sample_timestamps_fn=sample_timestamps_fn)
|
| 773 |
+
|
| 774 |
+
def _decode_and_sample_videos(
|
| 775 |
+
self,
|
| 776 |
+
videos: VideoInput,
|
| 777 |
+
video_metadata: Union[VideoMetadata, dict],
|
| 778 |
+
do_sample_frames: Optional[bool] = None,
|
| 779 |
+
sample_indices_fn: Optional[Callable] = None,
|
| 780 |
+
sample_timestamps_fn: Optional[Callable] = None,
|
| 781 |
+
):
|
| 782 |
+
"""
|
| 783 |
+
Decode input videos and sample frames if needed.
|
| 784 |
+
"""
|
| 785 |
+
videos = make_batched_videos(videos)
|
| 786 |
+
video_metadata = make_batched_metadata(videos, video_metadata=video_metadata)
|
| 787 |
+
|
| 788 |
+
# Framed-based sampling if an array video is passed
|
| 789 |
+
# Otherwise, time-based sampling with decoding
|
| 790 |
+
if is_valid_video(videos[0]) and do_sample_frames:
|
| 791 |
+
assert video_metadata[0].fps is not None, "FPS must be provided for video input"
|
| 792 |
+
sampled_videos = []
|
| 793 |
+
sampled_metadata = []
|
| 794 |
+
for video, metadata in zip(videos, video_metadata):
|
| 795 |
+
indices = sample_indices_fn(metadata=metadata)
|
| 796 |
+
metadata.frames_indices = indices
|
| 797 |
+
sampled_videos.append(video[indices])
|
| 798 |
+
sampled_metadata.append(metadata)
|
| 799 |
+
videos = sampled_videos
|
| 800 |
+
video_metadata = sampled_metadata
|
| 801 |
+
elif not is_valid_video(videos[0]):
|
| 802 |
+
if sample_indices_fn is None:
|
| 803 |
+
logger.warning(
|
| 804 |
+
"do_sample_frames is False, but video array is not provided: "
|
| 805 |
+
"Will decode the video and sample frames using MolmoAct2's default sampling mode"
|
| 806 |
+
)
|
| 807 |
+
if isinstance(videos[0], list):
|
| 808 |
+
raise ValueError(
|
| 809 |
+
"A list of images is not supported for video input!"
|
| 810 |
+
)
|
| 811 |
+
else:
|
| 812 |
+
videos, video_metadata = self.fetch_videos(videos, sample_timestamps_fn=sample_timestamps_fn)
|
| 813 |
+
|
| 814 |
+
return videos, video_metadata
|
| 815 |
+
|
| 816 |
+
def _prepare_input_videos(
|
| 817 |
+
self,
|
| 818 |
+
videos: VideoInput,
|
| 819 |
+
**kwargs,
|
| 820 |
+
) -> list[np.ndarray]:
|
| 821 |
+
processed_videos = [to_numpy(video) for video in videos]
|
| 822 |
+
return processed_videos
|
| 823 |
+
|
| 824 |
+
def preprocess(
|
| 825 |
+
self,
|
| 826 |
+
videos: VideoInput,
|
| 827 |
+
**kwargs: Unpack[MolmoAct2VideoProcessorKwargs],
|
| 828 |
+
) -> BatchFeature:
|
| 829 |
+
validate_kwargs(
|
| 830 |
+
captured_kwargs=kwargs.keys(),
|
| 831 |
+
valid_processor_keys=list(self.valid_kwargs.__annotations__.keys()) + ["return_tensors"],
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
# Set default kwargs from self. This ensures that if a kwarg is not provided
|
| 835 |
+
# by the user, it gets its default value from the instance, or is set to None.
|
| 836 |
+
for kwarg_name in self.valid_kwargs.__annotations__:
|
| 837 |
+
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
|
| 838 |
+
|
| 839 |
+
do_sample_frames = kwargs.pop("do_sample_frames")
|
| 840 |
+
video_metadata = kwargs.pop("video_metadata")
|
| 841 |
+
|
| 842 |
+
sample_indices_fn = partial(self.sample_frames, **kwargs) if do_sample_frames else None
|
| 843 |
+
sample_timestamps_fn = partial(self.sample_times, **kwargs)
|
| 844 |
+
videos, video_metadata = self._decode_and_sample_videos(
|
| 845 |
+
videos,
|
| 846 |
+
video_metadata=video_metadata,
|
| 847 |
+
do_sample_frames=do_sample_frames,
|
| 848 |
+
sample_indices_fn=sample_indices_fn,
|
| 849 |
+
sample_timestamps_fn=sample_timestamps_fn,
|
| 850 |
+
)
|
| 851 |
+
videos = self._prepare_input_videos(videos=videos)
|
| 852 |
+
|
| 853 |
+
kwargs = self._further_process_kwargs(**kwargs)
|
| 854 |
+
|
| 855 |
+
return_metadata = kwargs.pop("return_metadata")
|
| 856 |
+
preprocessed_videos = self._preprocess(videos=videos, **kwargs)
|
| 857 |
+
if return_metadata:
|
| 858 |
+
preprocessed_videos["video_metadata"] = video_metadata
|
| 859 |
+
return preprocessed_videos
|
| 860 |
+
|
| 861 |
+
def _preprocess(
|
| 862 |
+
self,
|
| 863 |
+
videos: list[np.ndarray],
|
| 864 |
+
size: Optional[SizeDict] = None,
|
| 865 |
+
resample: Optional[PILImageResampling] = None,
|
| 866 |
+
image_mean: Optional[Union[float, list[float]]] = None,
|
| 867 |
+
image_std: Optional[Union[float, list[float]]] = None,
|
| 868 |
+
do_convert_rgb: Optional[bool] = None,
|
| 869 |
+
patch_size: Optional[int] = None,
|
| 870 |
+
pooling_size: Optional[list[int]] = None,
|
| 871 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 872 |
+
**kwargs,
|
| 873 |
+
) -> BatchFeature:
|
| 874 |
+
"""
|
| 875 |
+
Preprocess a video for the model.
|
| 876 |
+
Args:
|
| 877 |
+
videos (`VideoInput`):
|
| 878 |
+
Video to preprocess.
|
| 879 |
+
size (`SizeDict`, *optional*, defaults to `self.size`):
|
| 880 |
+
Size of the image after resizing.
|
| 881 |
+
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
| 882 |
+
Resampling filter to use when resizing the image. This can be one of the enum `PILImageResampling`. Only
|
| 883 |
+
has an effect if `do_resize` is set to `True`.
|
| 884 |
+
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
|
| 885 |
+
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
| 886 |
+
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
|
| 887 |
+
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
| 888 |
+
`True`.
|
| 889 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
| 890 |
+
Whether to convert the image to RGB.
|
| 891 |
+
patch_size (`int`, *optional*, defaults to `self.patch_size`):
|
| 892 |
+
The spatial patch size of the vision encoder.
|
| 893 |
+
pooling_size (`list[int]`, *optional*, defaults to `self.pooling_size`):
|
| 894 |
+
The pooling size of the vision adapter.
|
| 895 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 896 |
+
The type of tensors to return. Can be one of:
|
| 897 |
+
- Unset: Return a list of `np.ndarray`.
|
| 898 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 899 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 900 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 901 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 902 |
+
|
| 903 |
+
Returns:
|
| 904 |
+
A `BatchFeature` containing the following keys:
|
| 905 |
+
- `pixel_values_videos`: The preprocessed videos.
|
| 906 |
+
- `video_token_pooling`: The indices of the patches in `crops` to pool for each token in `video_tokens`.
|
| 907 |
+
- `video_grids`: The video grids.
|
| 908 |
+
"""
|
| 909 |
+
if size.height is None or size.width is None:
|
| 910 |
+
raise ValueError("size must contain 'height' and 'width' keys.")
|
| 911 |
+
|
| 912 |
+
base_image_input_size = [size.height, size.width]
|
| 913 |
+
|
| 914 |
+
resample = resample or self.resample
|
| 915 |
+
image_mean = image_mean or self.image_mean
|
| 916 |
+
image_std = image_std or self.image_std
|
| 917 |
+
do_convert_rgb = do_convert_rgb or self.do_convert_rgb
|
| 918 |
+
|
| 919 |
+
patch_size = patch_size or self.patch_size
|
| 920 |
+
pooling_size = pooling_size or self.pooling_size
|
| 921 |
+
|
| 922 |
+
image_pooling_h, image_pooling_w = pooling_size
|
| 923 |
+
|
| 924 |
+
batch_grids = []
|
| 925 |
+
batch_crops = []
|
| 926 |
+
batch_pooled_patches_idx = []
|
| 927 |
+
|
| 928 |
+
for video in videos:
|
| 929 |
+
all_crops = []
|
| 930 |
+
pooled_patches_idx = []
|
| 931 |
+
|
| 932 |
+
for frame in video:
|
| 933 |
+
image_grid, crops, pooled_idx = image_to_patches_and_grids(
|
| 934 |
+
frame,
|
| 935 |
+
base_image_input_size,
|
| 936 |
+
resample,
|
| 937 |
+
image_mean,
|
| 938 |
+
image_std,
|
| 939 |
+
patch_size,
|
| 940 |
+
image_pooling_w,
|
| 941 |
+
image_pooling_h,
|
| 942 |
+
)
|
| 943 |
+
offset = sum(np.prod(x.shape[:2]) for x in all_crops)
|
| 944 |
+
pooled_idx_with_offset = np.where(pooled_idx >= 0, pooled_idx + offset, pooled_idx)
|
| 945 |
+
pooled_patches_idx.append(pooled_idx_with_offset)
|
| 946 |
+
all_crops.append(crops)
|
| 947 |
+
|
| 948 |
+
video_grid = np.array([len(video), image_grid[0], image_grid[1]])
|
| 949 |
+
all_crops = np.concatenate(all_crops, 0)
|
| 950 |
+
pooled_patches_idx = np.concatenate(pooled_patches_idx, 0)
|
| 951 |
+
|
| 952 |
+
batch_grids.append(video_grid)
|
| 953 |
+
batch_crops.append(all_crops)
|
| 954 |
+
batch_pooled_patches_idx.append(pooled_patches_idx)
|
| 955 |
+
|
| 956 |
+
video_grids = np.stack(batch_grids, 0)
|
| 957 |
+
pixel_values_videos = np.concatenate(batch_crops, 0)
|
| 958 |
+
video_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
|
| 959 |
+
|
| 960 |
+
data =dict(
|
| 961 |
+
pixel_values_videos=pixel_values_videos,
|
| 962 |
+
video_token_pooling=video_token_pooling,
|
| 963 |
+
video_grids=video_grids,
|
| 964 |
+
)
|
| 965 |
+
|
| 966 |
+
return BatchFeature(data, tensor_type=return_tensors)
|
| 967 |
+
|
| 968 |
+
|
| 969 |
+
MolmoAct2VideoProcessor.register_for_auto_class()
|