Add files using upload-large-folder tool
Browse files- .gitattributes +1 -0
- README.md +31 -3
- assets/MolmoAct2-Think.svg +36 -0
- chat_template.jinja +1 -0
- config.json +160 -0
- configuration_molmoact2.py +565 -0
- generation_config.json +6 -0
- image_processing_molmoact2.py +546 -0
- inference.py +768 -0
- model-00001-of-00005.safetensors +3 -0
- model-00002-of-00005.safetensors +3 -0
- model-00003-of-00005.safetensors +3 -0
- model-00004-of-00005.safetensors +3 -0
- model-00005-of-00005.safetensors +3 -0
- model.safetensors.index.json +0 -0
- modeling_molmoact2.py +0 -0
- norm_stats.json +1739 -0
- processing_molmoact2.py +418 -0
- processor_config.json +85 -0
- tokenizer.json +3 -0
- tokenizer_config.json +34 -0
- video_processing_molmoact2.py +969 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,3 +1,31 @@
|
|
| 1 |
-
---
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: transformers
|
| 3 |
+
tags:
|
| 4 |
+
- molmoact2
|
| 5 |
+
- robotics
|
| 6 |
+
- image-text-to-text
|
| 7 |
+
- depth-reasoning
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
<img src="assets/MolmoAct2-Think.svg" alt="MolmoAct Think Logo" style="width: auto; height: 50px;">
|
| 11 |
+
|
| 12 |
+
# **MolmoAct2-Think**
|
| 13 |
+
|
| 14 |
+
MolmoAct2-Think extends MolmoAct2 with depth-token reasoning. Before producing an action, the model can predict a compact 10 x 10 discrete depth representation and condition the action expert on the resulting depth-aware VLM cache.
|
| 15 |
+
|
| 16 |
+
This checkpoint is the post-trained, multi-embodiment depth-reasoning model. It is intended as a foundation checkpoint for further robot fine-tuning rather than as a ready-to-run policy for a single deployment setting.
|
| 17 |
+
|
| 18 |
+
## Quick Links
|
| 19 |
+
|
| 20 |
+
- 📂 Models: [Models](https://huggingface.co/collections/allenai/molmoact2-models), [Finetuned Models](https://huggingface.co/collections/allenai/molmoact2-finetuned-models)
|
| 21 |
+
- 📂 Datasets: [MolmoAct2-BimanualYAM Dataset](https://huggingface.co/collections/allenai/molmoact2-datasets), [MolmoAct2 Datasets](https://huggingface.co/collections/allenai/molmoact2-datasets), [Molmo2-ER Datasets](https://huggingface.co/collections/allenai/molmo2-er-datasets)
|
| 22 |
+
- 📄 Paper:
|
| 23 |
+
- 💻 Code: [allenai/molmoact2](https://github.com/allenai/molmoact2)
|
| 24 |
+
- 🎥 Blog Post: [MolmoAct2](https://allenai.org/blog/molmoact2)
|
| 25 |
+
|
| 26 |
+
## Intended Use
|
| 27 |
+
|
| 28 |
+
Use this checkpoint for further fine-tuning when the downstream policy should use depth reasoning. It contains the VLM, action expert, and depth-token weights, plus normalization metadata for the post-training mixture in `norm_stats.json`.
|
| 29 |
+
|
| 30 |
+
This model card intentionally does not include direct policy inference code. For ready-to-run depth-reasoning inference, use the fine-tuned `MolmoAct2-Think-LIBERO` checkpoint.
|
| 31 |
+
|
assets/MolmoAct2-Think.svg
ADDED
|
|
chat_template.jinja
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{% set DEMO_STYLES = ['point_count','pointing','cosyn_point','user_qa','long_caption','short_caption','video_long_caption','video_short_caption','video_point_track_per_frame','video_point_track_start_end','video_point_track_all_frames','video_single_point_track_start_end','video_transcript','video_clip_caption_start_end','video_clip_caption_start_end_in_seconds','video_clip_transcript_start_end','video_clip_transcript_start_end_in_seconds','video_frame_caption_timestamp','video_frame_caption_timestamp_in_seconds','correction_qa','text_sft','video_point','video_point_count','video_count','video_count_point','multi_image_pointing','multi_image_counting','multi_image_point_then_count','multi_image_count_then_point','demo','a_okvqa_mc','ai2_diagram_no_letter','ai2_diagram','science_qa','multi_image_mc','multi_image_mc_exp','mantis_instruct_mc','video_multiple_choice','video_multiple_choice_count_without_pointing','video_multiple_choice_multiple_correct','video_multiple_choice_w_subtitle'] %}{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% set has_subtitle = messages and messages[0]['role'].lower() == 'subtitle' %}{% for message in messages %}{% if message['content'] is not string %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% elif content['type'] == 'video' or 'video' in content or 'video_url' in content %}{% set video_count.value = video_count.value + 1 %}{% endif %}{% endfor %}{% endif %}{% endfor %}{% if image_count.value == 1 %}{{ '<|image|>' }}{% elif image_count.value > 1 %}{% for i in range(image_count.value) %}{{ 'Image ' ~ (i + 1) ~ '<|image|>' }}{% endfor %}{% endif %}{% for _ in range(video_count.value) %}{{ '<|video|>' }}{% endfor %}{% if has_subtitle %}{{ messages[0]['content'] }}{% endif %}{% for message in messages %}{% set role = message['role'].lower() %}{% if role == 'subtitle' %}{% continue %}{% endif %}{% set conv_index = loop.index - (1 if has_subtitle else 0) %}{%- if (conv_index % 2 == 1 and role != 'user') or (conv_index % 2 == 0 and role != 'assistant') -%}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{%- endif -%}{% if message['content'] is string %}{% set text_content = message['content'] %}{% else %}{% set m = namespace(text='') %}{% for content in message['content'] %}{% if content['type'] == 'text' %}{% if content['style'] is defined and content['style'] not in DEMO_STYLES %}{% set seg = content['style'] ~ ': ' ~ content['text'] %}{% else %}{% set seg = content['text'] %}{% endif %}{% set m.text = m.text ~ ('' if not m.text else ' ') ~ seg %}{% endif %}{% endfor %}{% set text_content = m.text %}{% endif %}{% if role == 'user' %}{% if not (has_subtitle and loop.index == 2) and not (not has_subtitle and loop.first) %}{{ '<|im_end|>\n' }}{% endif %}{{ '<|im_start|>user\n' }}{{ text_content }}{{ '<|im_end|>\n' }}{% else %} {# assistant #}{{ '<|im_start|>assistant\n' }}{{ text_content }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}
|
config.json
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"action_end_token_id": 151933,
|
| 3 |
+
"action_expert_condition_source": "kv_cache",
|
| 4 |
+
"action_expert_config": {
|
| 5 |
+
"attn_dropout": 0.0,
|
| 6 |
+
"causal_attn": false,
|
| 7 |
+
"compile": "blocks",
|
| 8 |
+
"context_layer_norm": true,
|
| 9 |
+
"dropout": 0.0,
|
| 10 |
+
"ffn_multiple_of": 256,
|
| 11 |
+
"hidden_size": 768,
|
| 12 |
+
"implementation": "new",
|
| 13 |
+
"max_action_dim": 32,
|
| 14 |
+
"max_horizon": 32,
|
| 15 |
+
"mlp_ratio": 4.0,
|
| 16 |
+
"model_type": "molmoact2_action_expert",
|
| 17 |
+
"num_heads": 8,
|
| 18 |
+
"num_layers": 36,
|
| 19 |
+
"qk_norm": true,
|
| 20 |
+
"qk_norm_eps": 1e-06,
|
| 21 |
+
"rope": true,
|
| 22 |
+
"rope_on_cross_attention": true,
|
| 23 |
+
"timestep_embed_dim": 256
|
| 24 |
+
},
|
| 25 |
+
"action_expert_depth_gate": false,
|
| 26 |
+
"action_expert_depth_gate_init_bias": -4.0,
|
| 27 |
+
"action_expert_depth_gate_per_layer": false,
|
| 28 |
+
"action_expert_layer_mode": "per_layer",
|
| 29 |
+
"action_format": "both",
|
| 30 |
+
"action_horizon": 30,
|
| 31 |
+
"action_output_token_id": 151931,
|
| 32 |
+
"action_start_token_id": 151932,
|
| 33 |
+
"action_token_start_id": 151934,
|
| 34 |
+
"adapter_config": {
|
| 35 |
+
"attention_dropout": 0.0,
|
| 36 |
+
"attn_implementation": "sdpa",
|
| 37 |
+
"float32_attention": true,
|
| 38 |
+
"head_dim": 72,
|
| 39 |
+
"hidden_act": "silu",
|
| 40 |
+
"hidden_size": 1152,
|
| 41 |
+
"image_feature_dropout": 0.0,
|
| 42 |
+
"initializer_range": 0.02,
|
| 43 |
+
"intermediate_size": 9728,
|
| 44 |
+
"model_type": "molmoact2",
|
| 45 |
+
"num_attention_heads": 16,
|
| 46 |
+
"num_key_value_heads": 16,
|
| 47 |
+
"pooling_attention_mask": true,
|
| 48 |
+
"residual_dropout": 0.0,
|
| 49 |
+
"text_hidden_size": 2560,
|
| 50 |
+
"vit_layers": [
|
| 51 |
+
-3,
|
| 52 |
+
-9
|
| 53 |
+
]
|
| 54 |
+
},
|
| 55 |
+
"add_action_expert": true,
|
| 56 |
+
"add_control_tokens": true,
|
| 57 |
+
"add_setup_tokens": true,
|
| 58 |
+
"architectures": [
|
| 59 |
+
"MolmoAct2ForConditionalGeneration"
|
| 60 |
+
],
|
| 61 |
+
"auto_map": {
|
| 62 |
+
"AutoConfig": "configuration_molmoact2.MolmoAct2Config",
|
| 63 |
+
"AutoModelForImageTextToText": "modeling_molmoact2.MolmoAct2ForConditionalGeneration"
|
| 64 |
+
},
|
| 65 |
+
"depth_end_token_id": 153984,
|
| 66 |
+
"depth_mode": 2,
|
| 67 |
+
"depth_output_token_id": 153982,
|
| 68 |
+
"depth_start_token_id": 153983,
|
| 69 |
+
"depth_token_start_id": 153985,
|
| 70 |
+
"dtype": "float32",
|
| 71 |
+
"enable_depth_reasoning": true,
|
| 72 |
+
"flow_matching_beta_alpha": 1.0,
|
| 73 |
+
"flow_matching_beta_beta": 1.5,
|
| 74 |
+
"flow_matching_cutoff": 1.0,
|
| 75 |
+
"flow_matching_num_steps": 10,
|
| 76 |
+
"flow_matching_time_offset": 0.001,
|
| 77 |
+
"flow_matching_time_scale": 0.999,
|
| 78 |
+
"frame_end_token_id": 155656,
|
| 79 |
+
"frame_start_token_id": 155655,
|
| 80 |
+
"image_col_id": 155651,
|
| 81 |
+
"image_end_token_id": 155649,
|
| 82 |
+
"image_high_res_id": 155650,
|
| 83 |
+
"image_low_res_id": 155654,
|
| 84 |
+
"image_patch_id": 155650,
|
| 85 |
+
"image_start_token_id": 155648,
|
| 86 |
+
"initializer_range": 0.02,
|
| 87 |
+
"low_res_image_start_token_id": 155652,
|
| 88 |
+
"mask_action_dim_padding": true,
|
| 89 |
+
"max_action_dim": 32,
|
| 90 |
+
"model_type": "molmoact2",
|
| 91 |
+
"n_obs_steps": 1,
|
| 92 |
+
"norm_stats_filename": "norm_stats.json",
|
| 93 |
+
"num_action_tokens": 2048,
|
| 94 |
+
"num_depth_codes": 100,
|
| 95 |
+
"num_depth_tokens": 128,
|
| 96 |
+
"num_state_tokens": 256,
|
| 97 |
+
"state_end_token_id": 151674,
|
| 98 |
+
"state_format": "discrete",
|
| 99 |
+
"state_start_token_id": 151673,
|
| 100 |
+
"state_token_start_id": 151675,
|
| 101 |
+
"text_config": {
|
| 102 |
+
"additional_vocab_size": 128,
|
| 103 |
+
"attention_dropout": 0.0,
|
| 104 |
+
"attn_implementation": "sdpa",
|
| 105 |
+
"embedding_dropout": 0.0,
|
| 106 |
+
"head_dim": 128,
|
| 107 |
+
"hidden_act": "silu",
|
| 108 |
+
"hidden_size": 2560,
|
| 109 |
+
"initializer_range": 0.02,
|
| 110 |
+
"intermediate_size": 9728,
|
| 111 |
+
"layer_norm_eps": 1e-06,
|
| 112 |
+
"max_position_embeddings": 16384,
|
| 113 |
+
"model_type": "molmoact2_text",
|
| 114 |
+
"norm_after": false,
|
| 115 |
+
"num_attention_heads": 32,
|
| 116 |
+
"num_hidden_layers": 36,
|
| 117 |
+
"num_key_value_heads": 8,
|
| 118 |
+
"qk_norm_type": "qwen3",
|
| 119 |
+
"qkv_bias": false,
|
| 120 |
+
"residual_dropout": 0.0,
|
| 121 |
+
"rope_parameters": {
|
| 122 |
+
"rope_theta": 5000000.0,
|
| 123 |
+
"rope_type": "default"
|
| 124 |
+
},
|
| 125 |
+
"rope_scaling_layers": null,
|
| 126 |
+
"rope_theta": 5000000.0,
|
| 127 |
+
"tie_word_embeddings": false,
|
| 128 |
+
"use_cache": true,
|
| 129 |
+
"use_qk_norm": true,
|
| 130 |
+
"vocab_size": 155648
|
| 131 |
+
},
|
| 132 |
+
"tie_word_embeddings": false,
|
| 133 |
+
"transformers_version": "5.3.0",
|
| 134 |
+
"use_frame_special_tokens": true,
|
| 135 |
+
"vit_config": {
|
| 136 |
+
"attention_dropout": 0.0,
|
| 137 |
+
"attn_implementation": "sdpa",
|
| 138 |
+
"float32_attention": true,
|
| 139 |
+
"head_dim": 72,
|
| 140 |
+
"hidden_act": "gelu_pytorch_tanh",
|
| 141 |
+
"hidden_size": 1152,
|
| 142 |
+
"image_default_input_size": [
|
| 143 |
+
378,
|
| 144 |
+
378
|
| 145 |
+
],
|
| 146 |
+
"image_num_pos": 729,
|
| 147 |
+
"image_patch_size": 14,
|
| 148 |
+
"initializer_range": 0.02,
|
| 149 |
+
"intermediate_size": 4304,
|
| 150 |
+
"layer_norm_eps": 1e-06,
|
| 151 |
+
"model_type": "molmoact2",
|
| 152 |
+
"num_attention_heads": 16,
|
| 153 |
+
"num_hidden_layers": 27,
|
| 154 |
+
"num_key_value_heads": 16,
|
| 155 |
+
"residual_dropout": 0.0
|
| 156 |
+
},
|
| 157 |
+
"bos_token_id": 151645,
|
| 158 |
+
"eos_token_id": 151645,
|
| 159 |
+
"pad_token_id": 151643
|
| 160 |
+
}
|
configuration_molmoact2.py
ADDED
|
@@ -0,0 +1,565 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MolmoAct2 configuration
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Optional, Any
|
| 6 |
+
|
| 7 |
+
from transformers import PretrainedConfig
|
| 8 |
+
from transformers.modeling_rope_utils import rope_config_validation
|
| 9 |
+
from transformers.utils import logging
|
| 10 |
+
|
| 11 |
+
logger = logging.get_logger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class MolmoAct2VitConfig(PretrainedConfig):
|
| 15 |
+
r"""
|
| 16 |
+
This is the configuration class to store the configuration of a [`MolmoAct2VisionTransformer`].
|
| 17 |
+
It is used to instantiate a `MolmoAct2VisionTransformer` according to the specified arguments,
|
| 18 |
+
defining the model architecture.
|
| 19 |
+
|
| 20 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 21 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 22 |
+
|
| 23 |
+
Example:
|
| 24 |
+
```python
|
| 25 |
+
>>> from transformers import MolmoAct2VitConfig, MolmoAct2VisionTransformer
|
| 26 |
+
|
| 27 |
+
>>> # Initializing a MolmoAct2VitConfig
|
| 28 |
+
>>> configuration = MolmoAct2VitConfig()
|
| 29 |
+
|
| 30 |
+
>>> # Initializing a MolmoAct2VisionTransformer (with random weights)
|
| 31 |
+
>>> model = MolmoAct2VisionTransformer(configuration)
|
| 32 |
+
|
| 33 |
+
>>> # Accessing the model configuration
|
| 34 |
+
>>> configuration = model.config
|
| 35 |
+
```"""
|
| 36 |
+
|
| 37 |
+
model_type = "molmoact2"
|
| 38 |
+
base_config_key = "vit_config"
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
hidden_size: int = 1152,
|
| 43 |
+
intermediate_size: int = 4304,
|
| 44 |
+
num_hidden_layers: int = 27,
|
| 45 |
+
num_attention_heads: int = 16,
|
| 46 |
+
num_key_value_heads: int = 16,
|
| 47 |
+
head_dim: int = 72,
|
| 48 |
+
hidden_act: str = "gelu_pytorch_tanh",
|
| 49 |
+
layer_norm_eps: float = 1e-6,
|
| 50 |
+
image_default_input_size: tuple[int, int] = (378, 378),
|
| 51 |
+
image_patch_size: int = 14,
|
| 52 |
+
image_num_pos: int = 577,
|
| 53 |
+
attention_dropout: float = 0.0,
|
| 54 |
+
residual_dropout: float = 0.0,
|
| 55 |
+
initializer_range: float = 0.02,
|
| 56 |
+
float32_attention: bool = True,
|
| 57 |
+
attn_implementation: str = "eager",
|
| 58 |
+
**kwargs,
|
| 59 |
+
):
|
| 60 |
+
self.attn_implementation = attn_implementation
|
| 61 |
+
super().__init__(
|
| 62 |
+
attn_implementation=attn_implementation,
|
| 63 |
+
**kwargs
|
| 64 |
+
)
|
| 65 |
+
self.hidden_size = hidden_size
|
| 66 |
+
self.intermediate_size = intermediate_size
|
| 67 |
+
self.num_hidden_layers = num_hidden_layers
|
| 68 |
+
self.num_attention_heads = num_attention_heads
|
| 69 |
+
self.num_key_value_heads = num_key_value_heads
|
| 70 |
+
self.head_dim = head_dim
|
| 71 |
+
self.hidden_act = hidden_act
|
| 72 |
+
self.layer_norm_eps = layer_norm_eps
|
| 73 |
+
self.image_default_input_size = image_default_input_size
|
| 74 |
+
self.image_patch_size = image_patch_size
|
| 75 |
+
self.image_num_pos = image_num_pos
|
| 76 |
+
self.attention_dropout = attention_dropout
|
| 77 |
+
self.residual_dropout = residual_dropout
|
| 78 |
+
self.initializer_range = initializer_range
|
| 79 |
+
self.float32_attention = float32_attention
|
| 80 |
+
|
| 81 |
+
@property
|
| 82 |
+
def image_num_patch(self):
|
| 83 |
+
h, w = self.image_default_input_size
|
| 84 |
+
return h // self.image_patch_size, w // self.image_patch_size
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class MolmoAct2AdapterConfig(PretrainedConfig):
|
| 88 |
+
r"""
|
| 89 |
+
This is the configuration class to store the configuration of MolmoAct2Adapter. With MolmoAct2VitConfig,
|
| 90 |
+
It is used to instantiate an MolmoAct2VisionBackbone according to the specified arguments,
|
| 91 |
+
defining the model architecture.
|
| 92 |
+
|
| 93 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 94 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 95 |
+
|
| 96 |
+
Example:
|
| 97 |
+
|
| 98 |
+
```python
|
| 99 |
+
>>> from transformers import MolmoAct2VitConfig, MolmoAct2AdapterConfig, MolmoAct2VisionBackbone
|
| 100 |
+
|
| 101 |
+
>>> # Initializing a MolmoAct2VitConfig and a MolmoAct2AdapterConfig
|
| 102 |
+
>>> vit_config = MolmoAct2VitConfig()
|
| 103 |
+
>>> adapter_config = MolmoPoolingConfig()
|
| 104 |
+
|
| 105 |
+
>>> # Initializing a MolmoAct2VisionBackbone (with random weights)
|
| 106 |
+
>>> model = MolmoAct2VisionBackbone(vit_config, adapter_config)
|
| 107 |
+
|
| 108 |
+
>>> # Accessing the model configuration
|
| 109 |
+
>>> vit_configuration = model.vit_config
|
| 110 |
+
>>> adapter_configuration = model.adapter_config
|
| 111 |
+
```"""
|
| 112 |
+
|
| 113 |
+
model_type = "molmoact2"
|
| 114 |
+
base_config_key = "adapter_config"
|
| 115 |
+
|
| 116 |
+
def __init__(
|
| 117 |
+
self,
|
| 118 |
+
vit_layers: tuple = (-3, -9),
|
| 119 |
+
pooling_attention_mask: bool = False,
|
| 120 |
+
hidden_size: int = 1152,
|
| 121 |
+
num_attention_heads: int = 16,
|
| 122 |
+
num_key_value_heads: int = 16,
|
| 123 |
+
head_dim: int = 72,
|
| 124 |
+
float32_attention: bool = True,
|
| 125 |
+
attention_dropout: float = 0.0,
|
| 126 |
+
residual_dropout: float = 0.0,
|
| 127 |
+
hidden_act: str = "silu",
|
| 128 |
+
intermediate_size: int = 18944,
|
| 129 |
+
text_hidden_size: int = 3584,
|
| 130 |
+
image_feature_dropout: float = 0.0,
|
| 131 |
+
initializer_range: float = 0.02,
|
| 132 |
+
attn_implementation: str = "eager",
|
| 133 |
+
**kwargs,
|
| 134 |
+
):
|
| 135 |
+
self.attn_implementation = attn_implementation
|
| 136 |
+
super().__init__(
|
| 137 |
+
attn_implementation=attn_implementation,
|
| 138 |
+
**kwargs
|
| 139 |
+
)
|
| 140 |
+
self.vit_layers = vit_layers
|
| 141 |
+
self.pooling_attention_mask = pooling_attention_mask
|
| 142 |
+
self.hidden_size = hidden_size
|
| 143 |
+
self.num_attention_heads = num_attention_heads
|
| 144 |
+
self.num_key_value_heads = num_key_value_heads
|
| 145 |
+
self.head_dim = head_dim
|
| 146 |
+
self.float32_attention = float32_attention
|
| 147 |
+
self.attention_dropout = attention_dropout
|
| 148 |
+
self.residual_dropout = residual_dropout
|
| 149 |
+
self.hidden_act = hidden_act
|
| 150 |
+
self.intermediate_size = intermediate_size
|
| 151 |
+
self.text_hidden_size = text_hidden_size
|
| 152 |
+
self.image_feature_dropout = image_feature_dropout
|
| 153 |
+
self.initializer_range = initializer_range
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class MolmoAct2TextConfig(PretrainedConfig):
|
| 157 |
+
r"""
|
| 158 |
+
This is the configuration class to store the configuration of a [`MolmoAct2TextModel`]. It is used to instantiate a
|
| 159 |
+
`MolmoAct2TextModel` according to the specified arguments, defining the model architecture.
|
| 160 |
+
|
| 161 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 162 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 163 |
+
|
| 164 |
+
Example:
|
| 165 |
+
```python
|
| 166 |
+
>>> from transformers import MolmoAct2TextConfig, MolmoAct2TextModel
|
| 167 |
+
|
| 168 |
+
>>> # Initializing a MolmoAct2TextConfig
|
| 169 |
+
>>> configuration = MolmoAct2TextConfig()
|
| 170 |
+
|
| 171 |
+
>>> # Initializing a MolmoAct2TextModel (with random weights)
|
| 172 |
+
>>> model = MolmoAct2TextModel(configuration)
|
| 173 |
+
|
| 174 |
+
>>> # Accessing the model configuration
|
| 175 |
+
>>> configuration = model.config
|
| 176 |
+
```"""
|
| 177 |
+
|
| 178 |
+
model_type = "molmoact2_text"
|
| 179 |
+
base_config_key = "text_config"
|
| 180 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 181 |
+
base_model_tp_plan = {
|
| 182 |
+
"blocks.*.self_attn.att_proj": "colwise",
|
| 183 |
+
"blocks.*.self_attn.attn_out": "rowwise",
|
| 184 |
+
"blocks.*.mlp.ff_proj": "colwise",
|
| 185 |
+
"blocks.*.mlp.ff_out": "rowwise",
|
| 186 |
+
}
|
| 187 |
+
base_model_pp_plan = {
|
| 188 |
+
"wte": (["input_ids"], ["inputs_embeds"]),
|
| 189 |
+
"blocks": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| 190 |
+
"ln_f": (["hidden_states"], ["hidden_states"]),
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
def __init__(
|
| 194 |
+
self,
|
| 195 |
+
hidden_size: int = 3584,
|
| 196 |
+
num_attention_heads: int = 28,
|
| 197 |
+
num_key_value_heads: Optional[int] = 4,
|
| 198 |
+
head_dim: int = 128,
|
| 199 |
+
vocab_size: int = 152064,
|
| 200 |
+
additional_vocab_size: int = 128,
|
| 201 |
+
qkv_bias: bool = True,
|
| 202 |
+
num_hidden_layers: int = 48,
|
| 203 |
+
intermediate_size: int = 18944,
|
| 204 |
+
hidden_act: str = "silu",
|
| 205 |
+
embedding_dropout: float=0.0,
|
| 206 |
+
attention_dropout: float=0.0,
|
| 207 |
+
residual_dropout: float = 0.0,
|
| 208 |
+
max_position_embeddings: int = 4096,
|
| 209 |
+
rope_theta: float = 1000000.0,
|
| 210 |
+
rope_scaling: dict[str, Any] = None,
|
| 211 |
+
rope_scaling_layers: Optional[list[int]] = None,
|
| 212 |
+
use_qk_norm: bool = False,
|
| 213 |
+
qk_norm_type: str = "olmo",
|
| 214 |
+
layer_norm_eps: int = 1e-6,
|
| 215 |
+
norm_after: bool = False,
|
| 216 |
+
initializer_range: float = 0.02,
|
| 217 |
+
use_cache=True,
|
| 218 |
+
tie_word_embeddings=False,
|
| 219 |
+
attn_implementation: str = "eager",
|
| 220 |
+
**kwargs,
|
| 221 |
+
):
|
| 222 |
+
self.attn_implementation = attn_implementation
|
| 223 |
+
super().__init__(
|
| 224 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 225 |
+
attn_implementation=attn_implementation,
|
| 226 |
+
**kwargs
|
| 227 |
+
)
|
| 228 |
+
self.hidden_size = hidden_size
|
| 229 |
+
self.num_attention_heads = num_attention_heads
|
| 230 |
+
if num_key_value_heads is None:
|
| 231 |
+
num_key_value_heads = num_attention_heads
|
| 232 |
+
self.num_key_value_heads = num_key_value_heads
|
| 233 |
+
self.head_dim = head_dim
|
| 234 |
+
self.vocab_size = vocab_size
|
| 235 |
+
self.additional_vocab_size = additional_vocab_size
|
| 236 |
+
self.qkv_bias = qkv_bias
|
| 237 |
+
self.num_hidden_layers = num_hidden_layers
|
| 238 |
+
self.intermediate_size = intermediate_size
|
| 239 |
+
self.hidden_act = hidden_act
|
| 240 |
+
self.embedding_dropout = embedding_dropout
|
| 241 |
+
self.attention_dropout = attention_dropout
|
| 242 |
+
self.residual_dropout = residual_dropout
|
| 243 |
+
self.max_position_embeddings = max_position_embeddings
|
| 244 |
+
self.rope_theta = rope_theta
|
| 245 |
+
self.rope_scaling = rope_scaling
|
| 246 |
+
self.rope_scaling_layers = rope_scaling_layers
|
| 247 |
+
self.use_qk_norm = use_qk_norm
|
| 248 |
+
self.qk_norm_type = qk_norm_type
|
| 249 |
+
self.layer_norm_eps = layer_norm_eps
|
| 250 |
+
self.norm_after = norm_after
|
| 251 |
+
self.initializer_range = initializer_range
|
| 252 |
+
self.use_cache = use_cache
|
| 253 |
+
|
| 254 |
+
# Validate the correctness of rotary position embeddings parameters
|
| 255 |
+
rope_config_validation(self)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class MolmoAct2ActionExpertConfig(PretrainedConfig):
|
| 259 |
+
r"""Configuration for the MolmoAct2 modern action expert."""
|
| 260 |
+
|
| 261 |
+
model_type = "molmoact2_action_expert"
|
| 262 |
+
base_config_key = "action_expert_config"
|
| 263 |
+
|
| 264 |
+
def __init__(
|
| 265 |
+
self,
|
| 266 |
+
max_horizon: int = 32,
|
| 267 |
+
max_action_dim: int = 14,
|
| 268 |
+
hidden_size: int = 1024,
|
| 269 |
+
num_layers: int = 32,
|
| 270 |
+
num_heads: int = 16,
|
| 271 |
+
mlp_ratio: float = 8.0 / 3.0,
|
| 272 |
+
ffn_multiple_of: int = 256,
|
| 273 |
+
timestep_embed_dim: int = 256,
|
| 274 |
+
dropout: float = 0.0,
|
| 275 |
+
attn_dropout: float = 0.0,
|
| 276 |
+
context_layer_norm: bool = True,
|
| 277 |
+
qk_norm: bool = True,
|
| 278 |
+
qk_norm_eps: float = 1e-6,
|
| 279 |
+
rope: bool = True,
|
| 280 |
+
rope_on_cross_attention: bool = False,
|
| 281 |
+
causal_attn: bool = False,
|
| 282 |
+
compile: str = "blocks",
|
| 283 |
+
implementation: str = "new",
|
| 284 |
+
**kwargs,
|
| 285 |
+
):
|
| 286 |
+
super().__init__(**kwargs)
|
| 287 |
+
if implementation != "new":
|
| 288 |
+
raise ValueError(
|
| 289 |
+
"MolmoAct2 HF export supports only action_expert.implementation='new'."
|
| 290 |
+
)
|
| 291 |
+
self.max_horizon = max_horizon
|
| 292 |
+
self.max_action_dim = max_action_dim
|
| 293 |
+
self.hidden_size = hidden_size
|
| 294 |
+
self.num_layers = num_layers
|
| 295 |
+
self.num_heads = num_heads
|
| 296 |
+
self.mlp_ratio = mlp_ratio
|
| 297 |
+
self.ffn_multiple_of = ffn_multiple_of
|
| 298 |
+
self.timestep_embed_dim = timestep_embed_dim
|
| 299 |
+
self.dropout = dropout
|
| 300 |
+
self.attn_dropout = attn_dropout
|
| 301 |
+
self.context_layer_norm = context_layer_norm
|
| 302 |
+
self.qk_norm = qk_norm
|
| 303 |
+
self.qk_norm_eps = qk_norm_eps
|
| 304 |
+
self.rope = rope
|
| 305 |
+
self.rope_on_cross_attention = rope_on_cross_attention
|
| 306 |
+
self.causal_attn = causal_attn
|
| 307 |
+
self.compile = compile
|
| 308 |
+
self.implementation = implementation
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
class MolmoAct2Config(PretrainedConfig):
|
| 312 |
+
r"""
|
| 313 |
+
This is the configuration class to store the configuration of a [`MolmoAct2ForConditionalGeneration`].
|
| 314 |
+
It is used to instantiate an MolmoAct2 model according to the specified arguments, defining the model architecture.
|
| 315 |
+
|
| 316 |
+
Example:
|
| 317 |
+
|
| 318 |
+
```python
|
| 319 |
+
>>> from transformers import MolmoAct2Config, MolmoAct2VitConfig, MolmoAct2AdapterConfig, MolmoAct2TextConfig
|
| 320 |
+
|
| 321 |
+
>>> # Initializing a MolmoAct2VitConfig
|
| 322 |
+
>>> vit_config = MolmoAct2VitConfig()
|
| 323 |
+
|
| 324 |
+
>>> # Initializing a MolmoAct2AdapterConfig
|
| 325 |
+
>>> adapter_config = MolmoAct2AdapterConfig()
|
| 326 |
+
|
| 327 |
+
>>> # Initializing a MolmoAct2TextConfig
|
| 328 |
+
>>> text_config = MolmoAct2TextConfig()
|
| 329 |
+
|
| 330 |
+
>>> # Initializing a MolmoAct2Config
|
| 331 |
+
>>> configuration = MolmoAct2Config(
|
| 332 |
+
>>> vit_config=vit_config,
|
| 333 |
+
>>> adapter_config=adapter_config,
|
| 334 |
+
>>> text_config=text_config,
|
| 335 |
+
>>> image_start_token_id=151936,
|
| 336 |
+
>>> image_end_token_id=151937,
|
| 337 |
+
>>> image_patch_id=151938,
|
| 338 |
+
>>> image_col_id=151939,
|
| 339 |
+
>>> low_res_image_start_token_id=151940,
|
| 340 |
+
>>> image_low_res_id=151942,
|
| 341 |
+
>>> frame_start_token_id=151943,
|
| 342 |
+
>>> frame_end_token_id=151944,
|
| 343 |
+
>>> )
|
| 344 |
+
|
| 345 |
+
>>> # Initializing a model
|
| 346 |
+
>>> model = MolmoAct2ForConditionalGeneration(configuration)
|
| 347 |
+
|
| 348 |
+
>>> # Accessing the model configuration
|
| 349 |
+
>>> configuration = model.config
|
| 350 |
+
```"""
|
| 351 |
+
|
| 352 |
+
model_type = "molmoact2"
|
| 353 |
+
sub_configs = {
|
| 354 |
+
"text_config": MolmoAct2TextConfig,
|
| 355 |
+
"vit_config": MolmoAct2VitConfig,
|
| 356 |
+
"adapter_config": MolmoAct2AdapterConfig,
|
| 357 |
+
"action_expert_config": MolmoAct2ActionExpertConfig,
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
def __init__(
|
| 361 |
+
self,
|
| 362 |
+
vit_config: MolmoAct2VitConfig = None,
|
| 363 |
+
adapter_config: MolmoAct2AdapterConfig = None,
|
| 364 |
+
text_config: MolmoAct2TextConfig = None,
|
| 365 |
+
action_expert_config: MolmoAct2ActionExpertConfig = None,
|
| 366 |
+
image_start_token_id: int = None,
|
| 367 |
+
low_res_image_start_token_id: int = None,
|
| 368 |
+
image_end_token_id: int = None,
|
| 369 |
+
image_low_res_id: int = None,
|
| 370 |
+
image_patch_id: int = None,
|
| 371 |
+
image_col_id: int = None,
|
| 372 |
+
frame_start_token_id: int = None,
|
| 373 |
+
frame_end_token_id: int = None,
|
| 374 |
+
use_frame_special_tokens: bool = True,
|
| 375 |
+
initializer_range: float = 0.02,
|
| 376 |
+
add_action_expert: bool = True,
|
| 377 |
+
max_action_dim: int = 7,
|
| 378 |
+
action_horizon: int = 16,
|
| 379 |
+
n_obs_steps: int = 1,
|
| 380 |
+
action_format: str = "continuous",
|
| 381 |
+
state_format: str = "discrete",
|
| 382 |
+
action_expert_condition_source: str = "kv_cache",
|
| 383 |
+
action_expert_layer_mode: str = "per_layer",
|
| 384 |
+
flow_matching_num_steps: int = 10,
|
| 385 |
+
flow_matching_cutoff: float = 1.0,
|
| 386 |
+
flow_matching_time_offset: float = 0.001,
|
| 387 |
+
flow_matching_time_scale: float = 0.999,
|
| 388 |
+
flow_matching_beta_alpha: float = 1.0,
|
| 389 |
+
flow_matching_beta_beta: float = 1.5,
|
| 390 |
+
mask_action_dim_padding: bool = True,
|
| 391 |
+
enable_depth_reasoning: bool = False,
|
| 392 |
+
depth_mode: int = 2,
|
| 393 |
+
num_depth_codes: int = 100,
|
| 394 |
+
action_expert_depth_gate: bool = False,
|
| 395 |
+
action_expert_depth_gate_per_layer: bool = False,
|
| 396 |
+
action_expert_depth_gate_init_bias: float = -4.0,
|
| 397 |
+
action_output_token_id: int = None,
|
| 398 |
+
action_start_token_id: int = None,
|
| 399 |
+
action_end_token_id: int = None,
|
| 400 |
+
action_token_start_id: int = None,
|
| 401 |
+
num_action_tokens: int = 0,
|
| 402 |
+
depth_output_token_id: int = None,
|
| 403 |
+
depth_start_token_id: int = None,
|
| 404 |
+
depth_end_token_id: int = None,
|
| 405 |
+
depth_token_start_id: int = None,
|
| 406 |
+
num_depth_tokens: int = 0,
|
| 407 |
+
state_start_token_id: int = None,
|
| 408 |
+
state_end_token_id: int = None,
|
| 409 |
+
state_token_start_id: int = None,
|
| 410 |
+
num_state_tokens: int = 0,
|
| 411 |
+
add_setup_tokens: bool = True,
|
| 412 |
+
add_control_tokens: bool = True,
|
| 413 |
+
norm_stats_filename: str = "norm_stats.json",
|
| 414 |
+
**kwargs,
|
| 415 |
+
):
|
| 416 |
+
super().__init__(**kwargs)
|
| 417 |
+
if vit_config is None:
|
| 418 |
+
self.vit_config = MolmoAct2VitConfig()
|
| 419 |
+
elif isinstance(vit_config, dict):
|
| 420 |
+
self.vit_config = MolmoAct2VitConfig(**vit_config)
|
| 421 |
+
else:
|
| 422 |
+
self.vit_config = vit_config
|
| 423 |
+
if adapter_config is None:
|
| 424 |
+
self.adapter_config = MolmoAct2AdapterConfig()
|
| 425 |
+
elif isinstance(adapter_config, dict):
|
| 426 |
+
self.adapter_config = MolmoAct2AdapterConfig(**adapter_config)
|
| 427 |
+
else:
|
| 428 |
+
self.adapter_config = adapter_config
|
| 429 |
+
if text_config is None:
|
| 430 |
+
self.text_config = MolmoAct2TextConfig()
|
| 431 |
+
elif isinstance(text_config, dict):
|
| 432 |
+
self.text_config = MolmoAct2TextConfig(**text_config)
|
| 433 |
+
else:
|
| 434 |
+
self.text_config = text_config
|
| 435 |
+
self.add_action_expert = bool(add_action_expert)
|
| 436 |
+
if not self.add_action_expert:
|
| 437 |
+
self.action_expert_config = None
|
| 438 |
+
elif action_expert_config is None:
|
| 439 |
+
self.action_expert_config = MolmoAct2ActionExpertConfig(
|
| 440 |
+
max_horizon=action_horizon,
|
| 441 |
+
max_action_dim=max_action_dim,
|
| 442 |
+
num_layers=self.text_config.num_hidden_layers,
|
| 443 |
+
)
|
| 444 |
+
elif isinstance(action_expert_config, dict):
|
| 445 |
+
self.action_expert_config = MolmoAct2ActionExpertConfig(**action_expert_config)
|
| 446 |
+
else:
|
| 447 |
+
self.action_expert_config = action_expert_config
|
| 448 |
+
if self.add_action_expert:
|
| 449 |
+
self._validate_release_action_config(
|
| 450 |
+
action_expert_config=self.action_expert_config,
|
| 451 |
+
action_expert_condition_source=action_expert_condition_source,
|
| 452 |
+
action_expert_layer_mode=action_expert_layer_mode,
|
| 453 |
+
state_format=state_format,
|
| 454 |
+
)
|
| 455 |
+
self.image_start_token_id = image_start_token_id
|
| 456 |
+
self.low_res_image_start_token_id = low_res_image_start_token_id
|
| 457 |
+
self.image_end_token_id = image_end_token_id
|
| 458 |
+
self.image_low_res_id = image_low_res_id
|
| 459 |
+
self.image_high_res_id = image_patch_id
|
| 460 |
+
self.image_patch_id = image_patch_id
|
| 461 |
+
self.image_col_id = image_col_id
|
| 462 |
+
self.frame_start_token_id = frame_start_token_id
|
| 463 |
+
self.frame_end_token_id = frame_end_token_id
|
| 464 |
+
self.use_frame_special_tokens = use_frame_special_tokens
|
| 465 |
+
self.initializer_range = initializer_range
|
| 466 |
+
self.max_action_dim = max_action_dim
|
| 467 |
+
self.action_horizon = action_horizon
|
| 468 |
+
self.n_obs_steps = n_obs_steps
|
| 469 |
+
self.action_format = action_format
|
| 470 |
+
self.state_format = state_format
|
| 471 |
+
self.action_expert_condition_source = action_expert_condition_source
|
| 472 |
+
self.action_expert_layer_mode = action_expert_layer_mode
|
| 473 |
+
self.flow_matching_num_steps = flow_matching_num_steps
|
| 474 |
+
self.flow_matching_cutoff = flow_matching_cutoff
|
| 475 |
+
self.flow_matching_time_offset = flow_matching_time_offset
|
| 476 |
+
self.flow_matching_time_scale = flow_matching_time_scale
|
| 477 |
+
self.flow_matching_beta_alpha = flow_matching_beta_alpha
|
| 478 |
+
self.flow_matching_beta_beta = flow_matching_beta_beta
|
| 479 |
+
self.mask_action_dim_padding = mask_action_dim_padding
|
| 480 |
+
self.enable_depth_reasoning = enable_depth_reasoning
|
| 481 |
+
self.depth_mode = depth_mode
|
| 482 |
+
self.num_depth_codes = num_depth_codes
|
| 483 |
+
self.action_expert_depth_gate = action_expert_depth_gate
|
| 484 |
+
self.action_expert_depth_gate_per_layer = action_expert_depth_gate_per_layer
|
| 485 |
+
self.action_expert_depth_gate_init_bias = action_expert_depth_gate_init_bias
|
| 486 |
+
self.action_output_token_id = action_output_token_id
|
| 487 |
+
self.action_start_token_id = action_start_token_id
|
| 488 |
+
self.action_end_token_id = action_end_token_id
|
| 489 |
+
self.action_token_start_id = action_token_start_id
|
| 490 |
+
self.num_action_tokens = num_action_tokens
|
| 491 |
+
self.depth_output_token_id = depth_output_token_id
|
| 492 |
+
self.depth_start_token_id = depth_start_token_id
|
| 493 |
+
self.depth_end_token_id = depth_end_token_id
|
| 494 |
+
self.depth_token_start_id = depth_token_start_id
|
| 495 |
+
self.num_depth_tokens = num_depth_tokens
|
| 496 |
+
self.state_start_token_id = state_start_token_id
|
| 497 |
+
self.state_end_token_id = state_end_token_id
|
| 498 |
+
self.state_token_start_id = state_token_start_id
|
| 499 |
+
self.num_state_tokens = num_state_tokens
|
| 500 |
+
self.add_setup_tokens = add_setup_tokens
|
| 501 |
+
self.add_control_tokens = add_control_tokens
|
| 502 |
+
self.norm_stats_filename = norm_stats_filename
|
| 503 |
+
|
| 504 |
+
@staticmethod
|
| 505 |
+
def _validate_release_action_config(
|
| 506 |
+
*,
|
| 507 |
+
action_expert_config: MolmoAct2ActionExpertConfig,
|
| 508 |
+
action_expert_condition_source: str,
|
| 509 |
+
action_expert_layer_mode: str,
|
| 510 |
+
state_format: str,
|
| 511 |
+
) -> None:
|
| 512 |
+
if action_expert_config.implementation != "new":
|
| 513 |
+
raise ValueError(
|
| 514 |
+
"MolmoAct2 HF export supports only action_expert.implementation='new'."
|
| 515 |
+
)
|
| 516 |
+
if action_expert_condition_source != "kv_cache":
|
| 517 |
+
raise ValueError(
|
| 518 |
+
"MolmoAct2 HF export supports only action_expert_condition_source='kv_cache'."
|
| 519 |
+
)
|
| 520 |
+
if action_expert_layer_mode != "per_layer":
|
| 521 |
+
raise ValueError(
|
| 522 |
+
"MolmoAct2 HF export supports only action_expert_layer_mode='per_layer'."
|
| 523 |
+
)
|
| 524 |
+
if state_format != "discrete":
|
| 525 |
+
raise ValueError("MolmoAct2 HF export supports only state_format='discrete'.")
|
| 526 |
+
|
| 527 |
+
@property
|
| 528 |
+
def image_num_patch(self):
|
| 529 |
+
assert self.vit_config is not None
|
| 530 |
+
return self.vit_config.image_num_patch
|
| 531 |
+
|
| 532 |
+
@property
|
| 533 |
+
def num_attention_heads(self):
|
| 534 |
+
return self.text_config.num_attention_heads
|
| 535 |
+
|
| 536 |
+
@property
|
| 537 |
+
def num_key_value_heads(self):
|
| 538 |
+
return self.text_config.num_key_value_heads
|
| 539 |
+
|
| 540 |
+
@property
|
| 541 |
+
def head_dim(self):
|
| 542 |
+
return self.text_config.head_dim
|
| 543 |
+
|
| 544 |
+
@property
|
| 545 |
+
def num_hidden_layers(self):
|
| 546 |
+
return self.text_config.num_hidden_layers
|
| 547 |
+
|
| 548 |
+
@property
|
| 549 |
+
def hidden_size(self):
|
| 550 |
+
return self.text_config.hidden_size
|
| 551 |
+
|
| 552 |
+
@property
|
| 553 |
+
def vocab_size(self):
|
| 554 |
+
return self.text_config.vocab_size
|
| 555 |
+
|
| 556 |
+
@property
|
| 557 |
+
def max_position_embeddings(self):
|
| 558 |
+
return self.text_config.max_position_embeddings
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
MolmoAct2VitConfig.register_for_auto_class()
|
| 562 |
+
MolmoAct2AdapterConfig.register_for_auto_class()
|
| 563 |
+
MolmoAct2TextConfig.register_for_auto_class()
|
| 564 |
+
MolmoAct2ActionExpertConfig.register_for_auto_class()
|
| 565 |
+
MolmoAct2Config.register_for_auto_class()
|
generation_config.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token_id": 151645,
|
| 3 |
+
"eos_token_id": 151645,
|
| 4 |
+
"pad_token_id": 151643,
|
| 5 |
+
"transformers_version": "5.3.0"
|
| 6 |
+
}
|
image_processing_molmoact2.py
ADDED
|
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Image processor class for MolmoAct2"""
|
| 2 |
+
from typing import Optional, Union
|
| 3 |
+
import numpy as np
|
| 4 |
+
import einops
|
| 5 |
+
import torch
|
| 6 |
+
import torchvision.transforms
|
| 7 |
+
|
| 8 |
+
from transformers.image_utils import (
|
| 9 |
+
IMAGENET_STANDARD_MEAN,
|
| 10 |
+
IMAGENET_STANDARD_STD,
|
| 11 |
+
ImageInput,
|
| 12 |
+
PILImageResampling,
|
| 13 |
+
make_flat_list_of_images,
|
| 14 |
+
valid_images,
|
| 15 |
+
to_numpy_array,
|
| 16 |
+
)
|
| 17 |
+
from transformers.image_transforms import convert_to_rgb
|
| 18 |
+
from transformers.processing_utils import ImagesKwargs
|
| 19 |
+
from transformers.image_processing_utils import BaseImageProcessor, get_size_dict
|
| 20 |
+
from transformers.utils import logging
|
| 21 |
+
from transformers.feature_extraction_utils import BatchFeature
|
| 22 |
+
from transformers.utils import TensorType, logging
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def normalize_image(
|
| 29 |
+
image: np.ndarray,
|
| 30 |
+
image_mean: list[float],
|
| 31 |
+
image_std: list[float],
|
| 32 |
+
) -> np.ndarray:
|
| 33 |
+
if np.allclose(image_mean, [0.5, 0.5, 0.5]) and np.allclose(image_std, [0.5, 0.5, 0.5]):
|
| 34 |
+
return image * np.asarray(2.0, dtype=np.float32) - np.asarray(1.0, dtype=np.float32)
|
| 35 |
+
image -= np.array(image_mean, dtype=np.float32)[None, None, :]
|
| 36 |
+
image /= np.array(image_std, dtype=np.float32)[None, None, :]
|
| 37 |
+
return image
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def resize_image(
|
| 41 |
+
image: np.ndarray,
|
| 42 |
+
desired_output_size: list[int],
|
| 43 |
+
resample: PILImageResampling,
|
| 44 |
+
) -> np.ndarray:
|
| 45 |
+
image = torch.permute(torch.from_numpy(image), [2, 0, 1])
|
| 46 |
+
dtype = image.dtype
|
| 47 |
+
if torch.is_floating_point(image):
|
| 48 |
+
in_min = 0.0
|
| 49 |
+
in_max = 1.0
|
| 50 |
+
resized = torchvision.transforms.Resize(
|
| 51 |
+
desired_output_size,
|
| 52 |
+
resample,
|
| 53 |
+
antialias=False,
|
| 54 |
+
)(image)
|
| 55 |
+
resized = torch.clip(resized, 0.0, 1.0).to(dtype)
|
| 56 |
+
else:
|
| 57 |
+
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(image.dtype)
|
| 58 |
+
in_min = 0.0
|
| 59 |
+
in_max = 255.0
|
| 60 |
+
resized = torchvision.transforms.Resize(
|
| 61 |
+
desired_output_size,
|
| 62 |
+
resample,
|
| 63 |
+
antialias=False,
|
| 64 |
+
)(image)
|
| 65 |
+
resized = torch.clip(resized, 0, 255).to(dtype)
|
| 66 |
+
|
| 67 |
+
resized = resized.to(torch.float32)
|
| 68 |
+
resized = (resized - in_min) / (in_max - in_min)
|
| 69 |
+
|
| 70 |
+
resized = torch.permute(resized, [1, 2, 0]).numpy()
|
| 71 |
+
|
| 72 |
+
return resized
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def select_tiling(h, w, patch_size, max_num_crops):
|
| 76 |
+
"""Divide in image of size [w, h] in up to max_num_patches of size patch_size"""
|
| 77 |
+
original_size = np.stack([h, w]) # [1, 2]
|
| 78 |
+
original_res = h * w
|
| 79 |
+
tilings = []
|
| 80 |
+
for i in range(1, max_num_crops + 1):
|
| 81 |
+
for j in range(1, max_num_crops + 1):
|
| 82 |
+
if i*j <= max_num_crops:
|
| 83 |
+
tilings.append((i, j))
|
| 84 |
+
# sort so argmin and argmax favour smaller tilings in the event of a tie
|
| 85 |
+
tilings.sort(key=lambda x: (x[0]*x[1], x[0]))
|
| 86 |
+
candidate_tilings = np.array(tilings, dtype=np.int32) # [n_resolutions, 2]
|
| 87 |
+
candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2]
|
| 88 |
+
|
| 89 |
+
# How much we would need to scale the image to fit exactly in each tiling
|
| 90 |
+
original_size = np.stack([h, w], dtype=np.float32) # [1, 2]
|
| 91 |
+
|
| 92 |
+
# The original size can be zero in rare cases if the image is smaller than the margin
|
| 93 |
+
# In those cases letting the scale become infinite means the tiling is based on the
|
| 94 |
+
# other side, or falls back to the smallest tiling
|
| 95 |
+
with np.errstate(divide='ignore'):
|
| 96 |
+
required_scale_d = candidate_resolutions.astype(np.float32) / original_size,
|
| 97 |
+
required_scale = np.min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1]
|
| 98 |
+
if np.all(required_scale < 1):
|
| 99 |
+
# We are forced to downscale, so try to minimize the amount of downscaling
|
| 100 |
+
ix = np.argmax(required_scale)
|
| 101 |
+
else:
|
| 102 |
+
# Pick the resolution that required the least upscaling so that it most closely fits the image
|
| 103 |
+
required_scale = np.where(required_scale < 1.0, 10e9, required_scale)
|
| 104 |
+
ix = np.argmin(required_scale)
|
| 105 |
+
return candidate_tilings[ix]
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def build_resized_image(
|
| 109 |
+
image: np.ndarray,
|
| 110 |
+
base_image_input_size: list[int],
|
| 111 |
+
resample: PILImageResampling,
|
| 112 |
+
image_mean: list[float],
|
| 113 |
+
image_std: list[float],
|
| 114 |
+
image_patch_size: int,
|
| 115 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
| 116 |
+
resized = resize_image(
|
| 117 |
+
image, base_image_input_size, resample,
|
| 118 |
+
)
|
| 119 |
+
resized = normalize_image(resized, image_mean, image_std)
|
| 120 |
+
if len(resized.shape) == 3:
|
| 121 |
+
resized = np.expand_dims(resized, 0)
|
| 122 |
+
crop_patch_w = base_image_input_size[1] // image_patch_size
|
| 123 |
+
crop_patch_h = base_image_input_size[0] // image_patch_size
|
| 124 |
+
resize_idx = np.arange(crop_patch_w*crop_patch_h).reshape([crop_patch_h, crop_patch_w])
|
| 125 |
+
return resized, resize_idx
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def build_overlapping_crops(
|
| 129 |
+
image: np.ndarray,
|
| 130 |
+
max_crops: int,
|
| 131 |
+
overlap_margins: list[int],
|
| 132 |
+
base_image_input_size: list[int],
|
| 133 |
+
resample: PILImageResampling,
|
| 134 |
+
image_mean: list[float],
|
| 135 |
+
image_std: list[float],
|
| 136 |
+
image_patch_size: int,
|
| 137 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
| 138 |
+
"""Decompose an image into a set of overlapping crops
|
| 139 |
+
|
| 140 |
+
:return crop_arr: [n_crops, h, w, 3] The crops
|
| 141 |
+
:return patch_idx: [overlap_patch_h, overlap_patch_w] For each patch in the resized image
|
| 142 |
+
the crops were extracted from, what patch in `crop_arr` it corresponds to
|
| 143 |
+
"""
|
| 144 |
+
original_image_h, original_image_w = image.shape[:2]
|
| 145 |
+
crop_size = base_image_input_size[0]
|
| 146 |
+
assert base_image_input_size[0] == base_image_input_size[1]
|
| 147 |
+
|
| 148 |
+
left_margin, right_margin = overlap_margins
|
| 149 |
+
total_margin_pixels = image_patch_size * (right_margin + left_margin) # pixels removed per dim
|
| 150 |
+
crop_patches = base_image_input_size[0] // image_patch_size # patches per crop dim
|
| 151 |
+
crop_window_patches = crop_patches - (right_margin + left_margin) # usable patches
|
| 152 |
+
crop_window_size = crop_window_patches * image_patch_size
|
| 153 |
+
crop_patch_w = base_image_input_size[1] // image_patch_size
|
| 154 |
+
crop_patch_h = base_image_input_size[0] // image_patch_size
|
| 155 |
+
original_image_h, original_image_w = image.shape[:2]
|
| 156 |
+
crop_size = base_image_input_size[0]
|
| 157 |
+
|
| 158 |
+
# Decide how to tile the image, to account for the overlap margins we compute the tiling
|
| 159 |
+
# as if we had an image without the margins and were using a crop size without the margins
|
| 160 |
+
tiling = select_tiling(
|
| 161 |
+
original_image_h - total_margin_pixels,
|
| 162 |
+
original_image_w - total_margin_pixels,
|
| 163 |
+
crop_window_size,
|
| 164 |
+
max_crops,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
src = resize_image(
|
| 168 |
+
image,
|
| 169 |
+
[tiling[0]*crop_window_size+total_margin_pixels, tiling[1]*crop_window_size+total_margin_pixels],
|
| 170 |
+
resample,
|
| 171 |
+
)
|
| 172 |
+
src = normalize_image(src, image_mean, image_std)
|
| 173 |
+
|
| 174 |
+
# Now we have to split the image into crops, and track what patches came from
|
| 175 |
+
# where in `patch_idx_arr`
|
| 176 |
+
n_crops = tiling[0] * tiling[1]
|
| 177 |
+
crop_arr = np.zeros([n_crops, crop_size, crop_size, 3], dtype=src.dtype)
|
| 178 |
+
patch_idx_arr = np.zeros([n_crops, crop_patch_h, crop_patch_w], dtype=np.int32)
|
| 179 |
+
on_crop = 0
|
| 180 |
+
for i in range(tiling[0]):
|
| 181 |
+
# Slide over `src` by `crop_window_size` steps, but extract crops of size `crops_size`
|
| 182 |
+
# which results in overlapping crop windows
|
| 183 |
+
y0 = i*crop_window_size
|
| 184 |
+
for j in range(tiling[1]):
|
| 185 |
+
x0 = j*crop_window_size
|
| 186 |
+
crop_arr[on_crop] = src[y0:y0+crop_size, x0:x0+crop_size]
|
| 187 |
+
patch_idx = np.arange(crop_patch_w*crop_patch_h).reshape(crop_patch_h, crop_patch_w)
|
| 188 |
+
patch_idx += on_crop * crop_patch_h * crop_patch_w
|
| 189 |
+
|
| 190 |
+
# Mask out idx that are in the overlap region
|
| 191 |
+
if i != 0:
|
| 192 |
+
patch_idx[:left_margin, :] = -1
|
| 193 |
+
if j != 0:
|
| 194 |
+
patch_idx[:, :left_margin] = -1
|
| 195 |
+
if i != tiling[0]-1:
|
| 196 |
+
patch_idx[-right_margin:, :] = -1
|
| 197 |
+
if j != tiling[1]-1:
|
| 198 |
+
patch_idx[:, -right_margin:] = -1
|
| 199 |
+
patch_idx_arr[on_crop] = patch_idx
|
| 200 |
+
on_crop += 1
|
| 201 |
+
|
| 202 |
+
# `patch_idx_arr` is ordered crop-by-crop, here we transpose `patch_idx_arr`
|
| 203 |
+
# so it is ordered left-to-right order
|
| 204 |
+
patch_idx_arr = np.reshape(
|
| 205 |
+
patch_idx_arr,
|
| 206 |
+
[tiling[0], tiling[1], crop_patch_h, crop_patch_w]
|
| 207 |
+
)
|
| 208 |
+
patch_idx_arr = np.transpose(patch_idx_arr, [0, 2, 1, 3])
|
| 209 |
+
patch_idx_arr = np.reshape(patch_idx_arr, [-1])
|
| 210 |
+
|
| 211 |
+
# Now get the parts not in the overlap region, so it should map each patch in `src`
|
| 212 |
+
# to the correct patch it should come from in `crop_arr`
|
| 213 |
+
patch_idx_arr = patch_idx_arr[patch_idx_arr >= 0].reshape(
|
| 214 |
+
src.shape[0]//image_patch_size,
|
| 215 |
+
src.shape[1]//image_patch_size,
|
| 216 |
+
)
|
| 217 |
+
return crop_arr, patch_idx_arr
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
|
| 221 |
+
"""Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
|
| 222 |
+
if len(array.shape) == 3:
|
| 223 |
+
n_crops, h, w = array.shape
|
| 224 |
+
h_patches = h//patch_size
|
| 225 |
+
w_patches = w//patch_size
|
| 226 |
+
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
|
| 227 |
+
array = np.transpose(array, [0, 1, 3, 2, 4])
|
| 228 |
+
array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size])
|
| 229 |
+
return array
|
| 230 |
+
else:
|
| 231 |
+
n_crops, h, w, c = array.shape
|
| 232 |
+
h_patches = h//patch_size
|
| 233 |
+
w_patches = w//patch_size
|
| 234 |
+
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
|
| 235 |
+
array = np.transpose(array, [0, 1, 3, 2, 4, 5])
|
| 236 |
+
array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size*c])
|
| 237 |
+
return array
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def arange_for_pooling(
|
| 241 |
+
idx_arr: np.ndarray,
|
| 242 |
+
pool_h: int,
|
| 243 |
+
pool_w: int,
|
| 244 |
+
) -> np.ndarray:
|
| 245 |
+
h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0]
|
| 246 |
+
w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1]
|
| 247 |
+
idx_arr = np.pad(idx_arr, [[h_pad//2, (h_pad+1)//2], [w_pad//2, (w_pad+1)//2]],
|
| 248 |
+
mode='constant',constant_values=-1)
|
| 249 |
+
return einops.rearrange(
|
| 250 |
+
idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def image_to_patches_and_grids(
|
| 254 |
+
image: np.ndarray,
|
| 255 |
+
max_crops: int,
|
| 256 |
+
overlap_margins: list[int],
|
| 257 |
+
base_image_input_size: list[int],
|
| 258 |
+
resample: PILImageResampling,
|
| 259 |
+
image_mean: list[float],
|
| 260 |
+
image_std: list[float],
|
| 261 |
+
image_patch_size: int,
|
| 262 |
+
image_pooling_w: int,
|
| 263 |
+
image_pooling_h: int,
|
| 264 |
+
crop_mode: str = "overlap-and-resize-c2",
|
| 265 |
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 266 |
+
"""
|
| 267 |
+
:return image_grids, the shape of each (low-res, high-res) image after pooling
|
| 268 |
+
:return crops, the image crops to processes with the ViT
|
| 269 |
+
:return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the
|
| 270 |
+
patches in `crops` to pool for that token, masked with -1
|
| 271 |
+
"""
|
| 272 |
+
if isinstance(base_image_input_size, int):
|
| 273 |
+
base_image_input_size = (base_image_input_size, base_image_input_size)
|
| 274 |
+
|
| 275 |
+
base_image_input_d = image_patch_size
|
| 276 |
+
pooling_w = image_pooling_w
|
| 277 |
+
pooling_h = image_pooling_h
|
| 278 |
+
crop_patch_w = base_image_input_size[1] // base_image_input_d
|
| 279 |
+
crop_patch_h = base_image_input_size[0] // base_image_input_d
|
| 280 |
+
|
| 281 |
+
if crop_mode == "resize":
|
| 282 |
+
resized, resize_idx = build_resized_image(
|
| 283 |
+
image,
|
| 284 |
+
base_image_input_size,
|
| 285 |
+
resample,
|
| 286 |
+
image_mean,
|
| 287 |
+
image_std,
|
| 288 |
+
image_patch_size,
|
| 289 |
+
)
|
| 290 |
+
resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
|
| 291 |
+
resized_h, resized_w = resize_idx.shape[:2]
|
| 292 |
+
resize_idx = resize_idx.reshape([-1, pooling_h * pooling_w])
|
| 293 |
+
image_grid = [np.array([resized_h, resized_w, 0, 0])]
|
| 294 |
+
return (
|
| 295 |
+
np.stack(image_grid, 0),
|
| 296 |
+
batch_pixels_to_patches(resized, image_patch_size),
|
| 297 |
+
resize_idx,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
if crop_mode not in {"overlap-and-resize-c2", "overlap-and-resize"}:
|
| 301 |
+
raise ValueError(f"Unsupported MolmoAct2 image crop_mode {crop_mode!r}.")
|
| 302 |
+
|
| 303 |
+
crop_arr, patch_idx_arr = build_overlapping_crops(
|
| 304 |
+
image,
|
| 305 |
+
max_crops,
|
| 306 |
+
overlap_margins,
|
| 307 |
+
base_image_input_size,
|
| 308 |
+
resample,
|
| 309 |
+
image_mean,
|
| 310 |
+
image_std,
|
| 311 |
+
image_patch_size,
|
| 312 |
+
)
|
| 313 |
+
pooling_idx = arange_for_pooling(patch_idx_arr, pooling_h, pooling_w)
|
| 314 |
+
h, w = pooling_idx.shape[:2]
|
| 315 |
+
pooling_idx = pooling_idx.reshape([-1, pooling_h*pooling_w])
|
| 316 |
+
|
| 317 |
+
# Finally do the same for the global image
|
| 318 |
+
resized, resize_idx = build_resized_image(
|
| 319 |
+
image,
|
| 320 |
+
base_image_input_size,
|
| 321 |
+
resample,
|
| 322 |
+
image_mean,
|
| 323 |
+
image_std,
|
| 324 |
+
image_patch_size,
|
| 325 |
+
)
|
| 326 |
+
crop_arr = np.concatenate([resized, crop_arr], 0)
|
| 327 |
+
|
| 328 |
+
resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
|
| 329 |
+
resized_h, resized_w = resize_idx.shape[:2]
|
| 330 |
+
resize_idx = resize_idx.reshape([-1, pooling_h*pooling_w])
|
| 331 |
+
|
| 332 |
+
# Global image goes first, so the order of patches in previous crops gets increased
|
| 333 |
+
pooling_idx = np.where(
|
| 334 |
+
pooling_idx >= 0,
|
| 335 |
+
pooling_idx + crop_patch_h*crop_patch_w,
|
| 336 |
+
-1
|
| 337 |
+
)
|
| 338 |
+
pooling_idx = np.concatenate([resize_idx, pooling_idx])
|
| 339 |
+
image_grid = [np.array([resized_h, resized_w, h, w])]
|
| 340 |
+
|
| 341 |
+
return (
|
| 342 |
+
np.stack(image_grid, 0),
|
| 343 |
+
batch_pixels_to_patches(crop_arr, image_patch_size),
|
| 344 |
+
pooling_idx
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
class MolmoAct2ImagesKwargs(ImagesKwargs, total=False):
|
| 349 |
+
max_crops: Optional[int]
|
| 350 |
+
overlap_margins: Optional[list[int]]
|
| 351 |
+
crop_mode: Optional[str]
|
| 352 |
+
patch_size: Optional[int]
|
| 353 |
+
pooling_size: Optional[list[int]]
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
class MolmoAct2ImageProcessor(BaseImageProcessor):
|
| 357 |
+
r"""
|
| 358 |
+
Constructs a MolmoAct2 image processor that preprocesses images for the model.
|
| 359 |
+
|
| 360 |
+
Args:
|
| 361 |
+
size (`dict[str, int]` *optional*, defaults to `{"height": 378, "width": 378}`):
|
| 362 |
+
Size of the image after resizing.
|
| 363 |
+
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
| 364 |
+
Resampling filter to use when resizing the image.
|
| 365 |
+
image_mean (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
| 366 |
+
Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
| 367 |
+
image_std (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
| 368 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
| 369 |
+
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
| 370 |
+
Whether to convert the image to RGB.
|
| 371 |
+
max_crops (`int`, *optional*, defaults to `8`):
|
| 372 |
+
Maximum number of crops to use per image.
|
| 373 |
+
overlap_margins (`list[int]`, *optional*, defaults to `[4, 4]`):
|
| 374 |
+
Overlap margins to use.
|
| 375 |
+
patch_size (`int`, *optional*, defaults to 14):
|
| 376 |
+
The spatial patch size of the vision encoder.
|
| 377 |
+
pooling_size (`list[int]`, *optional*, defaults to `[2, 2]`):
|
| 378 |
+
The pooling size of the vision adapter.
|
| 379 |
+
"""
|
| 380 |
+
|
| 381 |
+
model_input_names = ["pixel_values", "image_token_pooling", "image_grids", "image_num_crops"]
|
| 382 |
+
|
| 383 |
+
def __init__(
|
| 384 |
+
self,
|
| 385 |
+
size: Optional[dict[str, int]] = None,
|
| 386 |
+
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
| 387 |
+
image_mean: Optional[Union[float, list[float]]] = None,
|
| 388 |
+
image_std: Optional[Union[float, list[float]]] = None,
|
| 389 |
+
do_convert_rgb: bool = True,
|
| 390 |
+
max_crops: int = 8,
|
| 391 |
+
overlap_margins: list[int] = [4, 4],
|
| 392 |
+
crop_mode: str = "overlap-and-resize-c2",
|
| 393 |
+
patch_size: int = 14,
|
| 394 |
+
pooling_size: list[int] = [2, 2],
|
| 395 |
+
**kwargs,
|
| 396 |
+
) -> None:
|
| 397 |
+
super().__init__(**kwargs)
|
| 398 |
+
size = size if size is not None else {"height": 378, "width": 378}
|
| 399 |
+
size = get_size_dict(size, default_to_square=True)
|
| 400 |
+
self.size = size
|
| 401 |
+
|
| 402 |
+
self.resample = resample
|
| 403 |
+
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
| 404 |
+
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
| 405 |
+
self.do_convert_rgb = do_convert_rgb
|
| 406 |
+
|
| 407 |
+
self.max_crops = max_crops
|
| 408 |
+
self.overlap_margins = overlap_margins
|
| 409 |
+
self.crop_mode = crop_mode
|
| 410 |
+
self.patch_size = patch_size
|
| 411 |
+
self.pooling_size = pooling_size
|
| 412 |
+
|
| 413 |
+
def preprocess(
|
| 414 |
+
self,
|
| 415 |
+
images: ImageInput,
|
| 416 |
+
size: Optional[dict[str, int]] = None,
|
| 417 |
+
resample: Optional[PILImageResampling] = None,
|
| 418 |
+
image_mean: Optional[Union[float, list[float]]] = None,
|
| 419 |
+
image_std: Optional[Union[float, list[float]]] = None,
|
| 420 |
+
do_convert_rgb: Optional[bool] = None,
|
| 421 |
+
max_crops: Optional[int] = None,
|
| 422 |
+
overlap_margins: Optional[list[int]] = None,
|
| 423 |
+
crop_mode: Optional[str] = None,
|
| 424 |
+
patch_size: Optional[int] = None,
|
| 425 |
+
pooling_size: Optional[list[int]] = None,
|
| 426 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 427 |
+
**kwargs,
|
| 428 |
+
) -> BatchFeature:
|
| 429 |
+
"""
|
| 430 |
+
Args:
|
| 431 |
+
images (`ImageInput`):
|
| 432 |
+
Image to preprocess.
|
| 433 |
+
size (`dict[str, int]`, *optional*, defaults to `self.size`):
|
| 434 |
+
Size of the image after resizing.
|
| 435 |
+
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
| 436 |
+
Resampling filter to use when resizing the image. This can be one of the enum `PILImageResampling`. Only
|
| 437 |
+
has an effect if `do_resize` is set to `True`.
|
| 438 |
+
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
|
| 439 |
+
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
| 440 |
+
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
|
| 441 |
+
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
| 442 |
+
`True`.
|
| 443 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
| 444 |
+
Whether to convert the image to RGB.
|
| 445 |
+
max_crops (`int`, *optional*, defaults to `self.max_crops`):
|
| 446 |
+
Maximum number of crops to use per image.
|
| 447 |
+
overlap_margins (`list[int]`, *optional*, defaults to `self.overlap_margins`):
|
| 448 |
+
Overlap margins to use.
|
| 449 |
+
patch_size (`int`, *optional*, defaults to `self.patch_size`):
|
| 450 |
+
The spatial patch size of the vision encoder.
|
| 451 |
+
pooling_size (`list[int]`, *optional*, defaults to `self.pooling_size`):
|
| 452 |
+
The pooling size of the vision adapter.
|
| 453 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 454 |
+
The type of tensors to return. Can be one of:
|
| 455 |
+
- Unset: Return a list of `np.ndarray`.
|
| 456 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 457 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 458 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 459 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 460 |
+
|
| 461 |
+
Returns:
|
| 462 |
+
A `BatchFeature` containing the following keys:
|
| 463 |
+
- `pixel_values`: The preprocessed images.
|
| 464 |
+
- `image_token_pooling`: The indices of the patches in `crops` to pool for each token in `image_tokens`.
|
| 465 |
+
- `image_grids`: The image grids.
|
| 466 |
+
- `image_num_crops`: The number of crops for each image.
|
| 467 |
+
"""
|
| 468 |
+
if size is not None:
|
| 469 |
+
if "height" not in size or "width" not in size:
|
| 470 |
+
raise ValueError("size must contain 'height' and 'width' keys.")
|
| 471 |
+
else:
|
| 472 |
+
size = {**self.size}
|
| 473 |
+
|
| 474 |
+
base_image_input_size = [size["height"], size["width"]]
|
| 475 |
+
|
| 476 |
+
resample = resample or self.resample
|
| 477 |
+
image_mean = image_mean or self.image_mean
|
| 478 |
+
image_std = image_std or self.image_std
|
| 479 |
+
do_convert_rgb = do_convert_rgb or self.do_convert_rgb
|
| 480 |
+
|
| 481 |
+
max_crops = max_crops or self.max_crops
|
| 482 |
+
overlap_margins = overlap_margins or self.overlap_margins
|
| 483 |
+
crop_mode = crop_mode or self.crop_mode
|
| 484 |
+
patch_size = patch_size or self.patch_size
|
| 485 |
+
pooling_size = pooling_size or self.pooling_size
|
| 486 |
+
|
| 487 |
+
image_pooling_h, image_pooling_w = pooling_size
|
| 488 |
+
|
| 489 |
+
if images is not None:
|
| 490 |
+
images = self.fetch_images(images)
|
| 491 |
+
images = make_flat_list_of_images(images)
|
| 492 |
+
|
| 493 |
+
if images is not None and not valid_images(images):
|
| 494 |
+
raise ValueError(
|
| 495 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 496 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
if do_convert_rgb:
|
| 500 |
+
images = [convert_to_rgb(image) for image in images]
|
| 501 |
+
|
| 502 |
+
# All transformations expect numpy arrays.
|
| 503 |
+
images = [to_numpy_array(image) for image in images]
|
| 504 |
+
|
| 505 |
+
data = {}
|
| 506 |
+
if images is not None:
|
| 507 |
+
batch_grids = []
|
| 508 |
+
batch_crops = []
|
| 509 |
+
batch_pooled_patches_idx = []
|
| 510 |
+
batch_num_crops = []
|
| 511 |
+
|
| 512 |
+
for image in images:
|
| 513 |
+
image_grid, crops, pooled_idx = image_to_patches_and_grids(
|
| 514 |
+
image,
|
| 515 |
+
max_crops,
|
| 516 |
+
overlap_margins,
|
| 517 |
+
base_image_input_size,
|
| 518 |
+
resample,
|
| 519 |
+
image_mean,
|
| 520 |
+
image_std,
|
| 521 |
+
patch_size,
|
| 522 |
+
image_pooling_w,
|
| 523 |
+
image_pooling_h,
|
| 524 |
+
crop_mode,
|
| 525 |
+
)
|
| 526 |
+
batch_grids.append(image_grid)
|
| 527 |
+
batch_crops.append(crops)
|
| 528 |
+
batch_pooled_patches_idx.append(pooled_idx)
|
| 529 |
+
batch_num_crops.append(crops.shape[0])
|
| 530 |
+
|
| 531 |
+
pixel_values = np.concatenate(batch_crops, 0)
|
| 532 |
+
image_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
|
| 533 |
+
image_grids = np.concatenate(batch_grids, 0)
|
| 534 |
+
image_num_crops = np.array(batch_num_crops)
|
| 535 |
+
|
| 536 |
+
data.update(
|
| 537 |
+
pixel_values=pixel_values,
|
| 538 |
+
image_token_pooling=image_token_pooling,
|
| 539 |
+
image_grids=image_grids,
|
| 540 |
+
image_num_crops=image_num_crops,
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
return BatchFeature(data, tensor_type=return_tensors)
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
MolmoAct2ImageProcessor.register_for_auto_class()
|
inference.py
ADDED
|
@@ -0,0 +1,768 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Inference utilities for MolmoAct2"""
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Iterable, Optional, Sequence, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
from transformers.cache_utils import Cache
|
| 9 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class _ActionFlowInputs:
|
| 14 |
+
trajectory: torch.Tensor
|
| 15 |
+
context: Any
|
| 16 |
+
modulations: Sequence[Any]
|
| 17 |
+
action_dim_is_pad: Optional[torch.Tensor]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class _ActionFlowCudaGraph:
|
| 22 |
+
key: Tuple[Any, ...]
|
| 23 |
+
graph: torch.cuda.CUDAGraph
|
| 24 |
+
static_inputs: _ActionFlowInputs
|
| 25 |
+
output: torch.Tensor
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class _DepthDecodeCudaGraphLayerStage:
|
| 30 |
+
residual: torch.Tensor
|
| 31 |
+
query: torch.Tensor
|
| 32 |
+
key: torch.Tensor
|
| 33 |
+
value: torch.Tensor
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class _DepthDecodeCudaGraphPostStage:
|
| 38 |
+
graph: torch.cuda.CUDAGraph
|
| 39 |
+
attn_context: torch.Tensor
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class _DepthDecodeCudaGraph:
|
| 44 |
+
cache_key: Tuple[Any, ...]
|
| 45 |
+
pre_graph: torch.cuda.CUDAGraph
|
| 46 |
+
token_ids: torch.Tensor
|
| 47 |
+
cos: torch.Tensor
|
| 48 |
+
sin: torch.Tensor
|
| 49 |
+
positions: torch.Tensor
|
| 50 |
+
stages: Sequence[_DepthDecodeCudaGraphLayerStage]
|
| 51 |
+
post_graphs: Sequence[_DepthDecodeCudaGraphPostStage]
|
| 52 |
+
output: torch.Tensor
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@dataclass
|
| 56 |
+
class _DepthDecodeCudaGraphSpec:
|
| 57 |
+
eligible: bool
|
| 58 |
+
cache_key_prefix: Tuple[Any, ...]
|
| 59 |
+
num_hidden_layers: int
|
| 60 |
+
head_dim: int
|
| 61 |
+
num_attention_heads: int
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _cache_seq_len_int(past_key_values: Optional[Cache]) -> int:
|
| 65 |
+
if past_key_values is None:
|
| 66 |
+
return 0
|
| 67 |
+
seq_len = past_key_values.get_seq_length()
|
| 68 |
+
if torch.is_tensor(seq_len):
|
| 69 |
+
return int(seq_len.item())
|
| 70 |
+
return int(seq_len)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _cache_max_len_int(past_key_values: Optional[Cache]) -> int:
|
| 74 |
+
if past_key_values is None:
|
| 75 |
+
return -1
|
| 76 |
+
max_len = past_key_values.get_max_cache_shape()
|
| 77 |
+
if torch.is_tensor(max_len):
|
| 78 |
+
return int(max_len.item())
|
| 79 |
+
return int(max_len)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _iter_cache_key_values(
|
| 83 |
+
past_key_values: Cache,
|
| 84 |
+
) -> Iterable[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]]:
|
| 85 |
+
layers = getattr(past_key_values, "layers", None)
|
| 86 |
+
if layers is not None:
|
| 87 |
+
for layer in layers:
|
| 88 |
+
yield getattr(layer, "keys", None), getattr(layer, "values", None)
|
| 89 |
+
return
|
| 90 |
+
for layer in past_key_values:
|
| 91 |
+
yield layer[0], layer[1]
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class _DepthDecodeStaticLayerCache:
|
| 95 |
+
is_compileable = False
|
| 96 |
+
is_sliding = False
|
| 97 |
+
|
| 98 |
+
def __init__(self, max_cache_len: int) -> None:
|
| 99 |
+
self.max_cache_len = int(max_cache_len)
|
| 100 |
+
self.cumulative_length = 0
|
| 101 |
+
self.keys: Optional[torch.Tensor] = None
|
| 102 |
+
self.values: Optional[torch.Tensor] = None
|
| 103 |
+
|
| 104 |
+
def _allocate(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
|
| 105 |
+
bsz, n_heads = key_states.shape[:2]
|
| 106 |
+
self.keys = torch.empty(
|
| 107 |
+
(bsz, n_heads, self.max_cache_len, key_states.shape[-1]),
|
| 108 |
+
dtype=key_states.dtype,
|
| 109 |
+
device=key_states.device,
|
| 110 |
+
)
|
| 111 |
+
self.values = torch.empty(
|
| 112 |
+
(bsz, n_heads, self.max_cache_len, value_states.shape[-1]),
|
| 113 |
+
dtype=value_states.dtype,
|
| 114 |
+
device=value_states.device,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def update(
|
| 118 |
+
self,
|
| 119 |
+
key_states: torch.Tensor,
|
| 120 |
+
value_states: torch.Tensor,
|
| 121 |
+
*args,
|
| 122 |
+
**kwargs,
|
| 123 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 124 |
+
if self.keys is None:
|
| 125 |
+
self._allocate(key_states, value_states)
|
| 126 |
+
start = self.cumulative_length
|
| 127 |
+
end = start + key_states.shape[-2]
|
| 128 |
+
if end > self.max_cache_len:
|
| 129 |
+
raise RuntimeError(
|
| 130 |
+
f"KV cache length {end} exceeds max_cache_len={self.max_cache_len}."
|
| 131 |
+
)
|
| 132 |
+
self.keys[:, :, start:end, :].copy_(key_states)
|
| 133 |
+
self.values[:, :, start:end, :].copy_(value_states)
|
| 134 |
+
self.cumulative_length = end
|
| 135 |
+
return self.keys[:, :, :end, :], self.values[:, :, :end, :]
|
| 136 |
+
|
| 137 |
+
def get_seq_length(self) -> int:
|
| 138 |
+
return self.cumulative_length
|
| 139 |
+
|
| 140 |
+
def get_max_cache_shape(self) -> int:
|
| 141 |
+
return -1
|
| 142 |
+
|
| 143 |
+
def reset(self) -> None:
|
| 144 |
+
self.cumulative_length = 0
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class _DepthDecodeStaticCache(Cache):
|
| 148 |
+
def __init__(self, config: PretrainedConfig, max_cache_len: int) -> None:
|
| 149 |
+
text_config = config.get_text_config(decoder=True)
|
| 150 |
+
super().__init__(
|
| 151 |
+
layers=[
|
| 152 |
+
_DepthDecodeStaticLayerCache(max_cache_len=max_cache_len)
|
| 153 |
+
for _ in range(text_config.num_hidden_layers)
|
| 154 |
+
]
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
def get_seq_length(self, layer_idx: int = 0) -> int:
|
| 158 |
+
return self.layers[layer_idx].get_seq_length()
|
| 159 |
+
|
| 160 |
+
def get_max_cache_shape(self, layer_idx: int = 0) -> int:
|
| 161 |
+
return self.layers[layer_idx].get_max_cache_shape()
|
| 162 |
+
|
| 163 |
+
def reset(self) -> None:
|
| 164 |
+
for layer in self.layers:
|
| 165 |
+
layer.reset()
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class ActionCudaGraphManager:
|
| 169 |
+
def __init__(self, model: Any) -> None:
|
| 170 |
+
self.model = model
|
| 171 |
+
self.enabled = True
|
| 172 |
+
self.action_flow_graph: Optional[_ActionFlowCudaGraph] = None
|
| 173 |
+
|
| 174 |
+
def set_enabled(self, enabled: bool) -> None:
|
| 175 |
+
self.enabled = bool(enabled)
|
| 176 |
+
|
| 177 |
+
def can_use_action_flow(self, inputs: _ActionFlowInputs) -> bool:
|
| 178 |
+
action_model = self.model
|
| 179 |
+
if not self.enabled:
|
| 180 |
+
return False
|
| 181 |
+
if action_model.training or action_model._require_action_expert().training:
|
| 182 |
+
return False
|
| 183 |
+
if inputs.trajectory.device.type != "cuda":
|
| 184 |
+
return False
|
| 185 |
+
|
| 186 |
+
def all_on_cuda():
|
| 187 |
+
yield inputs.trajectory
|
| 188 |
+
for k, v in inputs.context.kv_contexts:
|
| 189 |
+
yield k
|
| 190 |
+
yield v
|
| 191 |
+
for t in (
|
| 192 |
+
inputs.context.cross_mask,
|
| 193 |
+
inputs.context.self_mask,
|
| 194 |
+
inputs.context.valid_action,
|
| 195 |
+
inputs.action_dim_is_pad,
|
| 196 |
+
):
|
| 197 |
+
if t is not None:
|
| 198 |
+
yield t
|
| 199 |
+
if inputs.context.rope_cache is not None:
|
| 200 |
+
yield from inputs.context.rope_cache
|
| 201 |
+
for step in inputs.modulations:
|
| 202 |
+
yield step.conditioning
|
| 203 |
+
for block_modulation in step.block_modulations:
|
| 204 |
+
yield from block_modulation
|
| 205 |
+
yield from step.final_modulation
|
| 206 |
+
|
| 207 |
+
return all(t.device.type == "cuda" for t in all_on_cuda())
|
| 208 |
+
|
| 209 |
+
def run_action_flow(
|
| 210 |
+
self,
|
| 211 |
+
inputs: _ActionFlowInputs,
|
| 212 |
+
steps: int,
|
| 213 |
+
run_loop,
|
| 214 |
+
) -> torch.Tensor:
|
| 215 |
+
key = _cuda_graph_key(inputs, steps)
|
| 216 |
+
cache = self.action_flow_graph
|
| 217 |
+
if cache is None or cache.key != key:
|
| 218 |
+
static_inputs = _clone_static_inputs(inputs)
|
| 219 |
+
graph, output = _capture_cuda_graph(
|
| 220 |
+
lambda: run_loop(static_inputs, steps),
|
| 221 |
+
inputs.trajectory.device,
|
| 222 |
+
after_warmup=lambda: static_inputs.trajectory.copy_(inputs.trajectory),
|
| 223 |
+
)
|
| 224 |
+
cache = _ActionFlowCudaGraph(
|
| 225 |
+
key=key,
|
| 226 |
+
graph=graph,
|
| 227 |
+
static_inputs=static_inputs,
|
| 228 |
+
output=output,
|
| 229 |
+
)
|
| 230 |
+
self.action_flow_graph = cache
|
| 231 |
+
else:
|
| 232 |
+
_copy_inputs_(cache.static_inputs, inputs)
|
| 233 |
+
|
| 234 |
+
cache.graph.replay()
|
| 235 |
+
return cache.output.clone()
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class DepthDecodeCudaGraphManager:
|
| 239 |
+
def __init__(self, model: Any) -> None:
|
| 240 |
+
self.model = model
|
| 241 |
+
self.backbone = model.model
|
| 242 |
+
self.enabled = True
|
| 243 |
+
self.graph: Optional[_DepthDecodeCudaGraph] = None
|
| 244 |
+
self.graph_spec: Optional[_DepthDecodeCudaGraphSpec] = None
|
| 245 |
+
|
| 246 |
+
def set_enabled(self, enabled: bool) -> None:
|
| 247 |
+
self.enabled = bool(enabled)
|
| 248 |
+
|
| 249 |
+
def make_static_cache(self, max_cache_len: int) -> _DepthDecodeStaticCache:
|
| 250 |
+
return _DepthDecodeStaticCache(
|
| 251 |
+
config=self.model.config.text_config,
|
| 252 |
+
max_cache_len=max_cache_len,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
def _depth_decode_spec(self) -> _DepthDecodeCudaGraphSpec:
|
| 256 |
+
static = self.graph_spec
|
| 257 |
+
if static is None:
|
| 258 |
+
cfg = self.backbone.transformer.config
|
| 259 |
+
rotary_emb = getattr(self.backbone.transformer, "rotary_emb", None)
|
| 260 |
+
static = _DepthDecodeCudaGraphSpec(
|
| 261 |
+
eligible=(
|
| 262 |
+
not cfg.norm_after
|
| 263 |
+
and cfg.rope_scaling_layers is None
|
| 264 |
+
and getattr(rotary_emb, "rope_type", None) == "default"
|
| 265 |
+
and cfg._attn_implementation == "sdpa"
|
| 266 |
+
),
|
| 267 |
+
cache_key_prefix=(
|
| 268 |
+
cfg.hidden_size,
|
| 269 |
+
cfg.num_attention_heads,
|
| 270 |
+
cfg.num_key_value_heads,
|
| 271 |
+
cfg.head_dim,
|
| 272 |
+
cfg.num_hidden_layers,
|
| 273 |
+
cfg.use_qk_norm,
|
| 274 |
+
cfg.qk_norm_type,
|
| 275 |
+
cfg._attn_implementation,
|
| 276 |
+
),
|
| 277 |
+
num_hidden_layers=cfg.num_hidden_layers,
|
| 278 |
+
head_dim=cfg.head_dim,
|
| 279 |
+
num_attention_heads=cfg.num_attention_heads,
|
| 280 |
+
)
|
| 281 |
+
self.graph_spec = static
|
| 282 |
+
return static
|
| 283 |
+
|
| 284 |
+
def can_use(
|
| 285 |
+
self,
|
| 286 |
+
next_input_ids: torch.Tensor,
|
| 287 |
+
*,
|
| 288 |
+
past_key_values: Cache,
|
| 289 |
+
attention_bias: torch.Tensor,
|
| 290 |
+
) -> bool:
|
| 291 |
+
if (
|
| 292 |
+
not self.enabled
|
| 293 |
+
or self.model.training
|
| 294 |
+
or self.backbone.transformer.training
|
| 295 |
+
):
|
| 296 |
+
return False
|
| 297 |
+
if next_input_ids.device.type != "cuda":
|
| 298 |
+
return False
|
| 299 |
+
if (
|
| 300 |
+
next_input_ids.ndim != 2
|
| 301 |
+
or next_input_ids.shape[0] != 1
|
| 302 |
+
or next_input_ids.shape[1] != 1
|
| 303 |
+
):
|
| 304 |
+
return False
|
| 305 |
+
if not isinstance(past_key_values, _DepthDecodeStaticCache):
|
| 306 |
+
return False
|
| 307 |
+
if (
|
| 308 |
+
not torch.is_tensor(attention_bias)
|
| 309 |
+
or attention_bias.device != next_input_ids.device
|
| 310 |
+
):
|
| 311 |
+
return False
|
| 312 |
+
return self._depth_decode_spec().eligible
|
| 313 |
+
|
| 314 |
+
def _depth_decode_key(
|
| 315 |
+
self,
|
| 316 |
+
next_input_ids: torch.Tensor,
|
| 317 |
+
attention_bias: torch.Tensor,
|
| 318 |
+
) -> Tuple[Any, ...]:
|
| 319 |
+
device = next_input_ids.device
|
| 320 |
+
return (
|
| 321 |
+
self._depth_decode_spec().cache_key_prefix,
|
| 322 |
+
device.type,
|
| 323 |
+
device.index,
|
| 324 |
+
self.model.lm_head.weight.dtype,
|
| 325 |
+
attention_bias.shape[-1],
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
def _select_depth_decode_rope(
|
| 329 |
+
self, cos: torch.Tensor, sin: torch.Tensor, *, past_length: int
|
| 330 |
+
) -> None:
|
| 331 |
+
emb = self.backbone.transformer.rotary_emb
|
| 332 |
+
cos.copy_(emb._pos_cos_cache[0, :, past_length : past_length + 1, :])
|
| 333 |
+
sin.copy_(emb._pos_sin_cache[0, :, past_length : past_length + 1, :])
|
| 334 |
+
|
| 335 |
+
def _depth_decode_pre_layer(
|
| 336 |
+
self,
|
| 337 |
+
layer_idx: int,
|
| 338 |
+
hidden_states: torch.Tensor,
|
| 339 |
+
cos: torch.Tensor,
|
| 340 |
+
sin: torch.Tensor,
|
| 341 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 342 |
+
block = self.backbone.transformer.blocks[layer_idx]
|
| 343 |
+
attention = block.self_attn
|
| 344 |
+
residual = hidden_states
|
| 345 |
+
hidden_states = block.attn_norm(hidden_states)
|
| 346 |
+
|
| 347 |
+
input_shape = hidden_states.shape[:-1]
|
| 348 |
+
hidden_shape = (*input_shape, -1, attention.head_dim)
|
| 349 |
+
qkv = attention.att_proj(hidden_states)
|
| 350 |
+
query_states, key_states, value_states = qkv.split(attention.fused_dims, dim=-1)
|
| 351 |
+
value_states = value_states.view(hidden_shape)
|
| 352 |
+
|
| 353 |
+
apply_qk_norm = attention.q_norm is not None and attention.k_norm is not None
|
| 354 |
+
norm_after_view = apply_qk_norm and attention.qk_norm_type == "qwen3"
|
| 355 |
+
|
| 356 |
+
if apply_qk_norm and not norm_after_view:
|
| 357 |
+
query_states = attention.q_norm(query_states)
|
| 358 |
+
key_states = attention.k_norm(key_states)
|
| 359 |
+
|
| 360 |
+
query_states = query_states.view(hidden_shape)
|
| 361 |
+
key_states = key_states.view(hidden_shape)
|
| 362 |
+
|
| 363 |
+
if norm_after_view:
|
| 364 |
+
query_states = attention.q_norm(query_states)
|
| 365 |
+
key_states = attention.k_norm(key_states)
|
| 366 |
+
|
| 367 |
+
query_states = query_states.transpose(1, 2)
|
| 368 |
+
key_states = key_states.transpose(1, 2)
|
| 369 |
+
value_states = value_states.transpose(1, 2)
|
| 370 |
+
query_states, key_states = _apply_rotary_pos_emb(
|
| 371 |
+
query_states, key_states, cos, sin
|
| 372 |
+
)
|
| 373 |
+
return residual, query_states, key_states, value_states
|
| 374 |
+
|
| 375 |
+
def _depth_decode_pre0(
|
| 376 |
+
self,
|
| 377 |
+
token_ids: torch.Tensor,
|
| 378 |
+
cos: torch.Tensor,
|
| 379 |
+
sin: torch.Tensor,
|
| 380 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 381 |
+
inputs_embeds = self.model._embed_base_tokens(token_ids)
|
| 382 |
+
return self._depth_decode_pre_layer(0, inputs_embeds, cos, sin)
|
| 383 |
+
|
| 384 |
+
def _depth_decode_post_layer(
|
| 385 |
+
self,
|
| 386 |
+
layer_idx: int,
|
| 387 |
+
residual: torch.Tensor,
|
| 388 |
+
attn_context: torch.Tensor,
|
| 389 |
+
) -> torch.Tensor:
|
| 390 |
+
block = self.backbone.transformer.blocks[layer_idx]
|
| 391 |
+
attention = block.self_attn
|
| 392 |
+
input_shape = residual.shape[:-1]
|
| 393 |
+
attn_output = attn_context.reshape(*input_shape, -1).contiguous()
|
| 394 |
+
attn_output = attention.attn_out(attn_output)
|
| 395 |
+
hidden_states = residual + block.dropout(attn_output)
|
| 396 |
+
|
| 397 |
+
residual = hidden_states
|
| 398 |
+
hidden_states = block.ff_norm(hidden_states)
|
| 399 |
+
hidden_states = block.mlp(hidden_states)
|
| 400 |
+
hidden_states = residual + block.dropout(hidden_states)
|
| 401 |
+
return hidden_states
|
| 402 |
+
|
| 403 |
+
def _depth_decode_post_and_pre_next(
|
| 404 |
+
self,
|
| 405 |
+
layer_idx: int,
|
| 406 |
+
residual: torch.Tensor,
|
| 407 |
+
attn_context: torch.Tensor,
|
| 408 |
+
cos: torch.Tensor,
|
| 409 |
+
sin: torch.Tensor,
|
| 410 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 411 |
+
hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
|
| 412 |
+
return self._depth_decode_pre_layer(layer_idx + 1, hidden_states, cos, sin)
|
| 413 |
+
|
| 414 |
+
def _depth_decode_last_post(
|
| 415 |
+
self,
|
| 416 |
+
layer_idx: int,
|
| 417 |
+
residual: torch.Tensor,
|
| 418 |
+
attn_context: torch.Tensor,
|
| 419 |
+
) -> torch.Tensor:
|
| 420 |
+
hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
|
| 421 |
+
return self.backbone.transformer.ln_f(hidden_states)
|
| 422 |
+
|
| 423 |
+
def _build_depth_decode_graph(
|
| 424 |
+
self,
|
| 425 |
+
next_input_ids: torch.Tensor,
|
| 426 |
+
*,
|
| 427 |
+
past_length: int,
|
| 428 |
+
attention_bias: torch.Tensor,
|
| 429 |
+
) -> _DepthDecodeCudaGraph:
|
| 430 |
+
text_config = self.backbone.transformer.config
|
| 431 |
+
device = next_input_ids.device
|
| 432 |
+
dtype = self.model.lm_head.weight.dtype
|
| 433 |
+
static = self._depth_decode_spec()
|
| 434 |
+
num_layers = static.num_hidden_layers
|
| 435 |
+
head_dim = static.head_dim
|
| 436 |
+
max_cache_len = int(attention_bias.shape[-1])
|
| 437 |
+
max_rope_len = max(int(text_config.max_position_embeddings or 0), max_cache_len)
|
| 438 |
+
self.backbone.transformer.prepare_rope_cache(
|
| 439 |
+
device=device, max_seq_len=max_rope_len
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
token_ids = torch.empty((1, 1), device=device, dtype=torch.long)
|
| 443 |
+
cos = torch.empty((1, 1, head_dim), device=device, dtype=dtype)
|
| 444 |
+
sin = torch.empty_like(cos)
|
| 445 |
+
positions = torch.arange(max_cache_len, device=device, dtype=torch.long)
|
| 446 |
+
context_shape = (1, 1, static.num_attention_heads, head_dim)
|
| 447 |
+
|
| 448 |
+
token_ids.copy_(next_input_ids)
|
| 449 |
+
self._select_depth_decode_rope(cos, sin, past_length=past_length)
|
| 450 |
+
|
| 451 |
+
pre_graph, pre_output = _capture_cuda_graph(
|
| 452 |
+
lambda: self._depth_decode_pre0(token_ids, cos, sin),
|
| 453 |
+
device,
|
| 454 |
+
)
|
| 455 |
+
stages = [_DepthDecodeCudaGraphLayerStage(*pre_output)]
|
| 456 |
+
post_graphs = []
|
| 457 |
+
for layer_idx in range(num_layers - 1):
|
| 458 |
+
stage = stages[-1]
|
| 459 |
+
attn_context = torch.empty(context_shape, device=device, dtype=dtype)
|
| 460 |
+
graph, output = _capture_cuda_graph(
|
| 461 |
+
lambda layer_idx=layer_idx, stage=stage, attn_context=attn_context: (
|
| 462 |
+
self._depth_decode_post_and_pre_next(
|
| 463 |
+
layer_idx,
|
| 464 |
+
stage.residual,
|
| 465 |
+
attn_context,
|
| 466 |
+
cos,
|
| 467 |
+
sin,
|
| 468 |
+
)
|
| 469 |
+
),
|
| 470 |
+
device,
|
| 471 |
+
)
|
| 472 |
+
post_graphs.append(
|
| 473 |
+
_DepthDecodeCudaGraphPostStage(graph=graph, attn_context=attn_context)
|
| 474 |
+
)
|
| 475 |
+
stages.append(_DepthDecodeCudaGraphLayerStage(*output))
|
| 476 |
+
|
| 477 |
+
last_stage = stages[-1]
|
| 478 |
+
last_attn_context = torch.empty(context_shape, device=device, dtype=dtype)
|
| 479 |
+
last_graph, last_output = _capture_cuda_graph(
|
| 480 |
+
lambda: self._depth_decode_last_post(
|
| 481 |
+
num_layers - 1,
|
| 482 |
+
last_stage.residual,
|
| 483 |
+
last_attn_context,
|
| 484 |
+
),
|
| 485 |
+
device,
|
| 486 |
+
)
|
| 487 |
+
post_graphs.append(
|
| 488 |
+
_DepthDecodeCudaGraphPostStage(
|
| 489 |
+
graph=last_graph, attn_context=last_attn_context
|
| 490 |
+
)
|
| 491 |
+
)
|
| 492 |
+
return _DepthDecodeCudaGraph(
|
| 493 |
+
cache_key=self._depth_decode_key(next_input_ids, attention_bias),
|
| 494 |
+
pre_graph=pre_graph,
|
| 495 |
+
token_ids=token_ids,
|
| 496 |
+
cos=cos,
|
| 497 |
+
sin=sin,
|
| 498 |
+
positions=positions,
|
| 499 |
+
stages=tuple(stages),
|
| 500 |
+
post_graphs=tuple(post_graphs),
|
| 501 |
+
output=last_output,
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
def _get_depth_decode_graph(
|
| 505 |
+
self,
|
| 506 |
+
next_input_ids: torch.Tensor,
|
| 507 |
+
*,
|
| 508 |
+
past_length: int,
|
| 509 |
+
attention_bias: torch.Tensor,
|
| 510 |
+
) -> _DepthDecodeCudaGraph:
|
| 511 |
+
key = self._depth_decode_key(next_input_ids, attention_bias)
|
| 512 |
+
decode_graph = self.graph
|
| 513 |
+
if decode_graph is None or decode_graph.cache_key != key:
|
| 514 |
+
decode_graph = self._build_depth_decode_graph(
|
| 515 |
+
next_input_ids,
|
| 516 |
+
past_length=past_length,
|
| 517 |
+
attention_bias=attention_bias,
|
| 518 |
+
)
|
| 519 |
+
self.graph = decode_graph
|
| 520 |
+
else:
|
| 521 |
+
decode_graph.token_ids.copy_(next_input_ids)
|
| 522 |
+
self._select_depth_decode_rope(
|
| 523 |
+
decode_graph.cos, decode_graph.sin, past_length=past_length
|
| 524 |
+
)
|
| 525 |
+
return decode_graph
|
| 526 |
+
|
| 527 |
+
def _run_depth_decode_attention_core(
|
| 528 |
+
self,
|
| 529 |
+
layer_idx: int,
|
| 530 |
+
stage: _DepthDecodeCudaGraphLayerStage,
|
| 531 |
+
*,
|
| 532 |
+
past_key_values: Cache,
|
| 533 |
+
attention_bias: torch.Tensor,
|
| 534 |
+
cache_position: torch.Tensor,
|
| 535 |
+
cos: torch.Tensor,
|
| 536 |
+
sin: torch.Tensor,
|
| 537 |
+
) -> torch.Tensor:
|
| 538 |
+
attention = self.backbone.transformer.blocks[layer_idx].self_attn
|
| 539 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 540 |
+
key_states, value_states = past_key_values.update(
|
| 541 |
+
stage.key,
|
| 542 |
+
stage.value,
|
| 543 |
+
layer_idx,
|
| 544 |
+
cache_kwargs,
|
| 545 |
+
)
|
| 546 |
+
key_states = _repeat_kv(key_states, attention.num_key_value_groups)
|
| 547 |
+
value_states = _repeat_kv(value_states, attention.num_key_value_groups)
|
| 548 |
+
attn_output = F.scaled_dot_product_attention(
|
| 549 |
+
stage.query,
|
| 550 |
+
key_states,
|
| 551 |
+
value_states,
|
| 552 |
+
attn_mask=attention_bias,
|
| 553 |
+
dropout_p=0.0,
|
| 554 |
+
is_causal=False,
|
| 555 |
+
)
|
| 556 |
+
return attn_output.transpose(1, 2)
|
| 557 |
+
|
| 558 |
+
def run(
|
| 559 |
+
self,
|
| 560 |
+
next_input_ids: torch.Tensor,
|
| 561 |
+
*,
|
| 562 |
+
past_key_values: Cache,
|
| 563 |
+
attention_bias: torch.Tensor,
|
| 564 |
+
past_length: int,
|
| 565 |
+
) -> Tuple[torch.Tensor, Cache]:
|
| 566 |
+
end = past_length + 1
|
| 567 |
+
decode_graph = self._get_depth_decode_graph(
|
| 568 |
+
next_input_ids,
|
| 569 |
+
past_length=past_length,
|
| 570 |
+
attention_bias=attention_bias,
|
| 571 |
+
)
|
| 572 |
+
cache_position = decode_graph.positions[past_length:end]
|
| 573 |
+
attention_bias_q = attention_bias[:, :, past_length:end, :end]
|
| 574 |
+
|
| 575 |
+
decode_graph.pre_graph.replay()
|
| 576 |
+
|
| 577 |
+
for layer_idx, post_graph in enumerate(decode_graph.post_graphs):
|
| 578 |
+
attn_context = self._run_depth_decode_attention_core(
|
| 579 |
+
layer_idx,
|
| 580 |
+
decode_graph.stages[layer_idx],
|
| 581 |
+
past_key_values=past_key_values,
|
| 582 |
+
attention_bias=attention_bias_q,
|
| 583 |
+
cache_position=cache_position,
|
| 584 |
+
cos=decode_graph.cos,
|
| 585 |
+
sin=decode_graph.sin,
|
| 586 |
+
)
|
| 587 |
+
post_graph.attn_context.copy_(attn_context)
|
| 588 |
+
post_graph.graph.replay()
|
| 589 |
+
|
| 590 |
+
return decode_graph.output, past_key_values
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
def _cuda_graph_tensor_signature(
|
| 594 |
+
tensor: Optional[torch.Tensor],
|
| 595 |
+
) -> Optional[Tuple[Any, ...]]:
|
| 596 |
+
if tensor is None:
|
| 597 |
+
return None
|
| 598 |
+
return (
|
| 599 |
+
tuple(tensor.shape),
|
| 600 |
+
tuple(tensor.stride()),
|
| 601 |
+
str(tensor.dtype),
|
| 602 |
+
str(tensor.device),
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
def _cuda_graph_context_signature(context: Any) -> Tuple[Any, ...]:
|
| 607 |
+
sig = _cuda_graph_tensor_signature
|
| 608 |
+
return (
|
| 609 |
+
tuple((sig(k), sig(v)) for k, v in context.kv_contexts),
|
| 610 |
+
sig(context.cross_mask),
|
| 611 |
+
sig(context.self_mask),
|
| 612 |
+
sig(context.valid_action),
|
| 613 |
+
None
|
| 614 |
+
if context.rope_cache is None
|
| 615 |
+
else tuple(sig(t) for t in context.rope_cache),
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
def _cuda_graph_modulation_signature(modulations: Sequence[Any]) -> Tuple[Any, ...]:
|
| 620 |
+
sig = _cuda_graph_tensor_signature
|
| 621 |
+
return tuple(
|
| 622 |
+
(
|
| 623 |
+
sig(step.conditioning),
|
| 624 |
+
tuple(
|
| 625 |
+
tuple(sig(t) for t in block_modulation)
|
| 626 |
+
for block_modulation in step.block_modulations
|
| 627 |
+
),
|
| 628 |
+
tuple(sig(t) for t in step.final_modulation),
|
| 629 |
+
)
|
| 630 |
+
for step in modulations
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
def _cuda_graph_key(inputs: _ActionFlowInputs, steps: int) -> Tuple[Any, ...]:
|
| 635 |
+
sig = _cuda_graph_tensor_signature
|
| 636 |
+
return (
|
| 637 |
+
sig(inputs.trajectory),
|
| 638 |
+
_cuda_graph_context_signature(inputs.context),
|
| 639 |
+
_cuda_graph_modulation_signature(inputs.modulations),
|
| 640 |
+
sig(inputs.action_dim_is_pad),
|
| 641 |
+
int(steps),
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
def _clone_static_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
| 646 |
+
if tensor is None:
|
| 647 |
+
return None
|
| 648 |
+
static = torch.empty_strided(
|
| 649 |
+
tuple(tensor.shape),
|
| 650 |
+
tuple(tensor.stride()),
|
| 651 |
+
device=tensor.device,
|
| 652 |
+
dtype=tensor.dtype,
|
| 653 |
+
)
|
| 654 |
+
static.copy_(tensor)
|
| 655 |
+
return static
|
| 656 |
+
|
| 657 |
+
|
| 658 |
+
def _clone_static_context(context: Any) -> Any:
|
| 659 |
+
rope_cache = None
|
| 660 |
+
if context.rope_cache is not None:
|
| 661 |
+
rope_cache = tuple(_clone_static_tensor(t) for t in context.rope_cache)
|
| 662 |
+
return context.__class__(
|
| 663 |
+
kv_contexts=tuple(
|
| 664 |
+
(_clone_static_tensor(k), _clone_static_tensor(v))
|
| 665 |
+
for k, v in context.kv_contexts
|
| 666 |
+
),
|
| 667 |
+
cross_mask=_clone_static_tensor(context.cross_mask),
|
| 668 |
+
self_mask=_clone_static_tensor(context.self_mask),
|
| 669 |
+
valid_action=_clone_static_tensor(context.valid_action),
|
| 670 |
+
rope_cache=rope_cache,
|
| 671 |
+
)
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
def _clone_static_modulations(modulations: Sequence[Any]) -> Sequence[Any]:
|
| 675 |
+
return tuple(
|
| 676 |
+
step.__class__(
|
| 677 |
+
conditioning=_clone_static_tensor(step.conditioning),
|
| 678 |
+
block_modulations=tuple(
|
| 679 |
+
tuple(_clone_static_tensor(t) for t in block_modulation)
|
| 680 |
+
for block_modulation in step.block_modulations
|
| 681 |
+
),
|
| 682 |
+
final_modulation=tuple(
|
| 683 |
+
_clone_static_tensor(t) for t in step.final_modulation
|
| 684 |
+
),
|
| 685 |
+
)
|
| 686 |
+
for step in modulations
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
def _clone_static_inputs(inputs: _ActionFlowInputs) -> _ActionFlowInputs:
|
| 691 |
+
return _ActionFlowInputs(
|
| 692 |
+
trajectory=_clone_static_tensor(inputs.trajectory),
|
| 693 |
+
context=_clone_static_context(inputs.context),
|
| 694 |
+
modulations=_clone_static_modulations(inputs.modulations),
|
| 695 |
+
action_dim_is_pad=_clone_static_tensor(inputs.action_dim_is_pad),
|
| 696 |
+
)
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
def _copy_context_(dst: Any, src: Any) -> None:
|
| 700 |
+
for (dst_k, dst_v), (src_k, src_v) in zip(dst.kv_contexts, src.kv_contexts):
|
| 701 |
+
dst_k.copy_(src_k)
|
| 702 |
+
dst_v.copy_(src_v)
|
| 703 |
+
if src.cross_mask is not None:
|
| 704 |
+
dst.cross_mask.copy_(src.cross_mask)
|
| 705 |
+
if src.self_mask is not None:
|
| 706 |
+
dst.self_mask.copy_(src.self_mask)
|
| 707 |
+
if src.valid_action is not None:
|
| 708 |
+
dst.valid_action.copy_(src.valid_action)
|
| 709 |
+
if src.rope_cache is not None:
|
| 710 |
+
for dst_tensor, src_tensor in zip(dst.rope_cache, src.rope_cache):
|
| 711 |
+
dst_tensor.copy_(src_tensor)
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
def _copy_inputs_(dst: _ActionFlowInputs, src: _ActionFlowInputs) -> None:
|
| 715 |
+
dst.trajectory.copy_(src.trajectory)
|
| 716 |
+
_copy_context_(dst.context, src.context)
|
| 717 |
+
if src.action_dim_is_pad is not None:
|
| 718 |
+
dst.action_dim_is_pad.copy_(src.action_dim_is_pad)
|
| 719 |
+
|
| 720 |
+
|
| 721 |
+
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 722 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 723 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 724 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 725 |
+
|
| 726 |
+
|
| 727 |
+
def _apply_rotary_pos_emb(
|
| 728 |
+
q: torch.Tensor,
|
| 729 |
+
k: torch.Tensor,
|
| 730 |
+
cos: torch.Tensor,
|
| 731 |
+
sin: torch.Tensor,
|
| 732 |
+
unsqueeze_dim: int = 1,
|
| 733 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 734 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 735 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 736 |
+
q_embed = (q * cos) + (_rotate_half(q) * sin)
|
| 737 |
+
k_embed = (k * cos) + (_rotate_half(k) * sin)
|
| 738 |
+
return q_embed, k_embed
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 742 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 743 |
+
if n_rep == 1:
|
| 744 |
+
return hidden_states
|
| 745 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(
|
| 746 |
+
batch, num_key_value_heads, n_rep, slen, head_dim
|
| 747 |
+
)
|
| 748 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 749 |
+
|
| 750 |
+
|
| 751 |
+
def _capture_cuda_graph(
|
| 752 |
+
fn,
|
| 753 |
+
device: torch.device,
|
| 754 |
+
*,
|
| 755 |
+
after_warmup=None,
|
| 756 |
+
) -> Tuple[torch.cuda.CUDAGraph, Any]:
|
| 757 |
+
warmup_stream = torch.cuda.Stream(device=device)
|
| 758 |
+
warmup_stream.wait_stream(torch.cuda.current_stream(device))
|
| 759 |
+
with torch.cuda.stream(warmup_stream):
|
| 760 |
+
fn()
|
| 761 |
+
torch.cuda.current_stream(device).wait_stream(warmup_stream)
|
| 762 |
+
if after_warmup is not None:
|
| 763 |
+
after_warmup()
|
| 764 |
+
|
| 765 |
+
graph = torch.cuda.CUDAGraph()
|
| 766 |
+
with torch.cuda.graph(graph):
|
| 767 |
+
output = fn()
|
| 768 |
+
return graph, output
|
model-00001-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:08663a88bc5c1f6c1cf8534a8cdf7971eb2fd66979ac42d38752d5209b971e6b
|
| 3 |
+
size 4929809880
|
model-00002-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9282d1390ca3fe81f31eaa4a925bc881ecb7c56c315eb8d3d1e2f6616cba7af9
|
| 3 |
+
size 4844690992
|
model-00003-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3f35fdd96126d04975d1feae1b715a875d03ad9d07c7de25fa91a6573cceb1e7
|
| 3 |
+
size 4844691024
|
model-00004-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e1b546c537a6eb6b5e281e5b8d6819ffe8ee8f88c55da5025e1e6ee2721cc907
|
| 3 |
+
size 4998106920
|
model-00005-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a925159b5e363f72007c38a2c4c2fc7cdac6e8b6ae990ddaa51f3abe526beb77
|
| 3 |
+
size 2345090936
|
model.safetensors.index.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
modeling_molmoact2.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
norm_stats.json
ADDED
|
@@ -0,0 +1,1739 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"format": "molmoact2_norm_stats.v1",
|
| 3 |
+
"norm_mode": "q01_q99",
|
| 4 |
+
"metadata_by_tag": {
|
| 5 |
+
"franka_molmoact": {
|
| 6 |
+
"action_key": "action.del_ee_action",
|
| 7 |
+
"state_key": "observation.state",
|
| 8 |
+
"camera_keys": [
|
| 9 |
+
"observation.images.primary",
|
| 10 |
+
"observation.images.secondary"
|
| 11 |
+
],
|
| 12 |
+
"normalize_gripper": false,
|
| 13 |
+
"action_horizon": 10,
|
| 14 |
+
"n_action_steps": 10,
|
| 15 |
+
"setup_type": "single franka robotic arm in molmoact",
|
| 16 |
+
"control_mode": "delta end-effector pose",
|
| 17 |
+
"action_stats": {
|
| 18 |
+
"min": [
|
| 19 |
+
-0.07434078305959702,
|
| 20 |
+
-0.07339745759963989,
|
| 21 |
+
-0.06539416313171387,
|
| 22 |
+
-0.1688285619020462,
|
| 23 |
+
-0.10289879888296127,
|
| 24 |
+
-0.2667275667190552,
|
| 25 |
+
0.0
|
| 26 |
+
],
|
| 27 |
+
"max": [
|
| 28 |
+
0.06042003631591797,
|
| 29 |
+
0.09417290985584259,
|
| 30 |
+
0.07019275426864624,
|
| 31 |
+
0.2616892158985138,
|
| 32 |
+
0.11751057207584381,
|
| 33 |
+
0.16968433558940887,
|
| 34 |
+
1.0
|
| 35 |
+
],
|
| 36 |
+
"mean": [
|
| 37 |
+
0.0005923698136522352,
|
| 38 |
+
0.000245022598131832,
|
| 39 |
+
-4.604843771714063e-05,
|
| 40 |
+
0.00022562421486693225,
|
| 41 |
+
-0.0005166618849942836,
|
| 42 |
+
-0.0002193919428051152,
|
| 43 |
+
0.557619424517478
|
| 44 |
+
],
|
| 45 |
+
"std": [
|
| 46 |
+
0.005274540883280089,
|
| 47 |
+
0.007662320435387572,
|
| 48 |
+
0.006516662891595147,
|
| 49 |
+
0.013564563259375743,
|
| 50 |
+
0.011179215063905077,
|
| 51 |
+
0.015195633113705318,
|
| 52 |
+
0.49666890583432166
|
| 53 |
+
],
|
| 54 |
+
"count": [
|
| 55 |
+
1482599.0
|
| 56 |
+
],
|
| 57 |
+
"q01": [
|
| 58 |
+
-0.011251236153566059,
|
| 59 |
+
-0.014918113203115847,
|
| 60 |
+
-0.011753186696798671,
|
| 61 |
+
-0.02785908205770074,
|
| 62 |
+
-0.025679744407356857,
|
| 63 |
+
-0.03279371599275369,
|
| 64 |
+
3.7096921558780464e-05
|
| 65 |
+
],
|
| 66 |
+
"q10": [
|
| 67 |
+
-0.005157558671709432,
|
| 68 |
+
-0.007627389508279324,
|
| 69 |
+
-0.006774633516067545,
|
| 70 |
+
-0.013867640389035468,
|
| 71 |
+
-0.01314412247667587,
|
| 72 |
+
-0.016390209597024155,
|
| 73 |
+
0.012615970474925397
|
| 74 |
+
],
|
| 75 |
+
"q50": [
|
| 76 |
+
0.00047587496704591567,
|
| 77 |
+
-5.756867525949417e-05,
|
| 78 |
+
-0.0004126053693703461,
|
| 79 |
+
0.00010505655624582394,
|
| 80 |
+
6.41251115100509e-05,
|
| 81 |
+
-0.00028445035571581385,
|
| 82 |
+
0.6884608295876035
|
| 83 |
+
],
|
| 84 |
+
"q90": [
|
| 85 |
+
0.006329953764757322,
|
| 86 |
+
0.008685542226677301,
|
| 87 |
+
0.008054135204293992,
|
| 88 |
+
0.01397800046720906,
|
| 89 |
+
0.010417940135392682,
|
| 90 |
+
0.016052135597642597,
|
| 91 |
+
0.9523435823006378
|
| 92 |
+
],
|
| 93 |
+
"q99": [
|
| 94 |
+
0.01117553552099105,
|
| 95 |
+
0.016859041899882184,
|
| 96 |
+
0.015590574732817865,
|
| 97 |
+
0.029286192888436802,
|
| 98 |
+
0.023178728984454205,
|
| 99 |
+
0.031348125431223534,
|
| 100 |
+
0.9523741512556912
|
| 101 |
+
],
|
| 102 |
+
"names": [
|
| 103 |
+
"x",
|
| 104 |
+
"y",
|
| 105 |
+
"z",
|
| 106 |
+
"rx",
|
| 107 |
+
"ry",
|
| 108 |
+
"rz",
|
| 109 |
+
"gripper"
|
| 110 |
+
],
|
| 111 |
+
"mask": [
|
| 112 |
+
true,
|
| 113 |
+
true,
|
| 114 |
+
true,
|
| 115 |
+
true,
|
| 116 |
+
true,
|
| 117 |
+
true,
|
| 118 |
+
false
|
| 119 |
+
]
|
| 120 |
+
},
|
| 121 |
+
"state_stats": {
|
| 122 |
+
"min": [
|
| 123 |
+
-0.26428329944610596,
|
| 124 |
+
-0.6690786480903625,
|
| 125 |
+
-0.11737073212862015,
|
| 126 |
+
-3.141592264175415,
|
| 127 |
+
-1.4651211500167847,
|
| 128 |
+
-2.9524343013763428,
|
| 129 |
+
-0.014517154544591904
|
| 130 |
+
],
|
| 131 |
+
"max": [
|
| 132 |
+
0.8226616978645325,
|
| 133 |
+
0.7252005338668823,
|
| 134 |
+
0.9137527346611023,
|
| 135 |
+
3.141592264175415,
|
| 136 |
+
1.3202887773513794,
|
| 137 |
+
1.35053551197052,
|
| 138 |
+
1.0004678964614868
|
| 139 |
+
],
|
| 140 |
+
"mean": [
|
| 141 |
+
0.524974531443455,
|
| 142 |
+
-0.009077995640631331,
|
| 143 |
+
0.37626277677807307,
|
| 144 |
+
-1.1230985182050761,
|
| 145 |
+
-0.15037831955429493,
|
| 146 |
+
-0.8360877101239638,
|
| 147 |
+
0.4828000792054066
|
| 148 |
+
],
|
| 149 |
+
"std": [
|
| 150 |
+
0.108734747557998,
|
| 151 |
+
0.19003219833018514,
|
| 152 |
+
0.13128520583933115,
|
| 153 |
+
2.684752550519181,
|
| 154 |
+
0.2757065272207643,
|
| 155 |
+
0.41552838131031417,
|
| 156 |
+
0.44164051887464084
|
| 157 |
+
],
|
| 158 |
+
"count": [
|
| 159 |
+
1482599.0
|
| 160 |
+
],
|
| 161 |
+
"q01": [
|
| 162 |
+
0.3835934249761093,
|
| 163 |
+
-0.16975945635008685,
|
| 164 |
+
0.26948875059068883,
|
| 165 |
+
-3.1154851500608602,
|
| 166 |
+
-0.3736599588300681,
|
| 167 |
+
-1.1556019879922976,
|
| 168 |
+
-0.010862173339892917
|
| 169 |
+
],
|
| 170 |
+
"q10": [
|
| 171 |
+
0.4262234785864046,
|
| 172 |
+
-0.14936716135148972,
|
| 173 |
+
0.28255691882796946,
|
| 174 |
+
-2.886297848869603,
|
| 175 |
+
-0.33795875325689667,
|
| 176 |
+
-1.111211596662902,
|
| 177 |
+
-0.010235056183139664
|
| 178 |
+
],
|
| 179 |
+
"q50": [
|
| 180 |
+
0.526748316869364,
|
| 181 |
+
-0.018291417601920976,
|
| 182 |
+
0.37479483390901625,
|
| 183 |
+
-1.3501052773711595,
|
| 184 |
+
-0.15482441515331163,
|
| 185 |
+
-0.8602460018117026,
|
| 186 |
+
0.5755928858365665
|
| 187 |
+
],
|
| 188 |
+
"q90": [
|
| 189 |
+
0.6151325476883934,
|
| 190 |
+
0.13938787377123313,
|
| 191 |
+
0.47141213932315906,
|
| 192 |
+
0.9833625012077128,
|
| 193 |
+
0.044654971996918924,
|
| 194 |
+
-0.5307531964313489,
|
| 195 |
+
0.8665980232471624
|
| 196 |
+
],
|
| 197 |
+
"q99": [
|
| 198 |
+
0.6251391330078412,
|
| 199 |
+
0.16582465215431033,
|
| 200 |
+
0.5049115299029577,
|
| 201 |
+
1.351274663819693,
|
| 202 |
+
0.09734563994616442,
|
| 203 |
+
-0.4656737437923123,
|
| 204 |
+
0.8676457869434169
|
| 205 |
+
],
|
| 206 |
+
"names": [
|
| 207 |
+
"x",
|
| 208 |
+
"y",
|
| 209 |
+
"z",
|
| 210 |
+
"rx",
|
| 211 |
+
"ry",
|
| 212 |
+
"rz",
|
| 213 |
+
"gripper"
|
| 214 |
+
],
|
| 215 |
+
"mask": [
|
| 216 |
+
true,
|
| 217 |
+
true,
|
| 218 |
+
true,
|
| 219 |
+
true,
|
| 220 |
+
true,
|
| 221 |
+
true,
|
| 222 |
+
false
|
| 223 |
+
]
|
| 224 |
+
}
|
| 225 |
+
},
|
| 226 |
+
"franka_droid": {
|
| 227 |
+
"action_key": "action",
|
| 228 |
+
"state_key": "observation.state",
|
| 229 |
+
"camera_keys": [
|
| 230 |
+
"observation.images.exterior_1_left",
|
| 231 |
+
"observation.images.exterior_2_left",
|
| 232 |
+
"observation.images.wrist_left"
|
| 233 |
+
],
|
| 234 |
+
"normalize_gripper": false,
|
| 235 |
+
"action_horizon": 15,
|
| 236 |
+
"n_action_steps": 15,
|
| 237 |
+
"setup_type": "single franka robotic arm in droid",
|
| 238 |
+
"control_mode": "absolute joint pose",
|
| 239 |
+
"action_stats": {
|
| 240 |
+
"min": [
|
| 241 |
+
-2.781099557876587,
|
| 242 |
+
-1.6407934427261353,
|
| 243 |
+
-2.7493984699249268,
|
| 244 |
+
-2.9508564472198486,
|
| 245 |
+
-2.7826988697052,
|
| 246 |
+
0.17983438074588776,
|
| 247 |
+
-2.901715040206909,
|
| 248 |
+
0.0
|
| 249 |
+
],
|
| 250 |
+
"max": [
|
| 251 |
+
2.7449073791503906,
|
| 252 |
+
1.6668277978897095,
|
| 253 |
+
2.7546653747558594,
|
| 254 |
+
-0.1936211884021759,
|
| 255 |
+
2.7786083221435547,
|
| 256 |
+
4.402013778686523,
|
| 257 |
+
2.90183162689209,
|
| 258 |
+
1.0
|
| 259 |
+
],
|
| 260 |
+
"mean": [
|
| 261 |
+
0.010418229819396566,
|
| 262 |
+
0.28233935319840636,
|
| 263 |
+
-0.015346633420959944,
|
| 264 |
+
-2.0060878874674715,
|
| 265 |
+
-0.029448930257783886,
|
| 266 |
+
2.350942437431684,
|
| 267 |
+
0.09820869537671756,
|
| 268 |
+
0.4390250813949694
|
| 269 |
+
],
|
| 270 |
+
"std": [
|
| 271 |
+
0.3170372143097277,
|
| 272 |
+
0.4863630998896905,
|
| 273 |
+
0.27477375809610444,
|
| 274 |
+
0.48806966037647037,
|
| 275 |
+
0.528105567983804,
|
| 276 |
+
0.4517944470893175,
|
| 277 |
+
0.7430287051319469,
|
| 278 |
+
0.44171628153080567
|
| 279 |
+
],
|
| 280 |
+
"count": [
|
| 281 |
+
17758044.0
|
| 282 |
+
],
|
| 283 |
+
"q01": [
|
| 284 |
+
-0.2879620949867506,
|
| 285 |
+
-0.5702219304684566,
|
| 286 |
+
-0.31101638810433413,
|
| 287 |
+
-2.5622234922052725,
|
| 288 |
+
-0.5101021838814974,
|
| 289 |
+
1.7376836093987995,
|
| 290 |
+
-0.5227783063045004,
|
| 291 |
+
8.3274762776141e-05
|
| 292 |
+
],
|
| 293 |
+
"q10": [
|
| 294 |
+
-0.2014563967491066,
|
| 295 |
+
-0.1953558605308627,
|
| 296 |
+
-0.20948523622127932,
|
| 297 |
+
-2.402722104277799,
|
| 298 |
+
-0.3766226599271436,
|
| 299 |
+
1.9723782378158212,
|
| 300 |
+
-0.35517365133256956,
|
| 301 |
+
0.006328163222238114
|
| 302 |
+
],
|
| 303 |
+
"q50": [
|
| 304 |
+
0.007162417319678469,
|
| 305 |
+
0.336456475452052,
|
| 306 |
+
-0.013974891825914252,
|
| 307 |
+
-2.008015245005848,
|
| 308 |
+
-0.025656272672692895,
|
| 309 |
+
2.3675065323600304,
|
| 310 |
+
0.09545267517159627,
|
| 311 |
+
0.4136583280016901
|
| 312 |
+
],
|
| 313 |
+
"q90": [
|
| 314 |
+
0.226670788411882,
|
| 315 |
+
0.6598602025239771,
|
| 316 |
+
0.17603345191458397,
|
| 317 |
+
-1.6119685483207011,
|
| 318 |
+
0.3092560943750579,
|
| 319 |
+
2.7008573894589345,
|
| 320 |
+
0.5521874183259908,
|
| 321 |
+
0.8737818408325101
|
| 322 |
+
],
|
| 323 |
+
"q99": [
|
| 324 |
+
0.32190872731317344,
|
| 325 |
+
0.7405054873177153,
|
| 326 |
+
0.2737893247287367,
|
| 327 |
+
-1.5075067942029405,
|
| 328 |
+
0.4329542718063284,
|
| 329 |
+
2.804162424656418,
|
| 330 |
+
0.7128911284154664,
|
| 331 |
+
0.8917437724235555
|
| 332 |
+
],
|
| 333 |
+
"names": [
|
| 334 |
+
"joint_0",
|
| 335 |
+
"joint_1",
|
| 336 |
+
"joint_2",
|
| 337 |
+
"joint_3",
|
| 338 |
+
"joint_4",
|
| 339 |
+
"joint_5",
|
| 340 |
+
"joint_6",
|
| 341 |
+
"gripper"
|
| 342 |
+
],
|
| 343 |
+
"mask": [
|
| 344 |
+
true,
|
| 345 |
+
true,
|
| 346 |
+
true,
|
| 347 |
+
true,
|
| 348 |
+
true,
|
| 349 |
+
true,
|
| 350 |
+
true,
|
| 351 |
+
false
|
| 352 |
+
]
|
| 353 |
+
},
|
| 354 |
+
"state_stats": {
|
| 355 |
+
"min": [
|
| 356 |
+
-2.6536705493927,
|
| 357 |
+
-1.6156227588653564,
|
| 358 |
+
-2.6781487464904785,
|
| 359 |
+
-2.9409868717193604,
|
| 360 |
+
-2.6705946922302246,
|
| 361 |
+
0.24893812835216522,
|
| 362 |
+
-2.757359266281128,
|
| 363 |
+
0.0
|
| 364 |
+
],
|
| 365 |
+
"max": [
|
| 366 |
+
2.6687583923339844,
|
| 367 |
+
1.5840554237365723,
|
| 368 |
+
2.666306734085083,
|
| 369 |
+
-0.29779934883117676,
|
| 370 |
+
2.6624162197113037,
|
| 371 |
+
4.272191524505615,
|
| 372 |
+
2.755643367767334,
|
| 373 |
+
1.0
|
| 374 |
+
],
|
| 375 |
+
"mean": [
|
| 376 |
+
0.011081824850861873,
|
| 377 |
+
0.27280296447760194,
|
| 378 |
+
-0.01550719225628586,
|
| 379 |
+
-2.01647228106023,
|
| 380 |
+
-0.029620826332964655,
|
| 381 |
+
2.3483866081585507,
|
| 382 |
+
0.09636965416886735,
|
| 383 |
+
0.3927326432557614
|
| 384 |
+
],
|
| 385 |
+
"std": [
|
| 386 |
+
0.31291266868924655,
|
| 387 |
+
0.4934370267472678,
|
| 388 |
+
0.2728791258795487,
|
| 389 |
+
0.48437020229024425,
|
| 390 |
+
0.521435680610052,
|
| 391 |
+
0.44821751701382595,
|
| 392 |
+
0.7352730961005634,
|
| 393 |
+
0.4070640216658998
|
| 394 |
+
],
|
| 395 |
+
"count": [
|
| 396 |
+
17758044.0
|
| 397 |
+
],
|
| 398 |
+
"q01": [
|
| 399 |
+
-0.2793009809782748,
|
| 400 |
+
-0.5873924424866738,
|
| 401 |
+
-0.3058546817065916,
|
| 402 |
+
-2.5639055042030354,
|
| 403 |
+
-0.491431808753978,
|
| 404 |
+
1.7381500993283228,
|
| 405 |
+
-0.5086147192989775,
|
| 406 |
+
1.6414552399718753e-05
|
| 407 |
+
],
|
| 408 |
+
"q10": [
|
| 409 |
+
-0.1994457930505723,
|
| 410 |
+
-0.2381088441987148,
|
| 411 |
+
-0.2103897594636481,
|
| 412 |
+
-2.421918892949847,
|
| 413 |
+
-0.3725951094142233,
|
| 414 |
+
1.961410109454104,
|
| 415 |
+
-0.35782982482940473,
|
| 416 |
+
0.005005809072924616
|
| 417 |
+
],
|
| 418 |
+
"q50": [
|
| 419 |
+
0.007891181486763803,
|
| 420 |
+
0.3376595448103942,
|
| 421 |
+
-0.014280627673021464,
|
| 422 |
+
-2.0134951539128574,
|
| 423 |
+
-0.025990006808582142,
|
| 424 |
+
2.3690656185268972,
|
| 425 |
+
0.09443906823538496,
|
| 426 |
+
0.38343357074070045
|
| 427 |
+
],
|
| 428 |
+
"q90": [
|
| 429 |
+
0.22605189533019984,
|
| 430 |
+
0.6543162155730768,
|
| 431 |
+
0.17689204963635444,
|
| 432 |
+
-1.6243810394635305,
|
| 433 |
+
0.30497772553178637,
|
| 434 |
+
2.696376125344824,
|
| 435 |
+
0.5494813775877777,
|
| 436 |
+
0.7734412581580631
|
| 437 |
+
],
|
| 438 |
+
"q99": [
|
| 439 |
+
0.3148177895778054,
|
| 440 |
+
0.7235689468221655,
|
| 441 |
+
0.2683897323238184,
|
| 442 |
+
-1.530780071911146,
|
| 443 |
+
0.415067150345451,
|
| 444 |
+
2.7863710743039887,
|
| 445 |
+
0.6952765173061115,
|
| 446 |
+
0.7968550629755542
|
| 447 |
+
],
|
| 448 |
+
"names": [
|
| 449 |
+
"joint_0",
|
| 450 |
+
"joint_1",
|
| 451 |
+
"joint_2",
|
| 452 |
+
"joint_3",
|
| 453 |
+
"joint_4",
|
| 454 |
+
"joint_5",
|
| 455 |
+
"joint_6",
|
| 456 |
+
"gripper"
|
| 457 |
+
],
|
| 458 |
+
"mask": [
|
| 459 |
+
true,
|
| 460 |
+
true,
|
| 461 |
+
true,
|
| 462 |
+
true,
|
| 463 |
+
true,
|
| 464 |
+
true,
|
| 465 |
+
true,
|
| 466 |
+
false
|
| 467 |
+
]
|
| 468 |
+
}
|
| 469 |
+
},
|
| 470 |
+
"google_robot_fractal": {
|
| 471 |
+
"action_key": "action",
|
| 472 |
+
"state_key": "observation.state",
|
| 473 |
+
"camera_keys": [
|
| 474 |
+
"observation.images.image"
|
| 475 |
+
],
|
| 476 |
+
"normalize_gripper": false,
|
| 477 |
+
"action_horizon": 3,
|
| 478 |
+
"n_action_steps": 3,
|
| 479 |
+
"setup_type": "google robot in rt_1",
|
| 480 |
+
"control_mode": "delta end-effector pose",
|
| 481 |
+
"action_stats": {
|
| 482 |
+
"min": [
|
| 483 |
+
-2.0204520225524902,
|
| 484 |
+
-5.497899532318115,
|
| 485 |
+
-2.031663417816162,
|
| 486 |
+
-1.569917917251587,
|
| 487 |
+
-1.569892168045044,
|
| 488 |
+
-1.570419430732727,
|
| 489 |
+
0.0
|
| 490 |
+
],
|
| 491 |
+
"max": [
|
| 492 |
+
2.9984593391418457,
|
| 493 |
+
22.09052848815918,
|
| 494 |
+
2.7507524490356445,
|
| 495 |
+
1.570636510848999,
|
| 496 |
+
1.5321086645126343,
|
| 497 |
+
1.5691522359848022,
|
| 498 |
+
1.0
|
| 499 |
+
],
|
| 500 |
+
"mean": [
|
| 501 |
+
0.006986742172085001,
|
| 502 |
+
0.006266400645656189,
|
| 503 |
+
-0.012625619452946994,
|
| 504 |
+
0.04333477176605177,
|
| 505 |
+
-0.005755843126369106,
|
| 506 |
+
0.0009133710921551742,
|
| 507 |
+
0.5354204546016331
|
| 508 |
+
],
|
| 509 |
+
"std": [
|
| 510 |
+
0.06943342828666754,
|
| 511 |
+
0.05987580207886052,
|
| 512 |
+
0.07384291122356837,
|
| 513 |
+
0.15697640227077467,
|
| 514 |
+
0.13192376844373777,
|
| 515 |
+
0.1463219229157086,
|
| 516 |
+
0.49874381100185294
|
| 517 |
+
],
|
| 518 |
+
"count": [
|
| 519 |
+
3786400.0
|
| 520 |
+
],
|
| 521 |
+
"q01": [
|
| 522 |
+
-0.22488493870935375,
|
| 523 |
+
-0.14842987771463928,
|
| 524 |
+
-0.23165991540148315,
|
| 525 |
+
-0.3518507387123856,
|
| 526 |
+
-0.4191961375830685,
|
| 527 |
+
-0.43642424734739155,
|
| 528 |
+
-1.000000013351432e-10
|
| 529 |
+
],
|
| 530 |
+
"q10": [
|
| 531 |
+
-0.057097137110108394,
|
| 532 |
+
-0.04180085777840345,
|
| 533 |
+
-0.08797302699742898,
|
| 534 |
+
-0.08695764133325046,
|
| 535 |
+
-0.14987822626697328,
|
| 536 |
+
-0.14407043696379337,
|
| 537 |
+
-1.000000013351432e-10
|
| 538 |
+
],
|
| 539 |
+
"q50": [
|
| 540 |
+
0.0024323156617234785,
|
| 541 |
+
0.001999621430072272,
|
| 542 |
+
-0.006186507557852898,
|
| 543 |
+
0.010844173385829027,
|
| 544 |
+
9.716094932283909e-05,
|
| 545 |
+
0.00029282634304717123,
|
| 546 |
+
0.9998131999001298
|
| 547 |
+
],
|
| 548 |
+
"q90": [
|
| 549 |
+
0.0799327921066265,
|
| 550 |
+
0.06281248479995295,
|
| 551 |
+
0.05719906967641521,
|
| 552 |
+
0.2181351081319081,
|
| 553 |
+
0.12581539646577725,
|
| 554 |
+
0.14653933152766907,
|
| 555 |
+
0.999962639980026
|
| 556 |
+
],
|
| 557 |
+
"q99": [
|
| 558 |
+
0.1780379284730618,
|
| 559 |
+
0.1492598341805028,
|
| 560 |
+
0.2184954847280796,
|
| 561 |
+
0.5894017219543457,
|
| 562 |
+
0.3527610110385077,
|
| 563 |
+
0.4478335709948289,
|
| 564 |
+
0.9999962639980026
|
| 565 |
+
],
|
| 566 |
+
"names": [
|
| 567 |
+
"x",
|
| 568 |
+
"y",
|
| 569 |
+
"z",
|
| 570 |
+
"roll",
|
| 571 |
+
"pitch",
|
| 572 |
+
"yaw",
|
| 573 |
+
"gripper"
|
| 574 |
+
],
|
| 575 |
+
"mask": [
|
| 576 |
+
true,
|
| 577 |
+
true,
|
| 578 |
+
true,
|
| 579 |
+
true,
|
| 580 |
+
true,
|
| 581 |
+
true,
|
| 582 |
+
false
|
| 583 |
+
]
|
| 584 |
+
},
|
| 585 |
+
"state_stats": {
|
| 586 |
+
"min": [
|
| 587 |
+
-0.4436439275741577,
|
| 588 |
+
-0.9970501065254211,
|
| 589 |
+
-0.006579156965017319,
|
| 590 |
+
-0.8643477559089661,
|
| 591 |
+
-0.7079970240592957,
|
| 592 |
+
-0.7688722014427185,
|
| 593 |
+
-0.4999994933605194,
|
| 594 |
+
0.0
|
| 595 |
+
],
|
| 596 |
+
"max": [
|
| 597 |
+
1.0534898042678833,
|
| 598 |
+
0.48018959164619446,
|
| 599 |
+
1.6896663904190063,
|
| 600 |
+
0.9999993443489075,
|
| 601 |
+
0.9999874830245972,
|
| 602 |
+
0.9554369449615479,
|
| 603 |
+
0.9914546012878418,
|
| 604 |
+
1.0
|
| 605 |
+
],
|
| 606 |
+
"mean": [
|
| 607 |
+
0.5582046028643476,
|
| 608 |
+
-0.08324323429555826,
|
| 609 |
+
0.7708198142579598,
|
| 610 |
+
-0.24752762586024715,
|
| 611 |
+
0.4959921774813562,
|
| 612 |
+
0.0925577145133276,
|
| 613 |
+
0.20941890216560163,
|
| 614 |
+
0.42619563761216767
|
| 615 |
+
],
|
| 616 |
+
"std": [
|
| 617 |
+
0.12440319799919354,
|
| 618 |
+
0.11571359399631491,
|
| 619 |
+
0.2458943611771509,
|
| 620 |
+
0.5132342578001884,
|
| 621 |
+
0.5223439094545202,
|
| 622 |
+
0.1666598633276366,
|
| 623 |
+
0.27617123901287927,
|
| 624 |
+
0.4538753441706389
|
| 625 |
+
],
|
| 626 |
+
"count": [
|
| 627 |
+
3786400.0
|
| 628 |
+
],
|
| 629 |
+
"q01": [
|
| 630 |
+
0.3249422830693862,
|
| 631 |
+
-0.28341992821874495,
|
| 632 |
+
0.14102827969076331,
|
| 633 |
+
-0.6864852132802142,
|
| 634 |
+
-0.6809632829655476,
|
| 635 |
+
-0.36044700054021983,
|
| 636 |
+
-0.4542378536110671,
|
| 637 |
+
-1.000000013351432e-10
|
| 638 |
+
],
|
| 639 |
+
"q10": [
|
| 640 |
+
0.42490653590113253,
|
| 641 |
+
-0.2163404740670024,
|
| 642 |
+
0.37762326560147996,
|
| 643 |
+
-0.6294334687684712,
|
| 644 |
+
-0.5920843577131312,
|
| 645 |
+
-0.09803071723264807,
|
| 646 |
+
-0.23202098126670248,
|
| 647 |
+
-1.000000013351432e-10
|
| 648 |
+
],
|
| 649 |
+
"q50": [
|
| 650 |
+
0.5389458633818717,
|
| 651 |
+
-0.10059445446247807,
|
| 652 |
+
0.8738477700690715,
|
| 653 |
+
-0.4849259061727551,
|
| 654 |
+
0.7293306254210121,
|
| 655 |
+
0.09137287071030761,
|
| 656 |
+
0.23796976550241536,
|
| 657 |
+
0.1832136750707287
|
| 658 |
+
],
|
| 659 |
+
"q90": [
|
| 660 |
+
0.7370583820538442,
|
| 661 |
+
0.08210784119164745,
|
| 662 |
+
0.9798527660285249,
|
| 663 |
+
0.7291734785677116,
|
| 664 |
+
0.84104651841686,
|
| 665 |
+
0.3032210107222038,
|
| 666 |
+
0.5373912158511455,
|
| 667 |
+
0.9999365636178622
|
| 668 |
+
],
|
| 669 |
+
"q99": [
|
| 670 |
+
0.8750117781915163,
|
| 671 |
+
0.21252014598261149,
|
| 672 |
+
1.0727446933587392,
|
| 673 |
+
0.9378297494636977,
|
| 674 |
+
0.9562844548524763,
|
| 675 |
+
0.46002622460251424,
|
| 676 |
+
0.721691133425786,
|
| 677 |
+
0.9999936563617862
|
| 678 |
+
],
|
| 679 |
+
"names": [
|
| 680 |
+
"x",
|
| 681 |
+
"y",
|
| 682 |
+
"z",
|
| 683 |
+
"rx",
|
| 684 |
+
"ry",
|
| 685 |
+
"rz",
|
| 686 |
+
"rw",
|
| 687 |
+
"gripper"
|
| 688 |
+
],
|
| 689 |
+
"mask": [
|
| 690 |
+
true,
|
| 691 |
+
true,
|
| 692 |
+
true,
|
| 693 |
+
true,
|
| 694 |
+
true,
|
| 695 |
+
true,
|
| 696 |
+
true,
|
| 697 |
+
false
|
| 698 |
+
]
|
| 699 |
+
}
|
| 700 |
+
},
|
| 701 |
+
"widowx_bridge": {
|
| 702 |
+
"action_key": "action",
|
| 703 |
+
"state_key": "observation.state",
|
| 704 |
+
"camera_keys": [
|
| 705 |
+
"observation.images.image_0",
|
| 706 |
+
"observation.images.image_1",
|
| 707 |
+
"observation.images.image_2",
|
| 708 |
+
"observation.images.image_3"
|
| 709 |
+
],
|
| 710 |
+
"normalize_gripper": false,
|
| 711 |
+
"action_horizon": 5,
|
| 712 |
+
"n_action_steps": 5,
|
| 713 |
+
"setup_type": "single widowx robotic arm in bridge",
|
| 714 |
+
"control_mode": "delta end-effector pose",
|
| 715 |
+
"action_stats": {
|
| 716 |
+
"min": [
|
| 717 |
+
-0.4007510244846344,
|
| 718 |
+
-0.13874775171279907,
|
| 719 |
+
-0.22553899884223938,
|
| 720 |
+
-3.2010786533355713,
|
| 721 |
+
-1.8618112802505493,
|
| 722 |
+
-6.279075622558594,
|
| 723 |
+
0.0
|
| 724 |
+
],
|
| 725 |
+
"max": [
|
| 726 |
+
0.41691166162490845,
|
| 727 |
+
0.25864794850349426,
|
| 728 |
+
0.21218234300613403,
|
| 729 |
+
3.122201919555664,
|
| 730 |
+
1.8618112802505493,
|
| 731 |
+
6.272472858428955,
|
| 732 |
+
1.0
|
| 733 |
+
],
|
| 734 |
+
"mean": [
|
| 735 |
+
0.00022731789976267202,
|
| 736 |
+
0.0001311203695138562,
|
| 737 |
+
-0.00012641641264803482,
|
| 738 |
+
-0.00014410962647987843,
|
| 739 |
+
-0.0003903070519037156,
|
| 740 |
+
0.00024063480455490454,
|
| 741 |
+
0.5765894392570026
|
| 742 |
+
],
|
| 743 |
+
"std": [
|
| 744 |
+
0.009782343005332487,
|
| 745 |
+
0.013714070718580267,
|
| 746 |
+
0.012687395519404626,
|
| 747 |
+
0.02848996416069207,
|
| 748 |
+
0.030552792886390234,
|
| 749 |
+
0.07751153262919225,
|
| 750 |
+
0.49409209255711634
|
| 751 |
+
],
|
| 752 |
+
"count": [
|
| 753 |
+
1893026.0
|
| 754 |
+
],
|
| 755 |
+
"q01": [
|
| 756 |
+
-0.02871995611488819,
|
| 757 |
+
-0.04170781908448411,
|
| 758 |
+
-0.02608340910386921,
|
| 759 |
+
-0.0808367313719228,
|
| 760 |
+
-0.09246813206247581,
|
| 761 |
+
-0.20693750972396757,
|
| 762 |
+
-1.000000013351432e-10
|
| 763 |
+
],
|
| 764 |
+
"q10": [
|
| 765 |
+
-0.010151055597043716,
|
| 766 |
+
-0.014922217821287087,
|
| 767 |
+
-0.01393665282931255,
|
| 768 |
+
-0.029593090264604636,
|
| 769 |
+
-0.03406380769665256,
|
| 770 |
+
-0.06413116391050117,
|
| 771 |
+
-1.000000013351432e-10
|
| 772 |
+
],
|
| 773 |
+
"q50": [
|
| 774 |
+
2.1248139354103056e-05,
|
| 775 |
+
-9.382913823534339e-06,
|
| 776 |
+
-0.0008275577521357758,
|
| 777 |
+
-0.00014731252460737677,
|
| 778 |
+
0.00047152176188271845,
|
| 779 |
+
0.0012537133528066303,
|
| 780 |
+
0.9998265319000765
|
| 781 |
+
],
|
| 782 |
+
"q90": [
|
| 783 |
+
0.011082387395765428,
|
| 784 |
+
0.015737555353724994,
|
| 785 |
+
0.016874204550636374,
|
| 786 |
+
0.02832893788750676,
|
| 787 |
+
0.0322629905973504,
|
| 788 |
+
0.06417266804375155,
|
| 789 |
+
0.9999653063800154
|
| 790 |
+
],
|
| 791 |
+
"q99": [
|
| 792 |
+
0.028291364035668985,
|
| 793 |
+
0.040898679036702676,
|
| 794 |
+
0.04018220331768194,
|
| 795 |
+
0.08177042032653538,
|
| 796 |
+
0.07759675528459531,
|
| 797 |
+
0.203201938256362,
|
| 798 |
+
0.9999965306380015
|
| 799 |
+
],
|
| 800 |
+
"names": [
|
| 801 |
+
"x",
|
| 802 |
+
"y",
|
| 803 |
+
"z",
|
| 804 |
+
"roll",
|
| 805 |
+
"pitch",
|
| 806 |
+
"yaw",
|
| 807 |
+
"gripper"
|
| 808 |
+
],
|
| 809 |
+
"mask": [
|
| 810 |
+
true,
|
| 811 |
+
true,
|
| 812 |
+
true,
|
| 813 |
+
true,
|
| 814 |
+
true,
|
| 815 |
+
true,
|
| 816 |
+
false
|
| 817 |
+
]
|
| 818 |
+
},
|
| 819 |
+
"state_stats": {
|
| 820 |
+
"min": [
|
| 821 |
+
-0.04167502000927925,
|
| 822 |
+
-0.3563207685947418,
|
| 823 |
+
-0.15537554025650024,
|
| 824 |
+
-3.141592502593994,
|
| 825 |
+
-1.4992541074752808,
|
| 826 |
+
-3.14153790473938,
|
| 827 |
+
0.0,
|
| 828 |
+
0.04637829214334488
|
| 829 |
+
],
|
| 830 |
+
"max": [
|
| 831 |
+
0.5862360596656799,
|
| 832 |
+
0.4034728705883026,
|
| 833 |
+
0.3568263053894043,
|
| 834 |
+
1.3517684936523438,
|
| 835 |
+
1.570796251296997,
|
| 836 |
+
3.141204357147217,
|
| 837 |
+
0.0,
|
| 838 |
+
1.1121242046356201
|
| 839 |
+
],
|
| 840 |
+
"mean": [
|
| 841 |
+
0.3094503633235095,
|
| 842 |
+
0.030725376723448255,
|
| 843 |
+
0.06443996750169499,
|
| 844 |
+
0.0064906683342908335,
|
| 845 |
+
-0.07720050195254197,
|
| 846 |
+
0.10766038148835028,
|
| 847 |
+
0.0,
|
| 848 |
+
0.7081244810708762
|
| 849 |
+
],
|
| 850 |
+
"std": [
|
| 851 |
+
0.06060302901710459,
|
| 852 |
+
0.0919536927343182,
|
| 853 |
+
0.05159382707079282,
|
| 854 |
+
0.1312174751351825,
|
| 855 |
+
0.16924010047039229,
|
| 856 |
+
0.5787203550709503,
|
| 857 |
+
0.0,
|
| 858 |
+
0.35365012001260804
|
| 859 |
+
],
|
| 860 |
+
"count": [
|
| 861 |
+
1893026.0
|
| 862 |
+
],
|
| 863 |
+
"q01": [
|
| 864 |
+
0.17102651970064053,
|
| 865 |
+
-0.16977934478310977,
|
| 866 |
+
-0.05565095783375642,
|
| 867 |
+
-0.3649685841887744,
|
| 868 |
+
-0.5418705685890239,
|
| 869 |
+
-1.3540046312592247,
|
| 870 |
+
0.0,
|
| 871 |
+
0.05212163980268402
|
| 872 |
+
],
|
| 873 |
+
"q10": [
|
| 874 |
+
0.234054275333357,
|
| 875 |
+
-0.08584102855009192,
|
| 876 |
+
0.007129108058706664,
|
| 877 |
+
-0.13279207613930774,
|
| 878 |
+
-0.2879179685802783,
|
| 879 |
+
-0.47590377710082316,
|
| 880 |
+
0.0,
|
| 881 |
+
0.08160105384386226
|
| 882 |
+
],
|
| 883 |
+
"q50": [
|
| 884 |
+
0.30824996150509265,
|
| 885 |
+
0.02806205006373531,
|
| 886 |
+
0.061364141277506515,
|
| 887 |
+
0.003477529234181987,
|
| 888 |
+
-0.06586482997881163,
|
| 889 |
+
0.033681061760553146,
|
| 890 |
+
0.0,
|
| 891 |
+
0.9850432498405283
|
| 892 |
+
],
|
| 893 |
+
"q90": [
|
| 894 |
+
0.3866535994382209,
|
| 895 |
+
0.15225549791502352,
|
| 896 |
+
0.1303319111924363,
|
| 897 |
+
0.14920492884702988,
|
| 898 |
+
0.11511126950562722,
|
| 899 |
+
0.8206040455663128,
|
| 900 |
+
0.0,
|
| 901 |
+
1.0013512433353218
|
| 902 |
+
],
|
| 903 |
+
"q99": [
|
| 904 |
+
0.453255677819252,
|
| 905 |
+
0.23543677111215228,
|
| 906 |
+
0.19489739182202712,
|
| 907 |
+
0.378015822982788,
|
| 908 |
+
0.27597790842706504,
|
| 909 |
+
1.8504199743270873,
|
| 910 |
+
0.0,
|
| 911 |
+
1.0106366157291133
|
| 912 |
+
],
|
| 913 |
+
"names": [
|
| 914 |
+
"x",
|
| 915 |
+
"y",
|
| 916 |
+
"z",
|
| 917 |
+
"roll",
|
| 918 |
+
"pitch",
|
| 919 |
+
"yaw",
|
| 920 |
+
"pad",
|
| 921 |
+
"gripper"
|
| 922 |
+
],
|
| 923 |
+
"mask": [
|
| 924 |
+
true,
|
| 925 |
+
true,
|
| 926 |
+
true,
|
| 927 |
+
true,
|
| 928 |
+
true,
|
| 929 |
+
true,
|
| 930 |
+
true,
|
| 931 |
+
false
|
| 932 |
+
]
|
| 933 |
+
}
|
| 934 |
+
},
|
| 935 |
+
"so100_so101_molmoact2": {
|
| 936 |
+
"action_key": "action",
|
| 937 |
+
"state_key": "observation.state",
|
| 938 |
+
"camera_keys": [],
|
| 939 |
+
"normalize_gripper": true,
|
| 940 |
+
"action_horizon": 30,
|
| 941 |
+
"n_action_steps": 30,
|
| 942 |
+
"setup_type": "single so100/so101 robotic arm in molmoact2",
|
| 943 |
+
"control_mode": "absolute joint pose",
|
| 944 |
+
"action_stats": {
|
| 945 |
+
"min": [
|
| 946 |
+
-122.607421875,
|
| 947 |
+
-270.0,
|
| 948 |
+
-269.208984375,
|
| 949 |
+
-125.771484375,
|
| 950 |
+
-269.912109375,
|
| 951 |
+
-31.57327651977539
|
| 952 |
+
],
|
| 953 |
+
"max": [
|
| 954 |
+
179.208984375,
|
| 955 |
+
219.638671875,
|
| 956 |
+
195.380859375,
|
| 957 |
+
178.9453125,
|
| 958 |
+
269.82421875,
|
| 959 |
+
119.40789031982422
|
| 960 |
+
],
|
| 961 |
+
"mean": [
|
| 962 |
+
3.343996486826433,
|
| 963 |
+
125.7905980370996,
|
| 964 |
+
120.20220128113388,
|
| 965 |
+
55.88144220174933,
|
| 966 |
+
-11.543010633027725,
|
| 967 |
+
11.25886240824774
|
| 968 |
+
],
|
| 969 |
+
"std": [
|
| 970 |
+
28.909870406169997,
|
| 971 |
+
52.25069634659296,
|
| 972 |
+
47.94432906599221,
|
| 973 |
+
36.01019142727721,
|
| 974 |
+
69.35504013212369,
|
| 975 |
+
17.116239869449775
|
| 976 |
+
],
|
| 977 |
+
"count": [
|
| 978 |
+
19619650.0
|
| 979 |
+
],
|
| 980 |
+
"q01": [
|
| 981 |
+
-42.1300246338976,
|
| 982 |
+
45.18258358164995,
|
| 983 |
+
35.40059182962813,
|
| 984 |
+
4.929781836327758,
|
| 985 |
+
-65.57568617645342,
|
| 986 |
+
-0.3016556932619033
|
| 987 |
+
],
|
| 988 |
+
"q10": [
|
| 989 |
+
-25.040070398997557,
|
| 990 |
+
68.27827215165794,
|
| 991 |
+
65.76540485606242,
|
| 992 |
+
26.58811186925123,
|
| 993 |
+
-39.81868441470048,
|
| 994 |
+
0.26123181871944706
|
| 995 |
+
],
|
| 996 |
+
"q50": [
|
| 997 |
+
3.0828094324713105,
|
| 998 |
+
124.5495736487354,
|
| 999 |
+
122.75175717637279,
|
| 1000 |
+
57.77960070056314,
|
| 1001 |
+
-11.094802886190045,
|
| 1002 |
+
4.866634607477139
|
| 1003 |
+
],
|
| 1004 |
+
"q90": [
|
| 1005 |
+
31.591544866079253,
|
| 1006 |
+
181.76986724267596,
|
| 1007 |
+
168.5741215400282,
|
| 1008 |
+
82.4353358815596,
|
| 1009 |
+
16.05609349144359,
|
| 1010 |
+
32.12324970648343
|
| 1011 |
+
],
|
| 1012 |
+
"q99": [
|
| 1013 |
+
48.55349563198916,
|
| 1014 |
+
186.10646680077767,
|
| 1015 |
+
173.6076722013997,
|
| 1016 |
+
93.41056417929472,
|
| 1017 |
+
43.53107398260694,
|
| 1018 |
+
44.74649336930881
|
| 1019 |
+
],
|
| 1020 |
+
"names": [
|
| 1021 |
+
"shoulder_pan",
|
| 1022 |
+
"shoulder_lift",
|
| 1023 |
+
"elbow_flex",
|
| 1024 |
+
"wrist_flex",
|
| 1025 |
+
"wrist_roll",
|
| 1026 |
+
"gripper"
|
| 1027 |
+
],
|
| 1028 |
+
"mask": [
|
| 1029 |
+
true,
|
| 1030 |
+
true,
|
| 1031 |
+
true,
|
| 1032 |
+
true,
|
| 1033 |
+
true,
|
| 1034 |
+
true
|
| 1035 |
+
]
|
| 1036 |
+
},
|
| 1037 |
+
"state_stats": {
|
| 1038 |
+
"min": [
|
| 1039 |
+
-115.048828125,
|
| 1040 |
+
-270.0,
|
| 1041 |
+
-235.8984375,
|
| 1042 |
+
-113.818359375,
|
| 1043 |
+
-268.9453125,
|
| 1044 |
+
-8.521058082580566
|
| 1045 |
+
],
|
| 1046 |
+
"max": [
|
| 1047 |
+
178.505859375,
|
| 1048 |
+
218.49609375,
|
| 1049 |
+
192.041015625,
|
| 1050 |
+
207.861328125,
|
| 1051 |
+
250.048828125,
|
| 1052 |
+
118.2519302368164
|
| 1053 |
+
],
|
| 1054 |
+
"mean": [
|
| 1055 |
+
3.3225097946752244,
|
| 1056 |
+
124.40594064960378,
|
| 1057 |
+
121.59550610749059,
|
| 1058 |
+
55.903039878016074,
|
| 1059 |
+
-11.41740021122887,
|
| 1060 |
+
13.358497334686597
|
| 1061 |
+
],
|
| 1062 |
+
"std": [
|
| 1063 |
+
28.79265204113751,
|
| 1064 |
+
52.702867303079756,
|
| 1065 |
+
47.00596021941705,
|
| 1066 |
+
35.53803566355756,
|
| 1067 |
+
69.12836626047817,
|
| 1068 |
+
16.333280282904557
|
| 1069 |
+
],
|
| 1070 |
+
"count": [
|
| 1071 |
+
19619650.0
|
| 1072 |
+
],
|
| 1073 |
+
"q01": [
|
| 1074 |
+
-41.90962240941357,
|
| 1075 |
+
43.66791235922949,
|
| 1076 |
+
38.38770483255723,
|
| 1077 |
+
5.711740446834044,
|
| 1078 |
+
-63.44539045209019,
|
| 1079 |
+
0.9435577790191543
|
| 1080 |
+
],
|
| 1081 |
+
"q10": [
|
| 1082 |
+
-24.949315993050774,
|
| 1083 |
+
66.30007546431412,
|
| 1084 |
+
68.16816985859437,
|
| 1085 |
+
27.120731646136054,
|
| 1086 |
+
-39.50255020332888,
|
| 1087 |
+
1.6190225837869365
|
| 1088 |
+
],
|
| 1089 |
+
"q50": [
|
| 1090 |
+
3.066375725640164,
|
| 1091 |
+
123.16482094240277,
|
| 1092 |
+
124.39930058290133,
|
| 1093 |
+
57.88605464633133,
|
| 1094 |
+
-11.037436711677765,
|
| 1095 |
+
9.241478261568748
|
| 1096 |
+
],
|
| 1097 |
+
"q90": [
|
| 1098 |
+
31.472920732960127,
|
| 1099 |
+
180.87158401301218,
|
| 1100 |
+
168.5699720215359,
|
| 1101 |
+
81.64709150074712,
|
| 1102 |
+
15.887605114617852,
|
| 1103 |
+
31.887861734718296
|
| 1104 |
+
],
|
| 1105 |
+
"q99": [
|
| 1106 |
+
48.29435703371732,
|
| 1107 |
+
185.2611055842669,
|
| 1108 |
+
173.13578487933165,
|
| 1109 |
+
91.78122415137209,
|
| 1110 |
+
42.94491979114059,
|
| 1111 |
+
44.13755601580974
|
| 1112 |
+
],
|
| 1113 |
+
"names": [
|
| 1114 |
+
"shoulder_pan",
|
| 1115 |
+
"shoulder_lift",
|
| 1116 |
+
"elbow_flex",
|
| 1117 |
+
"wrist_flex",
|
| 1118 |
+
"wrist_roll",
|
| 1119 |
+
"gripper"
|
| 1120 |
+
],
|
| 1121 |
+
"mask": [
|
| 1122 |
+
true,
|
| 1123 |
+
true,
|
| 1124 |
+
true,
|
| 1125 |
+
true,
|
| 1126 |
+
true,
|
| 1127 |
+
true
|
| 1128 |
+
]
|
| 1129 |
+
}
|
| 1130 |
+
},
|
| 1131 |
+
"google_robot_bc_z": {
|
| 1132 |
+
"action_key": "action",
|
| 1133 |
+
"state_key": "observation.state",
|
| 1134 |
+
"camera_keys": [
|
| 1135 |
+
"observation.images.image"
|
| 1136 |
+
],
|
| 1137 |
+
"normalize_gripper": false,
|
| 1138 |
+
"action_horizon": 10,
|
| 1139 |
+
"n_action_steps": 10,
|
| 1140 |
+
"setup_type": "google robot in bc_z",
|
| 1141 |
+
"control_mode": "delta end-effector pose",
|
| 1142 |
+
"action_stats": {
|
| 1143 |
+
"min": [
|
| 1144 |
+
-0.1677047461271286,
|
| 1145 |
+
-0.14630407094955444,
|
| 1146 |
+
-0.10066790133714676,
|
| 1147 |
+
-0.29421567916870117,
|
| 1148 |
+
-0.32101404666900635,
|
| 1149 |
+
-0.4635624885559082,
|
| 1150 |
+
0.0
|
| 1151 |
+
],
|
| 1152 |
+
"max": [
|
| 1153 |
+
0.2165454924106598,
|
| 1154 |
+
0.1251407265663147,
|
| 1155 |
+
0.09988310933113098,
|
| 1156 |
+
0.33544227480888367,
|
| 1157 |
+
0.28117990493774414,
|
| 1158 |
+
0.40614867210388184,
|
| 1159 |
+
1.0
|
| 1160 |
+
],
|
| 1161 |
+
"mean": [
|
| 1162 |
+
-0.009960200864471745,
|
| 1163 |
+
0.0009084977087131892,
|
| 1164 |
+
0.00499393515302369,
|
| 1165 |
+
0.00028739003438370427,
|
| 1166 |
+
-0.00871610909893306,
|
| 1167 |
+
-0.030692461306736755,
|
| 1168 |
+
0.8343520005664466
|
| 1169 |
+
],
|
| 1170 |
+
"std": [
|
| 1171 |
+
0.03080177058689462,
|
| 1172 |
+
0.023236620172139833,
|
| 1173 |
+
0.020777592916798007,
|
| 1174 |
+
0.041763587623031895,
|
| 1175 |
+
0.046686683400427,
|
| 1176 |
+
0.07753463216688747,
|
| 1177 |
+
0.3717643553432202
|
| 1178 |
+
],
|
| 1179 |
+
"count": [
|
| 1180 |
+
5471693.0
|
| 1181 |
+
],
|
| 1182 |
+
"q01": [
|
| 1183 |
+
-0.09213472068957661,
|
| 1184 |
+
-0.06450906318665113,
|
| 1185 |
+
-0.04912072456744037,
|
| 1186 |
+
-0.11609895664024446,
|
| 1187 |
+
-0.1413486404610977,
|
| 1188 |
+
-0.22517701597416145,
|
| 1189 |
+
-1.000000013351432e-10
|
| 1190 |
+
],
|
| 1191 |
+
"q10": [
|
| 1192 |
+
-0.05253115985050928,
|
| 1193 |
+
-0.028533985817234882,
|
| 1194 |
+
-0.021736428190829056,
|
| 1195 |
+
-0.04809403695382897,
|
| 1196 |
+
-0.0664864549799673,
|
| 1197 |
+
-0.1391167833364122,
|
| 1198 |
+
-1.000000013351432e-10
|
| 1199 |
+
],
|
| 1200 |
+
"q50": [
|
| 1201 |
+
-0.0031453596109414592,
|
| 1202 |
+
0.0004054125482836473,
|
| 1203 |
+
0.0023481391860319715,
|
| 1204 |
+
-8.489440239357886e-05,
|
| 1205 |
+
-0.002574837787014793,
|
| 1206 |
+
-0.014108526356650069,
|
| 1207 |
+
0.9998801266205536
|
| 1208 |
+
],
|
| 1209 |
+
"q90": [
|
| 1210 |
+
0.019494707527676427,
|
| 1211 |
+
0.029460992205482695,
|
| 1212 |
+
0.032557826189659966,
|
| 1213 |
+
0.04931595102291217,
|
| 1214 |
+
0.042994841552155126,
|
| 1215 |
+
0.05302803170853769,
|
| 1216 |
+
0.9999760253241107
|
| 1217 |
+
],
|
| 1218 |
+
"q99": [
|
| 1219 |
+
0.07630278211772451,
|
| 1220 |
+
0.05802308552485688,
|
| 1221 |
+
0.052553275338456634,
|
| 1222 |
+
0.1173714221625478,
|
| 1223 |
+
0.11711249897425843,
|
| 1224 |
+
0.1673988100025391,
|
| 1225 |
+
0.9999976025324111
|
| 1226 |
+
],
|
| 1227 |
+
"names": [
|
| 1228 |
+
"x",
|
| 1229 |
+
"y",
|
| 1230 |
+
"z",
|
| 1231 |
+
"roll",
|
| 1232 |
+
"pitch",
|
| 1233 |
+
"yaw",
|
| 1234 |
+
"gripper"
|
| 1235 |
+
],
|
| 1236 |
+
"mask": [
|
| 1237 |
+
true,
|
| 1238 |
+
true,
|
| 1239 |
+
true,
|
| 1240 |
+
true,
|
| 1241 |
+
true,
|
| 1242 |
+
true,
|
| 1243 |
+
false
|
| 1244 |
+
]
|
| 1245 |
+
},
|
| 1246 |
+
"state_stats": {
|
| 1247 |
+
"min": [
|
| 1248 |
+
-0.7190948724746704,
|
| 1249 |
+
-0.3756217360496521,
|
| 1250 |
+
-0.281008243560791,
|
| 1251 |
+
-2.400146484375,
|
| 1252 |
+
-2.500656843185425,
|
| 1253 |
+
-3.1274476051330566,
|
| 1254 |
+
0.0,
|
| 1255 |
+
0.0
|
| 1256 |
+
],
|
| 1257 |
+
"max": [
|
| 1258 |
+
0.6597589254379272,
|
| 1259 |
+
0.7259413599967957,
|
| 1260 |
+
1.1217665672302246,
|
| 1261 |
+
2.2803165912628174,
|
| 1262 |
+
1.8151572942733765,
|
| 1263 |
+
3.1237573623657227,
|
| 1264 |
+
0.0,
|
| 1265 |
+
1.0
|
| 1266 |
+
],
|
| 1267 |
+
"mean": [
|
| 1268 |
+
0.0176884768449917,
|
| 1269 |
+
0.10948195169606133,
|
| 1270 |
+
0.784290845584472,
|
| 1271 |
+
-0.5290053991424425,
|
| 1272 |
+
-0.22605912165135514,
|
| 1273 |
+
-0.17858785012278866,
|
| 1274 |
+
0.0,
|
| 1275 |
+
0.5600556496096702
|
| 1276 |
+
],
|
| 1277 |
+
"std": [
|
| 1278 |
+
0.1841601172406892,
|
| 1279 |
+
0.09627411033983578,
|
| 1280 |
+
0.08699189118288073,
|
| 1281 |
+
0.24700645691257475,
|
| 1282 |
+
0.4286554852012691,
|
| 1283 |
+
1.0001615516228195,
|
| 1284 |
+
0.0,
|
| 1285 |
+
0.3586031013748201
|
| 1286 |
+
],
|
| 1287 |
+
"count": [
|
| 1288 |
+
5471693.0
|
| 1289 |
+
],
|
| 1290 |
+
"q01": [
|
| 1291 |
+
-0.38789819221198557,
|
| 1292 |
+
-0.1118956928319213,
|
| 1293 |
+
0.6110697470322705,
|
| 1294 |
+
-1.0415028765133625,
|
| 1295 |
+
-1.1876200204022105,
|
| 1296 |
+
-2.3808376895782,
|
| 1297 |
+
0.0,
|
| 1298 |
+
0.19986777120588917
|
| 1299 |
+
],
|
| 1300 |
+
"q10": [
|
| 1301 |
+
-0.2318964688694949,
|
| 1302 |
+
-0.015558046064633315,
|
| 1303 |
+
0.6822309043992328,
|
| 1304 |
+
-0.7563316012340816,
|
| 1305 |
+
-0.7533119325741587,
|
| 1306 |
+
-1.3938289285869132,
|
| 1307 |
+
0.0,
|
| 1308 |
+
0.2000496453831541
|
| 1309 |
+
],
|
| 1310 |
+
"q50": [
|
| 1311 |
+
0.022859327635303822,
|
| 1312 |
+
0.10637610222856157,
|
| 1313 |
+
0.776611691927557,
|
| 1314 |
+
-0.5671171062825059,
|
| 1315 |
+
-0.24114911945667813,
|
| 1316 |
+
-0.25162686787881255,
|
| 1317 |
+
0.0,
|
| 1318 |
+
0.3501994619818789
|
| 1319 |
+
],
|
| 1320 |
+
"q90": [
|
| 1321 |
+
0.2666238156546802,
|
| 1322 |
+
0.23844897018337458,
|
| 1323 |
+
0.9059002565684082,
|
| 1324 |
+
-0.26983885858517637,
|
| 1325 |
+
0.3994129877275485,
|
| 1326 |
+
1.374448904817122,
|
| 1327 |
+
0.0,
|
| 1328 |
+
0.999900866490248
|
| 1329 |
+
],
|
| 1330 |
+
"q99": [
|
| 1331 |
+
0.3325375374171561,
|
| 1332 |
+
0.31715197447407467,
|
| 1333 |
+
0.982179447052214,
|
| 1334 |
+
0.34632693633800826,
|
| 1335 |
+
0.7713777675821983,
|
| 1336 |
+
2.029990628516839,
|
| 1337 |
+
0.0,
|
| 1338 |
+
0.9999900866490248
|
| 1339 |
+
],
|
| 1340 |
+
"names": [
|
| 1341 |
+
"x",
|
| 1342 |
+
"y",
|
| 1343 |
+
"z",
|
| 1344 |
+
"roll",
|
| 1345 |
+
"pitch",
|
| 1346 |
+
"yaw",
|
| 1347 |
+
"pad",
|
| 1348 |
+
"gripper"
|
| 1349 |
+
],
|
| 1350 |
+
"mask": [
|
| 1351 |
+
true,
|
| 1352 |
+
true,
|
| 1353 |
+
true,
|
| 1354 |
+
true,
|
| 1355 |
+
true,
|
| 1356 |
+
true,
|
| 1357 |
+
true,
|
| 1358 |
+
false
|
| 1359 |
+
]
|
| 1360 |
+
}
|
| 1361 |
+
},
|
| 1362 |
+
"yam_dual_molmoact2": {
|
| 1363 |
+
"action_key": "action",
|
| 1364 |
+
"state_key": "observation.state",
|
| 1365 |
+
"camera_keys": [
|
| 1366 |
+
"observation.images.top",
|
| 1367 |
+
"observation.images.left",
|
| 1368 |
+
"observation.images.right"
|
| 1369 |
+
],
|
| 1370 |
+
"normalize_gripper": false,
|
| 1371 |
+
"action_horizon": 30,
|
| 1372 |
+
"n_action_steps": 30,
|
| 1373 |
+
"setup_type": "bimanual yam robotic arms in molmoact2",
|
| 1374 |
+
"control_mode": "absolute joint pose",
|
| 1375 |
+
"action_stats": {
|
| 1376 |
+
"min": [
|
| 1377 |
+
-1.9876782894134521,
|
| 1378 |
+
-0.007057297509163618,
|
| 1379 |
+
-0.002861066721379757,
|
| 1380 |
+
-1.6958495378494263,
|
| 1381 |
+
-1.5730143785476685,
|
| 1382 |
+
-2.184138298034668,
|
| 1383 |
+
0.0,
|
| 1384 |
+
-1.6771572828292847,
|
| 1385 |
+
-0.00667582219466567,
|
| 1386 |
+
-0.0032425422687083483,
|
| 1387 |
+
-1.7061493396759033,
|
| 1388 |
+
-1.6287097930908203,
|
| 1389 |
+
-2.143320322036743,
|
| 1390 |
+
0.0
|
| 1391 |
+
],
|
| 1392 |
+
"max": [
|
| 1393 |
+
1.808003306388855,
|
| 1394 |
+
3.1988632678985596,
|
| 1395 |
+
3.1507973670959473,
|
| 1396 |
+
1.592851161956787,
|
| 1397 |
+
1.5890363454818726,
|
| 1398 |
+
2.2081711292266846,
|
| 1399 |
+
1.0,
|
| 1400 |
+
2.440871238708496,
|
| 1401 |
+
3.1084535121917725,
|
| 1402 |
+
3.1530861854553223,
|
| 1403 |
+
1.6649500131607056,
|
| 1404 |
+
1.5947585105895996,
|
| 1405 |
+
2.1639199256896973,
|
| 1406 |
+
1.0
|
| 1407 |
+
],
|
| 1408 |
+
"mean": [
|
| 1409 |
+
-0.08857854148141169,
|
| 1410 |
+
1.3813960226201991,
|
| 1411 |
+
1.2242081192216245,
|
| 1412 |
+
-0.7456114034786908,
|
| 1413 |
+
0.15342910390834139,
|
| 1414 |
+
-0.2406550926649683,
|
| 1415 |
+
0.6405881969404109,
|
| 1416 |
+
0.11816370494944337,
|
| 1417 |
+
1.3440412881232742,
|
| 1418 |
+
1.1275448419933234,
|
| 1419 |
+
-0.6567647967296087,
|
| 1420 |
+
-0.15745777770921981,
|
| 1421 |
+
0.20879381691599022,
|
| 1422 |
+
0.5971762495146153
|
| 1423 |
+
],
|
| 1424 |
+
"std": [
|
| 1425 |
+
0.31549225693975164,
|
| 1426 |
+
0.7241109409894698,
|
| 1427 |
+
0.6724976443740277,
|
| 1428 |
+
0.4912531895036823,
|
| 1429 |
+
0.3766601597067631,
|
| 1430 |
+
0.3683009171682207,
|
| 1431 |
+
0.41042883365599214,
|
| 1432 |
+
0.33538355728349317,
|
| 1433 |
+
0.8035033283123882,
|
| 1434 |
+
0.7129305114483252,
|
| 1435 |
+
0.5147389512393373,
|
| 1436 |
+
0.37362261558635523,
|
| 1437 |
+
0.35878804842243267,
|
| 1438 |
+
0.42346789755808983
|
| 1439 |
+
],
|
| 1440 |
+
"count": [
|
| 1441 |
+
76046658.0
|
| 1442 |
+
],
|
| 1443 |
+
"q01": [
|
| 1444 |
+
-0.6603105582072047,
|
| 1445 |
+
0.0041340051935240115,
|
| 1446 |
+
0.013831665477596221,
|
| 1447 |
+
-1.3744044717113109,
|
| 1448 |
+
-0.3593570239425977,
|
| 1449 |
+
-0.9302641712677729,
|
| 1450 |
+
0.051016362361406005,
|
| 1451 |
+
-0.49367228465810536,
|
| 1452 |
+
0.004744360313868616,
|
| 1453 |
+
0.017154297804418434,
|
| 1454 |
+
-1.4240273823045295,
|
| 1455 |
+
-0.9737084779331572,
|
| 1456 |
+
-0.4719268433374943,
|
| 1457 |
+
0.033350514024370274
|
| 1458 |
+
],
|
| 1459 |
+
"q10": [
|
| 1460 |
+
-0.4158939180171844,
|
| 1461 |
+
0.49040349295087926,
|
| 1462 |
+
0.48318427047331663,
|
| 1463 |
+
-1.1595704371830307,
|
| 1464 |
+
-0.13299944787425266,
|
| 1465 |
+
-0.5670792130135129,
|
| 1466 |
+
0.11117863560492024,
|
| 1467 |
+
-0.19067792775434206,
|
| 1468 |
+
0.19335683280594596,
|
| 1469 |
+
0.1783492294932824,
|
| 1470 |
+
-1.165289828212844,
|
| 1471 |
+
-0.5363078842413471,
|
| 1472 |
+
-0.11410713925580458,
|
| 1473 |
+
0.054251135868839034
|
| 1474 |
+
],
|
| 1475 |
+
"q50": [
|
| 1476 |
+
-0.07347940057883112,
|
| 1477 |
+
1.4486934996424023,
|
| 1478 |
+
1.2826819985862519,
|
| 1479 |
+
-0.8018464396181274,
|
| 1480 |
+
0.11333067563787286,
|
| 1481 |
+
-0.22188306769880142,
|
| 1482 |
+
0.7333514901431821,
|
| 1483 |
+
0.08159376899519756,
|
| 1484 |
+
1.542016049355695,
|
| 1485 |
+
1.2518141457542857,
|
| 1486 |
+
-0.6816567194944295,
|
| 1487 |
+
-0.12921257250905716,
|
| 1488 |
+
0.19217648232095094,
|
| 1489 |
+
0.6965966006454063
|
| 1490 |
+
],
|
| 1491 |
+
"q90": [
|
| 1492 |
+
0.21224325405051755,
|
| 1493 |
+
2.0044457220962184,
|
| 1494 |
+
1.7599272535504926,
|
| 1495 |
+
-0.17992348512991949,
|
| 1496 |
+
0.5121005560866031,
|
| 1497 |
+
0.06588770556098025,
|
| 1498 |
+
0.9798257827982823,
|
| 1499 |
+
0.49762827627115913,
|
| 1500 |
+
2.062871328579572,
|
| 1501 |
+
1.7914606668876476,
|
| 1502 |
+
-0.07308204053490945,
|
| 1503 |
+
0.182291301998786,
|
| 1504 |
+
0.5569780500008801,
|
| 1505 |
+
0.9922195168313757
|
| 1506 |
+
],
|
| 1507 |
+
"q99": [
|
| 1508 |
+
0.4704245731743921,
|
| 1509 |
+
2.244327078820327,
|
| 1510 |
+
2.0080105207169177,
|
| 1511 |
+
0.13399061379118773,
|
| 1512 |
+
0.8834156417282395,
|
| 1513 |
+
0.334483290041328,
|
| 1514 |
+
0.987078674113364,
|
| 1515 |
+
0.7377501348730936,
|
| 1516 |
+
2.285076596429336,
|
| 1517 |
+
2.0605540868103542,
|
| 1518 |
+
0.23968854170206916,
|
| 1519 |
+
0.5304791687465945,
|
| 1520 |
+
0.9621494841801348,
|
| 1521 |
+
0.9953596816858612
|
| 1522 |
+
],
|
| 1523 |
+
"names": [
|
| 1524 |
+
"left_joint_0.pos",
|
| 1525 |
+
"left_joint_1.pos",
|
| 1526 |
+
"left_joint_2.pos",
|
| 1527 |
+
"left_joint_3.pos",
|
| 1528 |
+
"left_joint_4.pos",
|
| 1529 |
+
"left_joint_5.pos",
|
| 1530 |
+
"left_gripper.pos",
|
| 1531 |
+
"right_joint_0.pos",
|
| 1532 |
+
"right_joint_1.pos",
|
| 1533 |
+
"right_joint_2.pos",
|
| 1534 |
+
"right_joint_3.pos",
|
| 1535 |
+
"right_joint_4.pos",
|
| 1536 |
+
"right_joint_5.pos",
|
| 1537 |
+
"right_gripper.pos"
|
| 1538 |
+
],
|
| 1539 |
+
"mask": [
|
| 1540 |
+
true,
|
| 1541 |
+
true,
|
| 1542 |
+
true,
|
| 1543 |
+
true,
|
| 1544 |
+
true,
|
| 1545 |
+
true,
|
| 1546 |
+
false,
|
| 1547 |
+
true,
|
| 1548 |
+
true,
|
| 1549 |
+
true,
|
| 1550 |
+
true,
|
| 1551 |
+
true,
|
| 1552 |
+
true,
|
| 1553 |
+
false
|
| 1554 |
+
]
|
| 1555 |
+
},
|
| 1556 |
+
"state_stats": {
|
| 1557 |
+
"min": [
|
| 1558 |
+
-1.971656322479248,
|
| 1559 |
+
0.00019073777366429567,
|
| 1560 |
+
0.001716639962978661,
|
| 1561 |
+
-1.7023346424102783,
|
| 1562 |
+
-1.576829195022583,
|
| 1563 |
+
-2.0963988304138184,
|
| 1564 |
+
0.0005250918911769986,
|
| 1565 |
+
-1.6741054058074951,
|
| 1566 |
+
-0.0009536888683214784,
|
| 1567 |
+
0.004386968910694122,
|
| 1568 |
+
-1.737811803817749,
|
| 1569 |
+
-1.574158787727356,
|
| 1570 |
+
-2.0941100120544434,
|
| 1571 |
+
0.003973988350480795
|
| 1572 |
+
],
|
| 1573 |
+
"max": [
|
| 1574 |
+
1.813725471496582,
|
| 1575 |
+
3.101205348968506,
|
| 1576 |
+
3.1466009616851807,
|
| 1577 |
+
1.5821698904037476,
|
| 1578 |
+
1.6222248077392578,
|
| 1579 |
+
2.1040284633636475,
|
| 1580 |
+
0.9997128844261169,
|
| 1581 |
+
2.4343862533569336,
|
| 1582 |
+
3.11112380027771,
|
| 1583 |
+
3.1492714881896973,
|
| 1584 |
+
1.5836957693099976,
|
| 1585 |
+
1.6062028408050537,
|
| 1586 |
+
2.1452276706695557,
|
| 1587 |
+
1.0
|
| 1588 |
+
],
|
| 1589 |
+
"mean": [
|
| 1590 |
+
-0.08969431138176573,
|
| 1591 |
+
1.3833397954729871,
|
| 1592 |
+
1.2214299123909826,
|
| 1593 |
+
-0.7438162535789633,
|
| 1594 |
+
0.15467924320885904,
|
| 1595 |
+
-0.2444551331990551,
|
| 1596 |
+
0.6477599794157677,
|
| 1597 |
+
0.11772745375836342,
|
| 1598 |
+
1.3475698442605246,
|
| 1599 |
+
1.1241839262647857,
|
| 1600 |
+
-0.657754523106273,
|
| 1601 |
+
-0.16024992695882134,
|
| 1602 |
+
0.2095172679704065,
|
| 1603 |
+
0.6019240399143698
|
| 1604 |
+
],
|
| 1605 |
+
"std": [
|
| 1606 |
+
0.3152726802877428,
|
| 1607 |
+
0.7215555774539155,
|
| 1608 |
+
0.6677525379386945,
|
| 1609 |
+
0.49249044506684236,
|
| 1610 |
+
0.3669531426180722,
|
| 1611 |
+
0.36500773276171394,
|
| 1612 |
+
0.4034043094483581,
|
| 1613 |
+
0.3350780291739786,
|
| 1614 |
+
0.8015514140140498,
|
| 1615 |
+
0.7087483761552382,
|
| 1616 |
+
0.5140769455948587,
|
| 1617 |
+
0.36485948060191936,
|
| 1618 |
+
0.35558886385685473,
|
| 1619 |
+
0.4187505380995499
|
| 1620 |
+
],
|
| 1621 |
+
"count": [
|
| 1622 |
+
76046658.0
|
| 1623 |
+
],
|
| 1624 |
+
"q01": [
|
| 1625 |
+
-0.6603467782218314,
|
| 1626 |
+
0.012553692652370085,
|
| 1627 |
+
0.021776265158983142,
|
| 1628 |
+
-1.3705572057237516,
|
| 1629 |
+
-0.3332034826366618,
|
| 1630 |
+
-0.9193192400336088,
|
| 1631 |
+
0.059239047676073166,
|
| 1632 |
+
-0.4935656974138795,
|
| 1633 |
+
0.012780929401173773,
|
| 1634 |
+
0.022236669213863816,
|
| 1635 |
+
-1.4227596196972356,
|
| 1636 |
+
-0.9434528867624581,
|
| 1637 |
+
-0.4598343195103144,
|
| 1638 |
+
0.037835498581155064
|
| 1639 |
+
],
|
| 1640 |
+
"q10": [
|
| 1641 |
+
-0.41642163282166217,
|
| 1642 |
+
0.49507907198249584,
|
| 1643 |
+
0.486584320872561,
|
| 1644 |
+
-1.1582997707602973,
|
| 1645 |
+
-0.12275828541607876,
|
| 1646 |
+
-0.5663963402767317,
|
| 1647 |
+
0.1261316463154828,
|
| 1648 |
+
-0.1908506486628405,
|
| 1649 |
+
0.1993559996076043,
|
| 1650 |
+
0.18204643795012038,
|
| 1651 |
+
-1.1656159852054215,
|
| 1652 |
+
-0.5295295866303873,
|
| 1653 |
+
-0.10955673634265617,
|
| 1654 |
+
0.06449180996120647
|
| 1655 |
+
],
|
| 1656 |
+
"q50": [
|
| 1657 |
+
-0.07460956719060403,
|
| 1658 |
+
1.4518741988602484,
|
| 1659 |
+
1.2790339607814287,
|
| 1660 |
+
-0.8004009901188069,
|
| 1661 |
+
0.11633919425925929,
|
| 1662 |
+
-0.2256239564587,
|
| 1663 |
+
0.7410515739838786,
|
| 1664 |
+
0.08125296212737787,
|
| 1665 |
+
1.546374492933441,
|
| 1666 |
+
1.2473645258976782,
|
| 1667 |
+
-0.6826830989658852,
|
| 1668 |
+
-0.13268823647237576,
|
| 1669 |
+
0.19324335769817771,
|
| 1670 |
+
0.6975293719700979
|
| 1671 |
+
],
|
| 1672 |
+
"q90": [
|
| 1673 |
+
0.21065792154366084,
|
| 1674 |
+
2.001783199519663,
|
| 1675 |
+
1.7536904322237028,
|
| 1676 |
+
-0.17570477043577734,
|
| 1677 |
+
0.5016373270395832,
|
| 1678 |
+
0.057081945863381375,
|
| 1679 |
+
0.9793483311612012,
|
| 1680 |
+
0.496661138954089,
|
| 1681 |
+
2.0633422575822404,
|
| 1682 |
+
1.784104252873167,
|
| 1683 |
+
-0.07449674242785952,
|
| 1684 |
+
0.17045548433242785,
|
| 1685 |
+
0.5532139533377123,
|
| 1686 |
+
0.9916430884848699
|
| 1687 |
+
],
|
| 1688 |
+
"q99": [
|
| 1689 |
+
0.4683004661020414,
|
| 1690 |
+
2.2309715341843326,
|
| 1691 |
+
1.9982285068319416,
|
| 1692 |
+
0.13319204881075056,
|
| 1693 |
+
0.8574646079271142,
|
| 1694 |
+
0.31881311685642116,
|
| 1695 |
+
0.9862640952345495,
|
| 1696 |
+
0.736253091937041,
|
| 1697 |
+
2.276675221510269,
|
| 1698 |
+
2.0496951704229227,
|
| 1699 |
+
0.23446313153252643,
|
| 1700 |
+
0.503194049828884,
|
| 1701 |
+
0.9489437100128476,
|
| 1702 |
+
0.9945109907992316
|
| 1703 |
+
],
|
| 1704 |
+
"names": [
|
| 1705 |
+
"left_joint_0.pos",
|
| 1706 |
+
"left_joint_1.pos",
|
| 1707 |
+
"left_joint_2.pos",
|
| 1708 |
+
"left_joint_3.pos",
|
| 1709 |
+
"left_joint_4.pos",
|
| 1710 |
+
"left_joint_5.pos",
|
| 1711 |
+
"left_gripper.pos",
|
| 1712 |
+
"right_joint_0.pos",
|
| 1713 |
+
"right_joint_1.pos",
|
| 1714 |
+
"right_joint_2.pos",
|
| 1715 |
+
"right_joint_3.pos",
|
| 1716 |
+
"right_joint_4.pos",
|
| 1717 |
+
"right_joint_5.pos",
|
| 1718 |
+
"right_gripper.pos"
|
| 1719 |
+
],
|
| 1720 |
+
"mask": [
|
| 1721 |
+
true,
|
| 1722 |
+
true,
|
| 1723 |
+
true,
|
| 1724 |
+
true,
|
| 1725 |
+
true,
|
| 1726 |
+
true,
|
| 1727 |
+
false,
|
| 1728 |
+
true,
|
| 1729 |
+
true,
|
| 1730 |
+
true,
|
| 1731 |
+
true,
|
| 1732 |
+
true,
|
| 1733 |
+
true,
|
| 1734 |
+
false
|
| 1735 |
+
]
|
| 1736 |
+
}
|
| 1737 |
+
}
|
| 1738 |
+
}
|
| 1739 |
+
}
|
processing_molmoact2.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Processor class for MolmoAct2.
|
| 3 |
+
"""
|
| 4 |
+
from typing import Optional, Union
|
| 5 |
+
import dataclasses
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from transformers.image_utils import ImageInput
|
| 10 |
+
from transformers.video_utils import VideoInput
|
| 11 |
+
from transformers.processing_utils import (
|
| 12 |
+
Unpack,
|
| 13 |
+
ProcessingKwargs,
|
| 14 |
+
ProcessorMixin,
|
| 15 |
+
)
|
| 16 |
+
from transformers.feature_extraction_utils import BatchFeature
|
| 17 |
+
from transformers.tokenization_utils_base import TextInput, PreTokenizedInput
|
| 18 |
+
from transformers.utils import logging
|
| 19 |
+
|
| 20 |
+
from transformers import AutoTokenizer
|
| 21 |
+
from .image_processing_molmoact2 import MolmoAct2ImagesKwargs, MolmoAct2ImageProcessor
|
| 22 |
+
from .video_processing_molmoact2 import MolmoAct2VideoProcessorKwargs, MolmoAct2VideoProcessor
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# Special tokens, these should be present in any tokenizer we use since the preprocessor uses them
|
| 29 |
+
IMAGE_PATCH_TOKEN = f"<im_patch>" # Where to insert high-res tokens
|
| 30 |
+
IMAGE_LOW_RES_TOKEN = f"<im_low>" # Where to insert low-res tokens
|
| 31 |
+
IM_START_TOKEN = f"<im_start>"
|
| 32 |
+
LOW_RES_IMAGE_START_TOKEN = f"<low_res_im_start>"
|
| 33 |
+
FRAME_START_TOKEN = f"<frame_start>"
|
| 34 |
+
IM_END_TOKEN = f"<im_end>"
|
| 35 |
+
FRAME_END_TOKEN= f"<frame_end>"
|
| 36 |
+
IM_COL_TOKEN = f"<im_col>"
|
| 37 |
+
IMAGE_PROMPT = "<|image|>"
|
| 38 |
+
VIDEO_PROMPT = "<|video|>"
|
| 39 |
+
|
| 40 |
+
IMAGE_TOKENS = [
|
| 41 |
+
IMAGE_PATCH_TOKEN,
|
| 42 |
+
IM_COL_TOKEN,
|
| 43 |
+
IM_START_TOKEN,
|
| 44 |
+
LOW_RES_IMAGE_START_TOKEN,
|
| 45 |
+
FRAME_START_TOKEN,
|
| 46 |
+
IM_END_TOKEN,
|
| 47 |
+
FRAME_END_TOKEN,
|
| 48 |
+
IMAGE_LOW_RES_TOKEN,
|
| 49 |
+
]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class MolmoAct2ProcessorKwargs(ProcessingKwargs, total=False):
|
| 53 |
+
"""MolmoAct2 processor kwargs"""
|
| 54 |
+
images_kwargs: MolmoAct2ImagesKwargs
|
| 55 |
+
videos_kwargs: MolmoAct2VideoProcessorKwargs
|
| 56 |
+
_defaults = {
|
| 57 |
+
"text_kwargs": {
|
| 58 |
+
"padding": False,
|
| 59 |
+
"return_mm_token_type_ids": True,
|
| 60 |
+
},
|
| 61 |
+
"videos_kwargs": {"return_metadata": True},
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class MolmoAct2Processor(ProcessorMixin):
|
| 66 |
+
attributes = ["image_processor", "video_processor", "tokenizer"]
|
| 67 |
+
optional_attributes = [
|
| 68 |
+
"chat_template",
|
| 69 |
+
"time_mode",
|
| 70 |
+
"image_use_col_tokens",
|
| 71 |
+
"use_single_crop_col_tokens",
|
| 72 |
+
"use_single_crop_start_token",
|
| 73 |
+
"video_use_col_tokens",
|
| 74 |
+
"use_frame_special_tokens",
|
| 75 |
+
]
|
| 76 |
+
image_processor_class = "AutoImageProcessor"
|
| 77 |
+
video_processor_class = "AutoVideoProcessor"
|
| 78 |
+
tokenizer_class = "AutoTokenizer"
|
| 79 |
+
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
image_processor: MolmoAct2ImageProcessor = None,
|
| 83 |
+
video_processor: MolmoAct2VideoProcessor = None,
|
| 84 |
+
tokenizer: AutoTokenizer = None,
|
| 85 |
+
chat_template: Optional[str] = None,
|
| 86 |
+
image_use_col_tokens: Optional[bool] = True,
|
| 87 |
+
use_single_crop_col_tokens: Optional[bool] = None,
|
| 88 |
+
use_single_crop_start_token: Optional[bool] = True,
|
| 89 |
+
video_use_col_tokens: Optional[bool] = False,
|
| 90 |
+
use_frame_special_tokens: Optional[bool] = True,
|
| 91 |
+
**kwargs
|
| 92 |
+
) -> None:
|
| 93 |
+
super().__init__(
|
| 94 |
+
image_processor,
|
| 95 |
+
video_processor,
|
| 96 |
+
tokenizer,
|
| 97 |
+
chat_template=chat_template,
|
| 98 |
+
)
|
| 99 |
+
self.image_use_col_tokens = image_use_col_tokens
|
| 100 |
+
self.use_single_crop_col_tokens = use_single_crop_col_tokens
|
| 101 |
+
self.use_single_crop_start_token = use_single_crop_start_token
|
| 102 |
+
self.video_use_col_tokens = video_use_col_tokens
|
| 103 |
+
self.use_frame_special_tokens = use_frame_special_tokens
|
| 104 |
+
|
| 105 |
+
self.image_placeholder_token = IMAGE_PROMPT
|
| 106 |
+
self.video_placeholder_token = VIDEO_PROMPT
|
| 107 |
+
self.image_token_ids = [
|
| 108 |
+
tokenizer.convert_tokens_to_ids(token)
|
| 109 |
+
for token in IMAGE_TOKENS
|
| 110 |
+
]
|
| 111 |
+
|
| 112 |
+
def get_image_tokens(self, image_grid: np.ndarray):
|
| 113 |
+
resized_h, resized_w, height, width = image_grid
|
| 114 |
+
if int(height) == 0 or int(width) == 0:
|
| 115 |
+
per_row = np.full(resized_w, IMAGE_PATCH_TOKEN)
|
| 116 |
+
use_single_crop_col_tokens = (
|
| 117 |
+
self.image_use_col_tokens
|
| 118 |
+
if self.use_single_crop_col_tokens is None
|
| 119 |
+
else self.use_single_crop_col_tokens
|
| 120 |
+
)
|
| 121 |
+
if use_single_crop_col_tokens:
|
| 122 |
+
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
|
| 123 |
+
joint = [
|
| 124 |
+
[IM_START_TOKEN],
|
| 125 |
+
np.tile(per_row, [resized_h]),
|
| 126 |
+
[IM_END_TOKEN],
|
| 127 |
+
]
|
| 128 |
+
return np.concatenate(joint)
|
| 129 |
+
per_row = np.full(width, IMAGE_PATCH_TOKEN)
|
| 130 |
+
if self.image_use_col_tokens:
|
| 131 |
+
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
|
| 132 |
+
joint = [
|
| 133 |
+
[IM_START_TOKEN],
|
| 134 |
+
np.tile(per_row, [height]),
|
| 135 |
+
[IM_END_TOKEN],
|
| 136 |
+
]
|
| 137 |
+
per_row = np.full(resized_w, IMAGE_PATCH_TOKEN)
|
| 138 |
+
use_single_crop_col_tokens = (
|
| 139 |
+
self.image_use_col_tokens
|
| 140 |
+
if self.use_single_crop_col_tokens is None
|
| 141 |
+
else self.use_single_crop_col_tokens
|
| 142 |
+
)
|
| 143 |
+
image_start_token = (
|
| 144 |
+
LOW_RES_IMAGE_START_TOKEN
|
| 145 |
+
if self.use_single_crop_start_token
|
| 146 |
+
else IM_START_TOKEN
|
| 147 |
+
)
|
| 148 |
+
if use_single_crop_col_tokens:
|
| 149 |
+
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
|
| 150 |
+
joint = [
|
| 151 |
+
[image_start_token],
|
| 152 |
+
np.tile(per_row, [resized_h]),
|
| 153 |
+
[IM_END_TOKEN],
|
| 154 |
+
] + joint
|
| 155 |
+
|
| 156 |
+
return np.concatenate(joint)
|
| 157 |
+
|
| 158 |
+
def get_video_string(
|
| 159 |
+
self,
|
| 160 |
+
video_grid: np.ndarray,
|
| 161 |
+
timestamps: np.ndarray,
|
| 162 |
+
):
|
| 163 |
+
if self.use_frame_special_tokens:
|
| 164 |
+
start_token_id = FRAME_START_TOKEN
|
| 165 |
+
end_token_id = FRAME_END_TOKEN
|
| 166 |
+
else:
|
| 167 |
+
start_token_id = IM_START_TOKEN
|
| 168 |
+
end_token_id = IM_END_TOKEN
|
| 169 |
+
|
| 170 |
+
num_frames, h, w = video_grid
|
| 171 |
+
video_string: str = ""
|
| 172 |
+
for frame_idx, frame_time in enumerate(timestamps):
|
| 173 |
+
# `per-frame-compact` time mode
|
| 174 |
+
prev_space = " " if frame_idx > 0 else ""
|
| 175 |
+
frame_prefix = prev_space + f"{frame_time:.1f} " # explicit whitespace before/after image tokens
|
| 176 |
+
|
| 177 |
+
video_string += frame_prefix
|
| 178 |
+
per_row = np.full(w, IMAGE_PATCH_TOKEN)
|
| 179 |
+
if self.video_use_col_tokens:
|
| 180 |
+
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
|
| 181 |
+
extra_tokens = np.tile(per_row, [h])
|
| 182 |
+
video_tokens = [
|
| 183 |
+
[start_token_id],
|
| 184 |
+
extra_tokens,
|
| 185 |
+
[end_token_id],
|
| 186 |
+
]
|
| 187 |
+
video_string += "".join(np.concatenate(video_tokens, 0))
|
| 188 |
+
|
| 189 |
+
return video_string
|
| 190 |
+
|
| 191 |
+
def insert_bos(
|
| 192 |
+
self,
|
| 193 |
+
input_ids: np.ndarray,
|
| 194 |
+
attention_mask: np.ndarray,
|
| 195 |
+
bos_token_id: int,
|
| 196 |
+
pad_token_id: int,
|
| 197 |
+
):
|
| 198 |
+
"""
|
| 199 |
+
Args:
|
| 200 |
+
input_ids: [B, S] array with left padding
|
| 201 |
+
attention_mask: [B, S] array (0 for pad, 1 for valid)
|
| 202 |
+
bos_token_id: int
|
| 203 |
+
pad_token_id: int
|
| 204 |
+
Returns:
|
| 205 |
+
input_ids_out: [B, S] or [B, S+1] array with bos inserted if needed
|
| 206 |
+
attention_mask_out: same shape as input_ids_out
|
| 207 |
+
"""
|
| 208 |
+
|
| 209 |
+
need_to_expand = len(input_ids.shape) == 1
|
| 210 |
+
if need_to_expand:
|
| 211 |
+
input_ids = input_ids[None, :]
|
| 212 |
+
attention_mask = attention_mask[None, :]
|
| 213 |
+
|
| 214 |
+
B, S = input_ids.shape
|
| 215 |
+
|
| 216 |
+
# Handle zero-length sequence
|
| 217 |
+
if S == 0:
|
| 218 |
+
new_input_ids = np.full((B, 1), bos_token_id, dtype=input_ids.dtype)
|
| 219 |
+
new_attention_mask = np.ones((B, 1), dtype=attention_mask.dtype)
|
| 220 |
+
if need_to_expand:
|
| 221 |
+
new_input_ids = new_input_ids[0]
|
| 222 |
+
new_attention_mask = new_attention_mask[0]
|
| 223 |
+
return new_input_ids, new_attention_mask
|
| 224 |
+
|
| 225 |
+
first_valid_index = (attention_mask == 1).argmax(axis=-1) # [B]
|
| 226 |
+
bos_already_present = np.all(input_ids[np.arange(B), first_valid_index] == bos_token_id)
|
| 227 |
+
|
| 228 |
+
if bos_already_present:
|
| 229 |
+
if need_to_expand:
|
| 230 |
+
input_ids = input_ids[0]
|
| 231 |
+
attention_mask = attention_mask[0]
|
| 232 |
+
return input_ids, attention_mask
|
| 233 |
+
else:
|
| 234 |
+
new_input_ids = np.full((B, S+1), pad_token_id, dtype=input_ids.dtype)
|
| 235 |
+
new_attention_mask = np.zeros((B, S+1), dtype=attention_mask.dtype)
|
| 236 |
+
|
| 237 |
+
src_idx = np.tile(np.arange(S), (B, 1)) # [B, S]
|
| 238 |
+
valid_mask = src_idx >= first_valid_index[:, None] # [B, S]
|
| 239 |
+
tgt_idx = src_idx + 1 # shit right
|
| 240 |
+
batch_idx = np.tile(np.arange(B)[:, None], (1, S)) # [B, S]
|
| 241 |
+
|
| 242 |
+
# flatten valid_positions
|
| 243 |
+
flat_vals = input_ids[valid_mask]
|
| 244 |
+
flat_batch = batch_idx[valid_mask]
|
| 245 |
+
flat_tgt = tgt_idx[valid_mask]
|
| 246 |
+
|
| 247 |
+
new_input_ids[flat_batch, flat_tgt] = flat_vals
|
| 248 |
+
new_attention_mask[flat_batch, flat_tgt] = 1
|
| 249 |
+
|
| 250 |
+
insert_pos = first_valid_index
|
| 251 |
+
new_input_ids[np.arange(B), insert_pos] = bos_token_id
|
| 252 |
+
new_attention_mask[np.arange(B), insert_pos] = 1
|
| 253 |
+
|
| 254 |
+
if need_to_expand:
|
| 255 |
+
new_input_ids = new_input_ids[0]
|
| 256 |
+
new_attention_mask = new_attention_mask[0]
|
| 257 |
+
|
| 258 |
+
return new_input_ids, new_attention_mask
|
| 259 |
+
|
| 260 |
+
def __call__(
|
| 261 |
+
self,
|
| 262 |
+
text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
|
| 263 |
+
images: ImageInput = None,
|
| 264 |
+
videos: VideoInput = None,
|
| 265 |
+
**kwargs: Unpack[MolmoAct2ProcessorKwargs],
|
| 266 |
+
) -> BatchFeature:
|
| 267 |
+
"""
|
| 268 |
+
|
| 269 |
+
Args:
|
| 270 |
+
text (`str`, `list[str]`, `list[list[str]]`):
|
| 271 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
| 272 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
| 273 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
| 274 |
+
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
|
| 275 |
+
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
| 276 |
+
tensor. Both channels-first and channels-last formats are supported.
|
| 277 |
+
videos (`dict[str, Any]` or `list[dict[str, Any]]`):
|
| 278 |
+
The video or batch of videos to be prepared. Each video can be a dictionary with the following keys:
|
| 279 |
+
- `"frames"`: `np.ndarray` of shape (T, H, W, 3)
|
| 280 |
+
- `"timestamps"`: `np.ndarray` of shape (T,)
|
| 281 |
+
- `"sampled_fps"`: `float` (optional)
|
| 282 |
+
- `"sampling_augmentation"`: `str` (optional)
|
| 283 |
+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
| 284 |
+
If set, will return tensors of a particular framework. Acceptable values are:
|
| 285 |
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
| 286 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
| 287 |
+
- `'np'`: Return NumPy `np.ndarray` objects.
|
| 288 |
+
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
`BatchFeature`: A [`BatchFeature`] with the following fields:
|
| 292 |
+
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
| 293 |
+
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
| 294 |
+
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`).
|
| 295 |
+
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
| 296 |
+
- **image_token_pooling** -- Indices of the patches in `image_grids` to pool for each token in `image_tokens`.
|
| 297 |
+
Returned when `images` is not `None`.
|
| 298 |
+
- **image_grids** -- Grids of images. Returned when `images` is not `None`.
|
| 299 |
+
- **image_num_crops** -- Number of crops for each image. Returned when `images` is not `None`.
|
| 300 |
+
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
|
| 301 |
+
- **video_token_pooling** -- Indices of the patches in `video_grids` to pool for each token in `video_tokens`.
|
| 302 |
+
Returned when `videos` is not `None`.
|
| 303 |
+
- **video_grids** -- Grids of videos. Returned when `videos` is not `None`.
|
| 304 |
+
"""
|
| 305 |
+
|
| 306 |
+
output_kwargs = self._merge_kwargs(
|
| 307 |
+
MolmoAct2ProcessorKwargs,
|
| 308 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
| 309 |
+
**kwargs,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
if images is not None:
|
| 313 |
+
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
|
| 314 |
+
image_grids = image_inputs["image_grids"]
|
| 315 |
+
else:
|
| 316 |
+
image_inputs = {}
|
| 317 |
+
image_grids = None
|
| 318 |
+
|
| 319 |
+
if videos is not None:
|
| 320 |
+
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
|
| 321 |
+
video_grids = videos_inputs["video_grids"]
|
| 322 |
+
# If user has not requested video metadata, pop it
|
| 323 |
+
if "return_metadata" not in kwargs:
|
| 324 |
+
video_metadata = videos_inputs.pop("video_metadata")
|
| 325 |
+
else:
|
| 326 |
+
video_metadata = videos_inputs["video_metadata"]
|
| 327 |
+
else:
|
| 328 |
+
videos_inputs = {}
|
| 329 |
+
video_grids = None
|
| 330 |
+
|
| 331 |
+
if not isinstance(text, list):
|
| 332 |
+
text = [text]
|
| 333 |
+
|
| 334 |
+
text = text.copy() # below lines change text in-place
|
| 335 |
+
|
| 336 |
+
if image_grids is not None:
|
| 337 |
+
index = 0
|
| 338 |
+
for i in range(len(text)):
|
| 339 |
+
num_images = text[i].count(self.image_placeholder_token)
|
| 340 |
+
image_grids_i = image_grids[index:index+num_images]
|
| 341 |
+
for image_grid in image_grids_i:
|
| 342 |
+
image_tokens = self.get_image_tokens(image_grid)
|
| 343 |
+
image_string = "".join(image_tokens)
|
| 344 |
+
text[i] = text[i].replace(self.image_placeholder_token, image_string, 1)
|
| 345 |
+
index += num_images
|
| 346 |
+
|
| 347 |
+
if video_grids is not None:
|
| 348 |
+
index = 0
|
| 349 |
+
for i in range(len(text)):
|
| 350 |
+
num_videos = text[i].count(self.video_placeholder_token)
|
| 351 |
+
assert num_videos in {0, 1}, "At most one video is supported for now"
|
| 352 |
+
video_grids_i = video_grids[index:index+num_videos]
|
| 353 |
+
metadata_i = video_metadata[index:index+num_videos]
|
| 354 |
+
for video_grid, metadata in zip(video_grids_i, metadata_i):
|
| 355 |
+
video_string = self.get_video_string(
|
| 356 |
+
video_grid,
|
| 357 |
+
metadata.timestamps,
|
| 358 |
+
)
|
| 359 |
+
text[i] = text[i].replace(self.video_placeholder_token, video_string, 1)
|
| 360 |
+
index += num_videos
|
| 361 |
+
|
| 362 |
+
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
| 363 |
+
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
|
| 364 |
+
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
| 365 |
+
|
| 366 |
+
input_ids = text_inputs["input_ids"]
|
| 367 |
+
attention_mask = text_inputs["attention_mask"]
|
| 368 |
+
|
| 369 |
+
input_ids = np.array(input_ids)
|
| 370 |
+
attention_mask = np.array(attention_mask)
|
| 371 |
+
|
| 372 |
+
bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
|
| 373 |
+
input_ids, attention_mask = self.insert_bos(
|
| 374 |
+
input_ids, attention_mask, bos, self.tokenizer.pad_token_id
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
if return_mm_token_type_ids:
|
| 378 |
+
image_tokens = np.array(self.image_token_ids).astype(input_ids.dtype)
|
| 379 |
+
token_type_ids = np.any(input_ids[:, :, None] == image_tokens[None, None, :], axis=-1)
|
| 380 |
+
text_inputs["token_type_ids"] = token_type_ids.tolist()
|
| 381 |
+
|
| 382 |
+
text_inputs["input_ids"] = input_ids.tolist()
|
| 383 |
+
text_inputs["attention_mask"] = attention_mask.tolist()
|
| 384 |
+
|
| 385 |
+
return BatchFeature(
|
| 386 |
+
data={**text_inputs, **image_inputs, **videos_inputs},
|
| 387 |
+
tensor_type=return_tensors,
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
def post_process_image_text_to_text(
|
| 391 |
+
self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
|
| 392 |
+
):
|
| 393 |
+
"""
|
| 394 |
+
Post-process the output of the model to decode the text.
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
generated_outputs (`torch.Tensor` or `np.ndarray`):
|
| 398 |
+
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
|
| 399 |
+
or `(sequence_length,)`.
|
| 400 |
+
skip_special_tokens (`bool`, *optional*, defaults to `True`):
|
| 401 |
+
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
|
| 402 |
+
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
|
| 403 |
+
Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
|
| 404 |
+
**kwargs:
|
| 405 |
+
Additional arguments to be passed to the tokenizer's `batch_decode method`.
|
| 406 |
+
|
| 407 |
+
Returns:
|
| 408 |
+
`list[str]`: The decoded text.
|
| 409 |
+
"""
|
| 410 |
+
return self.tokenizer.batch_decode(
|
| 411 |
+
generated_outputs,
|
| 412 |
+
skip_special_tokens=skip_special_tokens,
|
| 413 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 414 |
+
**kwargs,
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
MolmoAct2Processor.register_for_auto_class()
|
processor_config.json
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"auto_map": {
|
| 3 |
+
"AutoProcessor": "processing_molmoact2.MolmoAct2Processor"
|
| 4 |
+
},
|
| 5 |
+
"image_processor": {
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoImageProcessor": "image_processing_molmoact2.MolmoAct2ImageProcessor",
|
| 8 |
+
"AutoProcessor": "processing_molmoact2.MolmoAct2Processor"
|
| 9 |
+
},
|
| 10 |
+
"crop_mode": "resize",
|
| 11 |
+
"do_convert_rgb": true,
|
| 12 |
+
"image_mean": [
|
| 13 |
+
0.5,
|
| 14 |
+
0.5,
|
| 15 |
+
0.5
|
| 16 |
+
],
|
| 17 |
+
"image_processor_type": "MolmoAct2ImageProcessor",
|
| 18 |
+
"image_std": [
|
| 19 |
+
0.5,
|
| 20 |
+
0.5,
|
| 21 |
+
0.5
|
| 22 |
+
],
|
| 23 |
+
"max_crops": 8,
|
| 24 |
+
"overlap_margins": [
|
| 25 |
+
4,
|
| 26 |
+
4
|
| 27 |
+
],
|
| 28 |
+
"patch_size": 14,
|
| 29 |
+
"pooling_size": [
|
| 30 |
+
2,
|
| 31 |
+
2
|
| 32 |
+
],
|
| 33 |
+
"resample": 2,
|
| 34 |
+
"size": {
|
| 35 |
+
"height": 378,
|
| 36 |
+
"width": 378
|
| 37 |
+
}
|
| 38 |
+
},
|
| 39 |
+
"image_use_col_tokens": true,
|
| 40 |
+
"processor_class": "MolmoAct2Processor",
|
| 41 |
+
"use_frame_special_tokens": true,
|
| 42 |
+
"use_single_crop_col_tokens": false,
|
| 43 |
+
"use_single_crop_start_token": true,
|
| 44 |
+
"video_processor": {
|
| 45 |
+
"auto_map": {
|
| 46 |
+
"AutoProcessor": "processing_molmoact2.MolmoAct2Processor",
|
| 47 |
+
"AutoVideoProcessor": "video_processing_molmoact2.MolmoAct2VideoProcessor"
|
| 48 |
+
},
|
| 49 |
+
"data_format": "channels_first",
|
| 50 |
+
"default_to_square": true,
|
| 51 |
+
"do_convert_rgb": true,
|
| 52 |
+
"do_normalize": true,
|
| 53 |
+
"do_rescale": true,
|
| 54 |
+
"do_resize": true,
|
| 55 |
+
"do_sample_frames": true,
|
| 56 |
+
"frame_sample_mode": "uniform_last_frame",
|
| 57 |
+
"image_mean": [
|
| 58 |
+
0.5,
|
| 59 |
+
0.5,
|
| 60 |
+
0.5
|
| 61 |
+
],
|
| 62 |
+
"image_std": [
|
| 63 |
+
0.5,
|
| 64 |
+
0.5,
|
| 65 |
+
0.5
|
| 66 |
+
],
|
| 67 |
+
"max_fps": 2.0,
|
| 68 |
+
"num_frames": 8,
|
| 69 |
+
"patch_size": 14,
|
| 70 |
+
"pooling_size": [
|
| 71 |
+
3,
|
| 72 |
+
3
|
| 73 |
+
],
|
| 74 |
+
"resample": 2,
|
| 75 |
+
"rescale_factor": 0.00392156862745098,
|
| 76 |
+
"return_metadata": false,
|
| 77 |
+
"sampling_fps": 2,
|
| 78 |
+
"size": {
|
| 79 |
+
"height": 378,
|
| 80 |
+
"width": 378
|
| 81 |
+
},
|
| 82 |
+
"video_processor_type": "MolmoAct2VideoProcessor"
|
| 83 |
+
},
|
| 84 |
+
"video_use_col_tokens": false
|
| 85 |
+
}
|
tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8b6aeec78de2b0c7e95d7ae9d71cd04eba3d57351045a86c95520730e9c80d83
|
| 3 |
+
size 12176547
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": false,
|
| 3 |
+
"auto_map": {
|
| 4 |
+
"AutoProcessor": "processing_molmoact2.MolmoAct2Processor"
|
| 5 |
+
},
|
| 6 |
+
"backend": "tokenizers",
|
| 7 |
+
"bos_token": "<|im_end|>",
|
| 8 |
+
"clean_up_tokenization_spaces": false,
|
| 9 |
+
"eos_token": "<|im_end|>",
|
| 10 |
+
"errors": "replace",
|
| 11 |
+
"extra_special_tokens": [
|
| 12 |
+
"<im_start>",
|
| 13 |
+
"<im_end>",
|
| 14 |
+
"<im_patch>",
|
| 15 |
+
"<im_col>",
|
| 16 |
+
"<low_res_im_start>",
|
| 17 |
+
"<|image|>",
|
| 18 |
+
"<im_low>",
|
| 19 |
+
"<frame_start>",
|
| 20 |
+
"<frame_end>",
|
| 21 |
+
"<|video|>",
|
| 22 |
+
"<|points|>",
|
| 23 |
+
"<|token_index|>",
|
| 24 |
+
"<|vit_index|>",
|
| 25 |
+
"<|vit_loc|>"
|
| 26 |
+
],
|
| 27 |
+
"is_local": false,
|
| 28 |
+
"model_max_length": 1010000,
|
| 29 |
+
"pad_token": "<|endoftext|>",
|
| 30 |
+
"processor_class": "MolmoAct2Processor",
|
| 31 |
+
"split_special_tokens": false,
|
| 32 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
| 33 |
+
"unk_token": null
|
| 34 |
+
}
|
video_processing_molmoact2.py
ADDED
|
@@ -0,0 +1,969 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Video processor class for MolmoAct2"""
|
| 2 |
+
from functools import partial
|
| 3 |
+
import os
|
| 4 |
+
import warnings
|
| 5 |
+
from contextlib import redirect_stdout
|
| 6 |
+
from io import BytesIO
|
| 7 |
+
from urllib.parse import urlparse
|
| 8 |
+
from typing import Optional, Union, Callable
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import requests
|
| 12 |
+
import einops
|
| 13 |
+
import torch
|
| 14 |
+
import torchvision.transforms
|
| 15 |
+
|
| 16 |
+
from transformers.image_utils import (
|
| 17 |
+
IMAGENET_STANDARD_MEAN,
|
| 18 |
+
IMAGENET_STANDARD_STD,
|
| 19 |
+
ImageInput,
|
| 20 |
+
PILImageResampling,
|
| 21 |
+
SizeDict,
|
| 22 |
+
validate_kwargs,
|
| 23 |
+
)
|
| 24 |
+
from transformers.video_utils import (
|
| 25 |
+
VideoInput,
|
| 26 |
+
is_valid_video,
|
| 27 |
+
make_batched_videos,
|
| 28 |
+
make_batched_metadata,
|
| 29 |
+
VideoMetadata,
|
| 30 |
+
)
|
| 31 |
+
from transformers.processing_utils import Unpack, VideosKwargs
|
| 32 |
+
from transformers.video_processing_utils import BaseVideoProcessor
|
| 33 |
+
from transformers.utils import logging
|
| 34 |
+
from transformers.feature_extraction_utils import BatchFeature
|
| 35 |
+
from transformers.utils import (
|
| 36 |
+
is_av_available,
|
| 37 |
+
is_decord_available,
|
| 38 |
+
is_torchcodec_available,
|
| 39 |
+
is_yt_dlp_available,
|
| 40 |
+
TensorType,
|
| 41 |
+
logging,
|
| 42 |
+
to_numpy,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
logger = logging.get_logger(__name__)
|
| 47 |
+
|
| 48 |
+
MAX_VIDEO_FPS = 8
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def normalize_image(
|
| 52 |
+
image: np.ndarray,
|
| 53 |
+
image_mean: list[float],
|
| 54 |
+
image_std: list[float],
|
| 55 |
+
) -> np.ndarray:
|
| 56 |
+
if np.allclose(image_mean, [0.5, 0.5, 0.5]) and np.allclose(image_std, [0.5, 0.5, 0.5]):
|
| 57 |
+
return image * np.asarray(2.0, dtype=np.float32) - np.asarray(1.0, dtype=np.float32)
|
| 58 |
+
image -= np.array(image_mean, dtype=np.float32)[None, None, :]
|
| 59 |
+
image /= np.array(image_std, dtype=np.float32)[None, None, :]
|
| 60 |
+
return image
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def resize_image(
|
| 64 |
+
image: np.ndarray,
|
| 65 |
+
desired_output_size: list[int],
|
| 66 |
+
resample: PILImageResampling,
|
| 67 |
+
) -> np.ndarray:
|
| 68 |
+
if len(image.shape) == 3:
|
| 69 |
+
is_video = False
|
| 70 |
+
image = torch.permute(torch.from_numpy(image), [2, 0, 1])
|
| 71 |
+
else:
|
| 72 |
+
is_video = True
|
| 73 |
+
image = torch.permute(torch.from_numpy(image), [0, 3, 1, 2])
|
| 74 |
+
dtype = image.dtype
|
| 75 |
+
if torch.is_floating_point(image):
|
| 76 |
+
in_min = 0.0
|
| 77 |
+
in_max = 1.0
|
| 78 |
+
resized = torchvision.transforms.Resize(
|
| 79 |
+
desired_output_size,
|
| 80 |
+
resample,
|
| 81 |
+
antialias=False,
|
| 82 |
+
)(image)
|
| 83 |
+
resized = torch.clip(resized, 0.0, 1.0).to(dtype)
|
| 84 |
+
else:
|
| 85 |
+
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(image.dtype)
|
| 86 |
+
in_min = 0.0
|
| 87 |
+
in_max = 255.0
|
| 88 |
+
resized = torchvision.transforms.Resize(
|
| 89 |
+
desired_output_size,
|
| 90 |
+
resample,
|
| 91 |
+
antialias=False,
|
| 92 |
+
)(image)
|
| 93 |
+
resized = torch.clip(resized, 0, 255).to(dtype)
|
| 94 |
+
|
| 95 |
+
resized = resized.to(torch.float32)
|
| 96 |
+
resized = (resized - in_min) / (in_max - in_min)
|
| 97 |
+
|
| 98 |
+
if is_video:
|
| 99 |
+
resized = torch.permute(resized, [0, 2, 3, 1]).numpy()
|
| 100 |
+
else:
|
| 101 |
+
resized = torch.permute(resized, [1, 2, 0]).numpy()
|
| 102 |
+
|
| 103 |
+
return resized
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def build_resized_image(
|
| 107 |
+
image: np.ndarray,
|
| 108 |
+
base_image_input_size: list[int],
|
| 109 |
+
resample: PILImageResampling,
|
| 110 |
+
image_mean: list[float],
|
| 111 |
+
image_std: list[float],
|
| 112 |
+
image_patch_size: int,
|
| 113 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
| 114 |
+
resized = resize_image(
|
| 115 |
+
image, base_image_input_size, resample,
|
| 116 |
+
)
|
| 117 |
+
resized = normalize_image(resized, image_mean, image_std)
|
| 118 |
+
if len(resized.shape) == 3:
|
| 119 |
+
resized = np.expand_dims(resized, 0)
|
| 120 |
+
crop_patch_w = base_image_input_size[1] // image_patch_size
|
| 121 |
+
crop_patch_h = base_image_input_size[0] // image_patch_size
|
| 122 |
+
resize_idx = np.arange(crop_patch_w*crop_patch_h).reshape([crop_patch_h, crop_patch_w])
|
| 123 |
+
return resized, resize_idx
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
|
| 127 |
+
"""Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
|
| 128 |
+
if len(array.shape) == 3:
|
| 129 |
+
n_crops, h, w = array.shape
|
| 130 |
+
h_patches = h//patch_size
|
| 131 |
+
w_patches = w//patch_size
|
| 132 |
+
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
|
| 133 |
+
array = np.transpose(array, [0, 1, 3, 2, 4])
|
| 134 |
+
array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size])
|
| 135 |
+
return array
|
| 136 |
+
else:
|
| 137 |
+
n_crops, h, w, c = array.shape
|
| 138 |
+
h_patches = h//patch_size
|
| 139 |
+
w_patches = w//patch_size
|
| 140 |
+
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
|
| 141 |
+
array = np.transpose(array, [0, 1, 3, 2, 4, 5])
|
| 142 |
+
array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size*c])
|
| 143 |
+
return array
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def arange_for_pooling(
|
| 147 |
+
idx_arr: np.ndarray,
|
| 148 |
+
pool_h: int,
|
| 149 |
+
pool_w: int,
|
| 150 |
+
) -> np.ndarray:
|
| 151 |
+
h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0]
|
| 152 |
+
w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1]
|
| 153 |
+
idx_arr = np.pad(idx_arr, [[h_pad//2, (h_pad+1)//2], [w_pad//2, (w_pad+1)//2]],
|
| 154 |
+
mode='constant',constant_values=-1)
|
| 155 |
+
return einops.rearrange(
|
| 156 |
+
idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def image_to_patches_and_grids(
|
| 160 |
+
image: ImageInput,
|
| 161 |
+
base_image_input_size: list[int],
|
| 162 |
+
resample: PILImageResampling,
|
| 163 |
+
image_mean: list[float],
|
| 164 |
+
image_std: list[float],
|
| 165 |
+
image_patch_size: int,
|
| 166 |
+
image_pooling_w: int,
|
| 167 |
+
image_pooling_h: int,
|
| 168 |
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 169 |
+
"""
|
| 170 |
+
:return image_grids, the shape of each image after pooling
|
| 171 |
+
:return crops, the image crops to processes with the ViT
|
| 172 |
+
:return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the
|
| 173 |
+
patches in `crops` to pool for that token, masked with -1
|
| 174 |
+
"""
|
| 175 |
+
if isinstance(base_image_input_size, int):
|
| 176 |
+
base_image_input_size = (base_image_input_size, base_image_input_size)
|
| 177 |
+
|
| 178 |
+
pooling_w = image_pooling_w
|
| 179 |
+
pooling_h = image_pooling_h
|
| 180 |
+
|
| 181 |
+
resized, resize_idx = build_resized_image(
|
| 182 |
+
image,
|
| 183 |
+
base_image_input_size,
|
| 184 |
+
resample,
|
| 185 |
+
image_mean,
|
| 186 |
+
image_std,
|
| 187 |
+
image_patch_size,
|
| 188 |
+
)
|
| 189 |
+
pooling_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
|
| 190 |
+
h, w = pooling_idx.shape[:2]
|
| 191 |
+
pooling_idx = pooling_idx.reshape([-1, pooling_h*pooling_w])
|
| 192 |
+
image_grid = [h, w]
|
| 193 |
+
return (
|
| 194 |
+
image_grid,
|
| 195 |
+
batch_pixels_to_patches(resized, image_patch_size),
|
| 196 |
+
pooling_idx,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def get_candidate_target_fps(
|
| 201 |
+
video_fps: Union[int, float],
|
| 202 |
+
sampling_fps: Union[int, float],
|
| 203 |
+
max_fps: Union[int, float] = MAX_VIDEO_FPS,
|
| 204 |
+
) -> list[float]:
|
| 205 |
+
"""
|
| 206 |
+
Return the subset of `video_fps` factors that remain multiples of `sampling_fps`.
|
| 207 |
+
|
| 208 |
+
Examples:
|
| 209 |
+
>>> get_candidate_target_fps(video_fps=6, sampling_fps=2)
|
| 210 |
+
[2, 6]
|
| 211 |
+
>>> get_candidate_target_fps(video_fps=5, sampling_fps=1)
|
| 212 |
+
[1, 5]
|
| 213 |
+
>>> get_candidate_target_fps(video_fps=2, sampling_fps=2)
|
| 214 |
+
[2]
|
| 215 |
+
>>> get_candidate_target_fps(video_fps=5, sampling_fps=2)
|
| 216 |
+
Traceback (most recent call last):
|
| 217 |
+
...
|
| 218 |
+
ValueError: sampling_fps=2 must divide video_fps=5 to produce consistent frame steps.
|
| 219 |
+
"""
|
| 220 |
+
video_fps = int(video_fps)
|
| 221 |
+
sampling_fps = int(sampling_fps)
|
| 222 |
+
max_fps = int(max_fps)
|
| 223 |
+
|
| 224 |
+
if sampling_fps is None:
|
| 225 |
+
raise ValueError("sampling_fps must be provided")
|
| 226 |
+
if video_fps <= 0 or sampling_fps <= 0:
|
| 227 |
+
raise ValueError(f"video_fps and sampling_fps must be positive (got {video_fps}, {sampling_fps})")
|
| 228 |
+
if video_fps % sampling_fps != 0:
|
| 229 |
+
raise ValueError(f"sampling_fps={sampling_fps} must divide video_fps={video_fps}.")
|
| 230 |
+
|
| 231 |
+
candidates = []
|
| 232 |
+
for candidate in range(sampling_fps, video_fps + 1, sampling_fps):
|
| 233 |
+
if candidate > max_fps:
|
| 234 |
+
break
|
| 235 |
+
if video_fps % candidate == 0:
|
| 236 |
+
candidates.append(float(candidate))
|
| 237 |
+
|
| 238 |
+
return candidates
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def read_video_decord(
|
| 242 |
+
video_path,
|
| 243 |
+
sample_timestamps_fn: Callable,
|
| 244 |
+
**kwargs,
|
| 245 |
+
) -> np.ndarray:
|
| 246 |
+
"""
|
| 247 |
+
Decode a video using the Decord backend.
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
video_path (`str`):
|
| 251 |
+
Path to the video file.
|
| 252 |
+
sample_timestamps_fn (`Callable`):
|
| 253 |
+
A callable function that will return timestamps at which the video should be sampled.
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
| 257 |
+
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
| 258 |
+
- `VideoMetadata` object.
|
| 259 |
+
"""
|
| 260 |
+
# Lazy import from decord
|
| 261 |
+
import importlib
|
| 262 |
+
decord = importlib.import_module("decord")
|
| 263 |
+
|
| 264 |
+
vr = decord.VideoReader(uri=video_path, ctx=decord.cpu(0)) # decord has problems with gpu
|
| 265 |
+
video_fps = vr.get_avg_fps()
|
| 266 |
+
total_num_frames = len(vr)
|
| 267 |
+
time_stamps = vr.get_frame_timestamp(list(range(len(vr))))
|
| 268 |
+
duration = time_stamps[-1][1] - time_stamps[0][0]
|
| 269 |
+
|
| 270 |
+
metadata = VideoMetadata(
|
| 271 |
+
total_num_frames=int(total_num_frames),
|
| 272 |
+
fps=float(video_fps),
|
| 273 |
+
duration=float(duration),
|
| 274 |
+
video_backend="decord",
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
|
| 278 |
+
target_timestamps = np.array(target_timestamps)
|
| 279 |
+
offset = time_stamps[0, 0]
|
| 280 |
+
|
| 281 |
+
ix = np.searchsorted(time_stamps[:, 1], target_timestamps + offset, side='right')
|
| 282 |
+
ix = np.minimum(ix, len(time_stamps) - 1)
|
| 283 |
+
|
| 284 |
+
video = vr.get_batch(ix).asnumpy()
|
| 285 |
+
metadata.update(
|
| 286 |
+
{
|
| 287 |
+
"frames_indices": target_timestamps * video_fps,
|
| 288 |
+
"height": video.shape[1],
|
| 289 |
+
"width": video.shape[2],
|
| 290 |
+
}
|
| 291 |
+
)
|
| 292 |
+
return video, metadata
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def read_video_torchcodec(
|
| 296 |
+
video_path,
|
| 297 |
+
sample_timestamps_fn: Callable,
|
| 298 |
+
**kwargs,
|
| 299 |
+
) -> np.ndarray:
|
| 300 |
+
"""
|
| 301 |
+
Decode a video using torchcodec decoder.
|
| 302 |
+
|
| 303 |
+
Args:
|
| 304 |
+
video_path (`str`):
|
| 305 |
+
Path to the video file.
|
| 306 |
+
sample_timestamps_fn (`Callable`):
|
| 307 |
+
A callable function that will return timestamps at which the video should be sampled.
|
| 308 |
+
|
| 309 |
+
Returns:
|
| 310 |
+
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
| 311 |
+
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
| 312 |
+
- `VideoMetadata` object.
|
| 313 |
+
"""
|
| 314 |
+
# Lazy import torchcodec
|
| 315 |
+
import importlib
|
| 316 |
+
torchcodec = importlib.import_module("torchcodec")
|
| 317 |
+
|
| 318 |
+
decoder = torchcodec.decoders.VideoDecoder(
|
| 319 |
+
video_path,
|
| 320 |
+
# Interestingly `exact` mode takes less than approximate when we load the whole video
|
| 321 |
+
seek_mode="exact",
|
| 322 |
+
# Allow FFmpeg decide on the number of threads for efficiency
|
| 323 |
+
num_ffmpeg_threads=0,
|
| 324 |
+
)
|
| 325 |
+
# If the first frame starts at > 0, we effectively clip the video starting at that time
|
| 326 |
+
# since (most) video players would also skip to that time
|
| 327 |
+
time_offset = decoder.metadata.begin_stream_seconds_from_content
|
| 328 |
+
# Note this duration does assume we started playing at `time_offset`
|
| 329 |
+
duration = decoder.metadata.duration_seconds
|
| 330 |
+
|
| 331 |
+
metadata = VideoMetadata(
|
| 332 |
+
total_num_frames=decoder.metadata.num_frames,
|
| 333 |
+
fps=decoder.metadata.average_fps,
|
| 334 |
+
duration=duration,
|
| 335 |
+
video_backend="torchcodec",
|
| 336 |
+
height=decoder.metadata.height,
|
| 337 |
+
width=decoder.metadata.width,
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
|
| 341 |
+
|
| 342 |
+
# Floating point/rounding issues might cause `target_timestamps` to be very slightly
|
| 343 |
+
# out-of-bounds, to handle this we sanity check then clip them
|
| 344 |
+
assert all(x >= 0 for x in target_timestamps)
|
| 345 |
+
assert all(x < duration+1e-6 for x in target_timestamps)
|
| 346 |
+
# 1e-6 padding since torchcodec can throw out-of-bounds errors even if you ask for the
|
| 347 |
+
# exact boundary value, we should still get the first/last frame anyway
|
| 348 |
+
max_timestamp = decoder.metadata.end_stream_seconds_from_content - 1e-6
|
| 349 |
+
min_timestamp = decoder.metadata.begin_stream_seconds_from_content + 1e-6
|
| 350 |
+
# Note we avoid using numpy ops here to reduce floating precision issues
|
| 351 |
+
timestamps = [x + time_offset for x in target_timestamps]
|
| 352 |
+
timestamps = [max(min_timestamp, min(max_timestamp, x)) for x in timestamps]
|
| 353 |
+
|
| 354 |
+
video = decoder.get_frames_played_at(timestamps).data.numpy().transpose(0, 2, 3, 1) # Convert to THWC format
|
| 355 |
+
target_timestamps = np.array(target_timestamps)
|
| 356 |
+
metadata.frames_indices = target_timestamps * metadata.fps
|
| 357 |
+
|
| 358 |
+
return video, metadata
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def read_video_pyav(
|
| 362 |
+
video_path,
|
| 363 |
+
sample_timestamps_fn: Callable,
|
| 364 |
+
**kwargs,
|
| 365 |
+
) -> np.ndarray:
|
| 366 |
+
"""
|
| 367 |
+
Decode a video using the PyAV backend.
|
| 368 |
+
|
| 369 |
+
Args:
|
| 370 |
+
video_path (`str`):
|
| 371 |
+
Path to the video file.
|
| 372 |
+
sample_timestamps_fn (`Callable`):
|
| 373 |
+
A callable function that will return timestamps at which the video should be sampled.
|
| 374 |
+
|
| 375 |
+
Returns:
|
| 376 |
+
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
| 377 |
+
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
| 378 |
+
- `VideoMetadata` object.
|
| 379 |
+
"""
|
| 380 |
+
# Lazy import torchcodec
|
| 381 |
+
import importlib
|
| 382 |
+
av = importlib.import_module("av")
|
| 383 |
+
|
| 384 |
+
with av.open(video_path) as container:
|
| 385 |
+
video_stream = container.streams.video[0]
|
| 386 |
+
fps = video_stream.average_rate or video_stream.guessed_rate
|
| 387 |
+
it = container.decode(video=0)
|
| 388 |
+
frames = list(it)
|
| 389 |
+
|
| 390 |
+
stream = container.streams.video[0]
|
| 391 |
+
start = frames[0].pts * stream.time_base
|
| 392 |
+
container_end = stream.duration
|
| 393 |
+
if container_end is not None:
|
| 394 |
+
container_end *= stream.time_base
|
| 395 |
+
if container_end is None or container_end < frames[-1].pts:
|
| 396 |
+
# Some problem with stream duration, so use the frame PTS directly
|
| 397 |
+
# and guess the duration of the last frame
|
| 398 |
+
end = frames[-1].pts * stream.time_base + 1/fps
|
| 399 |
+
else:
|
| 400 |
+
end = container_end
|
| 401 |
+
duration = float(end - start)
|
| 402 |
+
|
| 403 |
+
metadata = VideoMetadata(
|
| 404 |
+
total_num_frames=len(frames),
|
| 405 |
+
fps=float(fps),
|
| 406 |
+
duration=float(duration),
|
| 407 |
+
video_backend="pyav",
|
| 408 |
+
height=video_stream.height,
|
| 409 |
+
width=video_stream.width,
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
|
| 413 |
+
offset = float(start)
|
| 414 |
+
|
| 415 |
+
target_timestamps = np.array(target_timestamps)
|
| 416 |
+
end_time_stamps = np.array([float(frame.pts * stream.time_base) for frame in frames[1:]] + [duration])
|
| 417 |
+
indices = np.searchsorted(end_time_stamps, target_timestamps + offset, side='right')
|
| 418 |
+
indices = np.minimum(indices, len(end_time_stamps) - 1)
|
| 419 |
+
|
| 420 |
+
video = np.stack(
|
| 421 |
+
[frames[i].to_ndarray(format="rgb24", channel_last=True) for i in indices],
|
| 422 |
+
axis=0,
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
metadata.frames_indices = target_timestamps * fps
|
| 426 |
+
|
| 427 |
+
return video, metadata
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
VIDEO_DECODERS = {
|
| 431 |
+
"decord": read_video_decord,
|
| 432 |
+
"torchcodec": read_video_torchcodec,
|
| 433 |
+
"pyav": read_video_pyav,
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def load_video(
|
| 438 |
+
video: VideoInput,
|
| 439 |
+
backend: str = "decord",
|
| 440 |
+
sample_timestamps_fn: Optional[Callable] = None,
|
| 441 |
+
**kwargs,
|
| 442 |
+
):
|
| 443 |
+
"""
|
| 444 |
+
Loads `video` to a numpy array.
|
| 445 |
+
|
| 446 |
+
Args:
|
| 447 |
+
video (`VideoInput`):
|
| 448 |
+
The video to convert to the numpy array format. Can be a link to video or local path.
|
| 449 |
+
backend (`str`, *optional*, defaults to `"decord"`):
|
| 450 |
+
The backend to use when loading the video. Can be any of ["decord", "pyav", ""torchcodec"]. Defaults to "decord".
|
| 451 |
+
sample_timestamps_fn (`Callable`):
|
| 452 |
+
A callable function that will return timestamps at which the video should be sampled.
|
| 453 |
+
"""
|
| 454 |
+
|
| 455 |
+
# Early exit if provided an array or `PIL` frames
|
| 456 |
+
if not isinstance(video, str):
|
| 457 |
+
metadata = [None] * len(video)
|
| 458 |
+
return video, metadata
|
| 459 |
+
|
| 460 |
+
if urlparse(video).netloc in ["www.youtube.com", "youtube.com"]:
|
| 461 |
+
if not is_yt_dlp_available():
|
| 462 |
+
raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.")
|
| 463 |
+
# Lazy import from yt_dlp
|
| 464 |
+
import importlib
|
| 465 |
+
yt_dlp = importlib.import_module("yt_dlp")
|
| 466 |
+
|
| 467 |
+
buffer = BytesIO()
|
| 468 |
+
with redirect_stdout(buffer), yt_dlp.YoutubeDL() as f:
|
| 469 |
+
f.download([video])
|
| 470 |
+
bytes_obj = buffer.getvalue()
|
| 471 |
+
file_obj = BytesIO(bytes_obj)
|
| 472 |
+
elif video.startswith("http://") or video.startswith("https://"):
|
| 473 |
+
file_obj = BytesIO(requests.get(video).content)
|
| 474 |
+
elif os.path.isfile(video):
|
| 475 |
+
file_obj = video
|
| 476 |
+
else:
|
| 477 |
+
raise TypeError("Incorrect format used for video. Should be an url linking to an video or a local path.")
|
| 478 |
+
|
| 479 |
+
# can also load with decord, but not cv2/torchvision
|
| 480 |
+
# both will fail in case of url links
|
| 481 |
+
video_is_url = video.startswith("http://") or video.startswith("https://")
|
| 482 |
+
if video_is_url and backend == "opencv":
|
| 483 |
+
raise ValueError("If you are trying to load a video from URL, you cannot use 'opencv' as backend")
|
| 484 |
+
|
| 485 |
+
if (
|
| 486 |
+
(not is_decord_available() and backend == "decord")
|
| 487 |
+
or (not is_torchcodec_available() and backend == "torchcodec")
|
| 488 |
+
or (not is_av_available() and backend == "pyav")
|
| 489 |
+
):
|
| 490 |
+
raise ImportError(
|
| 491 |
+
f"You chose backend={backend} for loading the video but the required library is not found in your environment "
|
| 492 |
+
f"Make sure to install {backend} before loading the video."
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
video_decoder = VIDEO_DECODERS[backend]
|
| 496 |
+
video, metadata = video_decoder(file_obj, sample_timestamps_fn, **kwargs)
|
| 497 |
+
return video, metadata
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def get_target_fps(
|
| 501 |
+
video_fps: float,
|
| 502 |
+
max_frames: int,
|
| 503 |
+
total_frames: int,
|
| 504 |
+
frame_sample_mode: str,
|
| 505 |
+
candidate_target_fps: tuple[float],
|
| 506 |
+
) -> float:
|
| 507 |
+
"""
|
| 508 |
+
Get the target fps that best spans the video and has the most frames sampled
|
| 509 |
+
"""
|
| 510 |
+
num_frames_sampled = 0
|
| 511 |
+
selected_target_fps = None
|
| 512 |
+
for target_fps in candidate_target_fps:
|
| 513 |
+
step_size = max(int(video_fps / target_fps), 1)
|
| 514 |
+
num_frames_sampled_at_fps = int(total_frames / step_size)
|
| 515 |
+
if num_frames_sampled == 0:
|
| 516 |
+
if "uniform" in frame_sample_mode:
|
| 517 |
+
if num_frames_sampled_at_fps > max_frames:
|
| 518 |
+
break
|
| 519 |
+
selected_target_fps = target_fps
|
| 520 |
+
num_frames_sampled = num_frames_sampled_at_fps
|
| 521 |
+
|
| 522 |
+
else:
|
| 523 |
+
# the candidate sampling fps increases so frame count can't decrease
|
| 524 |
+
assert num_frames_sampled <= num_frames_sampled_at_fps
|
| 525 |
+
if num_frames_sampled_at_fps > max_frames:
|
| 526 |
+
# choose the sampling fps that spans the video
|
| 527 |
+
continue
|
| 528 |
+
|
| 529 |
+
elif num_frames_sampled_at_fps > num_frames_sampled:
|
| 530 |
+
# both are less than max_frames, choose the one with higher density of frames sampled
|
| 531 |
+
selected_target_fps = target_fps
|
| 532 |
+
num_frames_sampled = num_frames_sampled_at_fps
|
| 533 |
+
return selected_target_fps
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
def get_frame_times_and_chosen_fps(
|
| 537 |
+
selected_target_fps,
|
| 538 |
+
total_frames,
|
| 539 |
+
max_frames,
|
| 540 |
+
video_fps
|
| 541 |
+
):
|
| 542 |
+
if selected_target_fps is None:
|
| 543 |
+
frame_indices = np.linspace(0, total_frames, max_frames, endpoint=False, dtype=int)
|
| 544 |
+
else:
|
| 545 |
+
step_size = max(int(video_fps / selected_target_fps), 1)
|
| 546 |
+
frame_indices = np.arange(0, total_frames, step_size)
|
| 547 |
+
if len(frame_indices) > max_frames:
|
| 548 |
+
frame_indices = frame_indices[:max_frames]
|
| 549 |
+
return selected_target_fps, frame_indices
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
class MolmoAct2VideoProcessorKwargs(VideosKwargs, total=False):
|
| 553 |
+
patch_size: Optional[int]
|
| 554 |
+
pooling_size: Optional[list[int]]
|
| 555 |
+
frame_sample_mode: Optional[str]
|
| 556 |
+
max_fps: Optional[int]
|
| 557 |
+
sampling_fps: Optional[int]
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
class MolmoAct2VideoProcessor(BaseVideoProcessor):
|
| 561 |
+
resample = PILImageResampling.BILINEAR
|
| 562 |
+
size = {"height": 378, "width": 378}
|
| 563 |
+
image_mean = IMAGENET_STANDARD_MEAN
|
| 564 |
+
image_std = IMAGENET_STANDARD_STD
|
| 565 |
+
do_resize = True
|
| 566 |
+
do_rescale = True
|
| 567 |
+
do_normalize = True
|
| 568 |
+
do_convert_rgb = True
|
| 569 |
+
patch_size = 14
|
| 570 |
+
pooling_size = [3, 3]
|
| 571 |
+
do_sample_frames = True
|
| 572 |
+
frame_sample_mode = "uniform_last_frame"
|
| 573 |
+
max_fps = 2
|
| 574 |
+
sampling_fps = 2
|
| 575 |
+
valid_kwargs = MolmoAct2VideoProcessorKwargs
|
| 576 |
+
model_input_names = ["pixel_values_videos", "video_token_pooling", "video_grids"]
|
| 577 |
+
|
| 578 |
+
def __init__(self, **kwargs: Unpack[MolmoAct2VideoProcessorKwargs]):
|
| 579 |
+
super().__init__(**kwargs)
|
| 580 |
+
if self.size is not None and (
|
| 581 |
+
self.size.get("height", None) is None or self.size.get("width", None) is None
|
| 582 |
+
):
|
| 583 |
+
raise ValueError("size must contain 'height' and 'width' keys.")
|
| 584 |
+
|
| 585 |
+
def _further_process_kwargs(
|
| 586 |
+
self,
|
| 587 |
+
size: Optional[SizeDict] = None,
|
| 588 |
+
**kwargs,
|
| 589 |
+
) -> dict:
|
| 590 |
+
"""
|
| 591 |
+
Update kwargs that need further processing before being validated
|
| 592 |
+
Can be overridden by subclasses to customize the processing of kwargs.
|
| 593 |
+
"""
|
| 594 |
+
if size is not None and ("height" not in size or "width" not in size):
|
| 595 |
+
raise ValueError("size must contain 'height' and 'width' keys.")
|
| 596 |
+
|
| 597 |
+
return super()._further_process_kwargs(size=size, **kwargs)
|
| 598 |
+
|
| 599 |
+
def sample_times(
|
| 600 |
+
self,
|
| 601 |
+
metadata: VideoMetadata,
|
| 602 |
+
frame_sample_mode: str,
|
| 603 |
+
num_frames: int,
|
| 604 |
+
max_fps: Optional[int] = None,
|
| 605 |
+
sampling_fps: Optional[int] = None,
|
| 606 |
+
**kwargs,
|
| 607 |
+
) -> np.ndarray:
|
| 608 |
+
"""
|
| 609 |
+
Time-based sampling if an array video is passed
|
| 610 |
+
Args:
|
| 611 |
+
metadata (`VideoMetadata`):
|
| 612 |
+
Metadata of the video containing information about total duration, fps and total number of frames.
|
| 613 |
+
frame_sample_mode (`str`, *optional*):
|
| 614 |
+
Mode to sample frames. Defaults to `self.frame_sample_mode`.
|
| 615 |
+
num_frames (`int`, *optional*):
|
| 616 |
+
Maximum number of frames to sample. Defaults to `self.num_frames`.
|
| 617 |
+
man_fps (`int`, *optional*):
|
| 618 |
+
Maximum frames per second to sample.
|
| 619 |
+
sampling_fps (`int`, *optional*):
|
| 620 |
+
Sampling frames per second. Defaults to `self.sampling_fps`.
|
| 621 |
+
Used when `frame_sample_mode` is `"fps"`.
|
| 622 |
+
"""
|
| 623 |
+
frame_sample_mode = frame_sample_mode or self.frame_sample_mode
|
| 624 |
+
num_frames = num_frames or self.num_frames
|
| 625 |
+
sampling_fps = sampling_fps or self.sampling_fps
|
| 626 |
+
|
| 627 |
+
duration = metadata.duration or metadata.total_num_frames / metadata.fps
|
| 628 |
+
if frame_sample_mode == "fps":
|
| 629 |
+
candidate_target_fps = get_candidate_target_fps(metadata.fps, sampling_fps)
|
| 630 |
+
# Try larger and larger FPSs until we hit one that can't span the video
|
| 631 |
+
target_fps = candidate_target_fps[0]
|
| 632 |
+
for candidate_fps in candidate_target_fps[1:]:
|
| 633 |
+
if num_frames / candidate_fps < duration:
|
| 634 |
+
break
|
| 635 |
+
target_fps = candidate_fps
|
| 636 |
+
times = np.arange(0, num_frames) / target_fps
|
| 637 |
+
times = times[times < duration]
|
| 638 |
+
return times
|
| 639 |
+
elif frame_sample_mode == "uniform_last_frame":
|
| 640 |
+
if max_fps is not None:
|
| 641 |
+
max_duration = (num_frames-1) / max_fps # -1 to include the last frame
|
| 642 |
+
if max_duration < duration:
|
| 643 |
+
times = np.linspace(
|
| 644 |
+
0, duration, num=num_frames, endpoint=True, dtype=np.float64
|
| 645 |
+
)
|
| 646 |
+
else:
|
| 647 |
+
times = np.arange(0.0, stop=duration, step=1/max_fps)
|
| 648 |
+
times = np.concatenate([times, [duration]], axis=0)
|
| 649 |
+
assert len(times) <= num_frames
|
| 650 |
+
else:
|
| 651 |
+
times = np.linspace(
|
| 652 |
+
0, duration, num=num_frames, endpoint=True, dtype=np.float64
|
| 653 |
+
)
|
| 654 |
+
return times
|
| 655 |
+
else:
|
| 656 |
+
raise NotImplementedError(frame_sample_mode)
|
| 657 |
+
|
| 658 |
+
def sample_frames(
|
| 659 |
+
self,
|
| 660 |
+
metadata: VideoMetadata,
|
| 661 |
+
frame_sample_mode: Optional[str] = None,
|
| 662 |
+
num_frames: Optional[int] = None,
|
| 663 |
+
max_fps: Optional[int] = None,
|
| 664 |
+
sampling_fps: Optional[int] = None,
|
| 665 |
+
**kwargs,
|
| 666 |
+
) -> np.ndarray:
|
| 667 |
+
"""
|
| 668 |
+
Frame-based sampling if an array video is passed
|
| 669 |
+
Args:
|
| 670 |
+
metadata (`VideoMetadata`):
|
| 671 |
+
Metadata of the video containing information about total duration, fps and total number of frames.
|
| 672 |
+
frame_sample_mode (`str`, *optional*):
|
| 673 |
+
Mode to sample frames. Defaults to `self.frame_sample_mode`.
|
| 674 |
+
num_frames (`int`, *optional*):
|
| 675 |
+
Maximum number of frames to sample. Defaults to `self.num_frames`.
|
| 676 |
+
max_fps (`int`, *optional*):
|
| 677 |
+
Maximum frames per second to sample.
|
| 678 |
+
sampling_fps (`int`, *optional*):
|
| 679 |
+
Sampling frames per second. Defaults to `self.sampling_fps`.
|
| 680 |
+
Used when `frame_sample_mode` is `"fps"`.
|
| 681 |
+
"""
|
| 682 |
+
frame_sample_mode = frame_sample_mode or self.frame_sample_mode
|
| 683 |
+
num_frames = num_frames or self.num_frames
|
| 684 |
+
sampling_fps = sampling_fps or self.sampling_fps
|
| 685 |
+
|
| 686 |
+
total_num_frames = metadata.total_num_frames
|
| 687 |
+
if frame_sample_mode == "uniform_last_frame" and max_fps is not None:
|
| 688 |
+
duration = total_num_frames / metadata.fps
|
| 689 |
+
if total_num_frames <= 2:
|
| 690 |
+
return np.arange(total_num_frames).astype(int)
|
| 691 |
+
if duration > (num_frames - 1) / max_fps: # -1 to include the last frame
|
| 692 |
+
# uniform fallback
|
| 693 |
+
indices = np.linspace(
|
| 694 |
+
0,
|
| 695 |
+
total_num_frames - 1,
|
| 696 |
+
num=min(num_frames, total_num_frames),
|
| 697 |
+
endpoint=True,
|
| 698 |
+
).astype(int)
|
| 699 |
+
return indices
|
| 700 |
+
else:
|
| 701 |
+
float_indices = np.arange(
|
| 702 |
+
0.0, stop=total_num_frames - 1, step=float(metadata.fps / max_fps),
|
| 703 |
+
)
|
| 704 |
+
if np.round(float_indices[-1]) != total_num_frames - 1:
|
| 705 |
+
float_indices = np.concatenate([float_indices, [total_num_frames - 1]], axis=0)
|
| 706 |
+
indices = np.round(float_indices).astype(int)
|
| 707 |
+
assert indices[-1] < total_num_frames
|
| 708 |
+
assert len(float_indices) <= num_frames
|
| 709 |
+
return indices
|
| 710 |
+
elif frame_sample_mode == "uniform_last_frame":
|
| 711 |
+
indices = np.linspace(
|
| 712 |
+
0, total_num_frames - 1, num=min(num_frames, total_num_frames), endpoint=True,
|
| 713 |
+
).astype(int)
|
| 714 |
+
return indices
|
| 715 |
+
elif frame_sample_mode == "fps":
|
| 716 |
+
candidate_target_fps = get_candidate_target_fps(metadata.fps, sampling_fps)
|
| 717 |
+
selected_target_fps = get_target_fps(
|
| 718 |
+
metadata.fps,
|
| 719 |
+
num_frames,
|
| 720 |
+
total_num_frames,
|
| 721 |
+
frame_sample_mode,
|
| 722 |
+
candidate_target_fps,
|
| 723 |
+
)
|
| 724 |
+
_, indices = get_frame_times_and_chosen_fps(
|
| 725 |
+
selected_target_fps,
|
| 726 |
+
total_num_frames,
|
| 727 |
+
num_frames,
|
| 728 |
+
metadata.fps,
|
| 729 |
+
)
|
| 730 |
+
return indices
|
| 731 |
+
else:
|
| 732 |
+
raise NotImplementedError(frame_sample_mode)
|
| 733 |
+
|
| 734 |
+
def fetch_videos(
|
| 735 |
+
self,
|
| 736 |
+
video_url_or_urls: Union[str, list[str], list[list[str]]],
|
| 737 |
+
sample_timestamps_fn=None
|
| 738 |
+
):
|
| 739 |
+
"""
|
| 740 |
+
Convert a single or a list of urls into the corresponding `np.array` objects.
|
| 741 |
+
|
| 742 |
+
If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
|
| 743 |
+
returned.
|
| 744 |
+
"""
|
| 745 |
+
if (
|
| 746 |
+
(not is_decord_available())
|
| 747 |
+
and (not is_torchcodec_available())
|
| 748 |
+
and (not is_av_available())
|
| 749 |
+
):
|
| 750 |
+
raise ImportError(
|
| 751 |
+
"MolmoAct2VideoProcessor requires `decord`, `torchcodec`, or `av` to be installed."
|
| 752 |
+
)
|
| 753 |
+
|
| 754 |
+
if is_decord_available():
|
| 755 |
+
backend = "decord"
|
| 756 |
+
elif is_torchcodec_available():
|
| 757 |
+
warnings.warn(
|
| 758 |
+
"`decord` is not installed and cannot be used to decode the video by default. "
|
| 759 |
+
"Falling back to `torchcodec`."
|
| 760 |
+
)
|
| 761 |
+
backend = "torchcodec"
|
| 762 |
+
else:
|
| 763 |
+
warnings.warn(
|
| 764 |
+
"`decord` is not installed and cannot be used to decode the video by default. "
|
| 765 |
+
"Falling back to `PyAV`."
|
| 766 |
+
)
|
| 767 |
+
backend = "pyav"
|
| 768 |
+
|
| 769 |
+
if isinstance(video_url_or_urls, list):
|
| 770 |
+
return list(zip(*[self.fetch_videos(x, sample_timestamps_fn=sample_timestamps_fn) for x in video_url_or_urls]))
|
| 771 |
+
else:
|
| 772 |
+
return load_video(video_url_or_urls, backend=backend, sample_timestamps_fn=sample_timestamps_fn)
|
| 773 |
+
|
| 774 |
+
def _decode_and_sample_videos(
|
| 775 |
+
self,
|
| 776 |
+
videos: VideoInput,
|
| 777 |
+
video_metadata: Union[VideoMetadata, dict],
|
| 778 |
+
do_sample_frames: Optional[bool] = None,
|
| 779 |
+
sample_indices_fn: Optional[Callable] = None,
|
| 780 |
+
sample_timestamps_fn: Optional[Callable] = None,
|
| 781 |
+
):
|
| 782 |
+
"""
|
| 783 |
+
Decode input videos and sample frames if needed.
|
| 784 |
+
"""
|
| 785 |
+
videos = make_batched_videos(videos)
|
| 786 |
+
video_metadata = make_batched_metadata(videos, video_metadata=video_metadata)
|
| 787 |
+
|
| 788 |
+
# Framed-based sampling if an array video is passed
|
| 789 |
+
# Otherwise, time-based sampling with decoding
|
| 790 |
+
if is_valid_video(videos[0]) and do_sample_frames:
|
| 791 |
+
assert video_metadata[0].fps is not None, "FPS must be provided for video input"
|
| 792 |
+
sampled_videos = []
|
| 793 |
+
sampled_metadata = []
|
| 794 |
+
for video, metadata in zip(videos, video_metadata):
|
| 795 |
+
indices = sample_indices_fn(metadata=metadata)
|
| 796 |
+
metadata.frames_indices = indices
|
| 797 |
+
sampled_videos.append(video[indices])
|
| 798 |
+
sampled_metadata.append(metadata)
|
| 799 |
+
videos = sampled_videos
|
| 800 |
+
video_metadata = sampled_metadata
|
| 801 |
+
elif not is_valid_video(videos[0]):
|
| 802 |
+
if sample_indices_fn is None:
|
| 803 |
+
logger.warning(
|
| 804 |
+
"do_sample_frames is False, but video array is not provided: "
|
| 805 |
+
"Will decode the video and sample frames using MolmoAct2's default sampling mode"
|
| 806 |
+
)
|
| 807 |
+
if isinstance(videos[0], list):
|
| 808 |
+
raise ValueError(
|
| 809 |
+
"A list of images is not supported for video input!"
|
| 810 |
+
)
|
| 811 |
+
else:
|
| 812 |
+
videos, video_metadata = self.fetch_videos(videos, sample_timestamps_fn=sample_timestamps_fn)
|
| 813 |
+
|
| 814 |
+
return videos, video_metadata
|
| 815 |
+
|
| 816 |
+
def _prepare_input_videos(
|
| 817 |
+
self,
|
| 818 |
+
videos: VideoInput,
|
| 819 |
+
**kwargs,
|
| 820 |
+
) -> list[np.ndarray]:
|
| 821 |
+
processed_videos = [to_numpy(video) for video in videos]
|
| 822 |
+
return processed_videos
|
| 823 |
+
|
| 824 |
+
def preprocess(
|
| 825 |
+
self,
|
| 826 |
+
videos: VideoInput,
|
| 827 |
+
**kwargs: Unpack[MolmoAct2VideoProcessorKwargs],
|
| 828 |
+
) -> BatchFeature:
|
| 829 |
+
validate_kwargs(
|
| 830 |
+
captured_kwargs=kwargs.keys(),
|
| 831 |
+
valid_processor_keys=list(self.valid_kwargs.__annotations__.keys()) + ["return_tensors"],
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
# Set default kwargs from self. This ensures that if a kwarg is not provided
|
| 835 |
+
# by the user, it gets its default value from the instance, or is set to None.
|
| 836 |
+
for kwarg_name in self.valid_kwargs.__annotations__:
|
| 837 |
+
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
|
| 838 |
+
|
| 839 |
+
do_sample_frames = kwargs.pop("do_sample_frames")
|
| 840 |
+
video_metadata = kwargs.pop("video_metadata")
|
| 841 |
+
|
| 842 |
+
sample_indices_fn = partial(self.sample_frames, **kwargs) if do_sample_frames else None
|
| 843 |
+
sample_timestamps_fn = partial(self.sample_times, **kwargs)
|
| 844 |
+
videos, video_metadata = self._decode_and_sample_videos(
|
| 845 |
+
videos,
|
| 846 |
+
video_metadata=video_metadata,
|
| 847 |
+
do_sample_frames=do_sample_frames,
|
| 848 |
+
sample_indices_fn=sample_indices_fn,
|
| 849 |
+
sample_timestamps_fn=sample_timestamps_fn,
|
| 850 |
+
)
|
| 851 |
+
videos = self._prepare_input_videos(videos=videos)
|
| 852 |
+
|
| 853 |
+
kwargs = self._further_process_kwargs(**kwargs)
|
| 854 |
+
|
| 855 |
+
return_metadata = kwargs.pop("return_metadata")
|
| 856 |
+
preprocessed_videos = self._preprocess(videos=videos, **kwargs)
|
| 857 |
+
if return_metadata:
|
| 858 |
+
preprocessed_videos["video_metadata"] = video_metadata
|
| 859 |
+
return preprocessed_videos
|
| 860 |
+
|
| 861 |
+
def _preprocess(
|
| 862 |
+
self,
|
| 863 |
+
videos: list[np.ndarray],
|
| 864 |
+
size: Optional[SizeDict] = None,
|
| 865 |
+
resample: Optional[PILImageResampling] = None,
|
| 866 |
+
image_mean: Optional[Union[float, list[float]]] = None,
|
| 867 |
+
image_std: Optional[Union[float, list[float]]] = None,
|
| 868 |
+
do_convert_rgb: Optional[bool] = None,
|
| 869 |
+
patch_size: Optional[int] = None,
|
| 870 |
+
pooling_size: Optional[list[int]] = None,
|
| 871 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 872 |
+
**kwargs,
|
| 873 |
+
) -> BatchFeature:
|
| 874 |
+
"""
|
| 875 |
+
Preprocess a video for the model.
|
| 876 |
+
Args:
|
| 877 |
+
videos (`VideoInput`):
|
| 878 |
+
Video to preprocess.
|
| 879 |
+
size (`SizeDict`, *optional*, defaults to `self.size`):
|
| 880 |
+
Size of the image after resizing.
|
| 881 |
+
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
| 882 |
+
Resampling filter to use when resizing the image. This can be one of the enum `PILImageResampling`. Only
|
| 883 |
+
has an effect if `do_resize` is set to `True`.
|
| 884 |
+
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
|
| 885 |
+
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
| 886 |
+
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
|
| 887 |
+
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
| 888 |
+
`True`.
|
| 889 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
| 890 |
+
Whether to convert the image to RGB.
|
| 891 |
+
patch_size (`int`, *optional*, defaults to `self.patch_size`):
|
| 892 |
+
The spatial patch size of the vision encoder.
|
| 893 |
+
pooling_size (`list[int]`, *optional*, defaults to `self.pooling_size`):
|
| 894 |
+
The pooling size of the vision adapter.
|
| 895 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 896 |
+
The type of tensors to return. Can be one of:
|
| 897 |
+
- Unset: Return a list of `np.ndarray`.
|
| 898 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 899 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 900 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 901 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 902 |
+
|
| 903 |
+
Returns:
|
| 904 |
+
A `BatchFeature` containing the following keys:
|
| 905 |
+
- `pixel_values_videos`: The preprocessed videos.
|
| 906 |
+
- `video_token_pooling`: The indices of the patches in `crops` to pool for each token in `video_tokens`.
|
| 907 |
+
- `video_grids`: The video grids.
|
| 908 |
+
"""
|
| 909 |
+
if size.height is None or size.width is None:
|
| 910 |
+
raise ValueError("size must contain 'height' and 'width' keys.")
|
| 911 |
+
|
| 912 |
+
base_image_input_size = [size.height, size.width]
|
| 913 |
+
|
| 914 |
+
resample = resample or self.resample
|
| 915 |
+
image_mean = image_mean or self.image_mean
|
| 916 |
+
image_std = image_std or self.image_std
|
| 917 |
+
do_convert_rgb = do_convert_rgb or self.do_convert_rgb
|
| 918 |
+
|
| 919 |
+
patch_size = patch_size or self.patch_size
|
| 920 |
+
pooling_size = pooling_size or self.pooling_size
|
| 921 |
+
|
| 922 |
+
image_pooling_h, image_pooling_w = pooling_size
|
| 923 |
+
|
| 924 |
+
batch_grids = []
|
| 925 |
+
batch_crops = []
|
| 926 |
+
batch_pooled_patches_idx = []
|
| 927 |
+
|
| 928 |
+
for video in videos:
|
| 929 |
+
all_crops = []
|
| 930 |
+
pooled_patches_idx = []
|
| 931 |
+
|
| 932 |
+
for frame in video:
|
| 933 |
+
image_grid, crops, pooled_idx = image_to_patches_and_grids(
|
| 934 |
+
frame,
|
| 935 |
+
base_image_input_size,
|
| 936 |
+
resample,
|
| 937 |
+
image_mean,
|
| 938 |
+
image_std,
|
| 939 |
+
patch_size,
|
| 940 |
+
image_pooling_w,
|
| 941 |
+
image_pooling_h,
|
| 942 |
+
)
|
| 943 |
+
offset = sum(np.prod(x.shape[:2]) for x in all_crops)
|
| 944 |
+
pooled_idx_with_offset = np.where(pooled_idx >= 0, pooled_idx + offset, pooled_idx)
|
| 945 |
+
pooled_patches_idx.append(pooled_idx_with_offset)
|
| 946 |
+
all_crops.append(crops)
|
| 947 |
+
|
| 948 |
+
video_grid = np.array([len(video), image_grid[0], image_grid[1]])
|
| 949 |
+
all_crops = np.concatenate(all_crops, 0)
|
| 950 |
+
pooled_patches_idx = np.concatenate(pooled_patches_idx, 0)
|
| 951 |
+
|
| 952 |
+
batch_grids.append(video_grid)
|
| 953 |
+
batch_crops.append(all_crops)
|
| 954 |
+
batch_pooled_patches_idx.append(pooled_patches_idx)
|
| 955 |
+
|
| 956 |
+
video_grids = np.stack(batch_grids, 0)
|
| 957 |
+
pixel_values_videos = np.concatenate(batch_crops, 0)
|
| 958 |
+
video_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
|
| 959 |
+
|
| 960 |
+
data =dict(
|
| 961 |
+
pixel_values_videos=pixel_values_videos,
|
| 962 |
+
video_token_pooling=video_token_pooling,
|
| 963 |
+
video_grids=video_grids,
|
| 964 |
+
)
|
| 965 |
+
|
| 966 |
+
return BatchFeature(data, tensor_type=return_tensors)
|
| 967 |
+
|
| 968 |
+
|
| 969 |
+
MolmoAct2VideoProcessor.register_for_auto_class()
|