lfh commited on
Commit
eb868a1
·
1 Parent(s): 95a3948

remove files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +0 -35
  2. README.md +0 -57
  3. capvector-oft/.pre-commit-config.yaml +0 -27
  4. capvector-oft/ALOHA.md +0 -157
  5. capvector-oft/LIBERO.md +0 -130
  6. capvector-oft/LICENSE +0 -21
  7. capvector-oft/SETUP.md +0 -24
  8. capvector-oft/capvector/.gitignore +0 -8
  9. capvector-oft/capvector/compute_lora_diff.py +0 -35
  10. capvector-oft/capvector/compute_lora_shell/compute_lora_diff.sh +0 -8
  11. capvector-oft/capvector/initialized_interpolate_shell/get_vector_robotwin.sh +0 -26
  12. capvector-oft/capvector/interpolate.py +0 -247
  13. capvector-oft/capvector/interpolate.sh +0 -26
  14. capvector-oft/capvector/interpolate_robotwin.py +0 -247
  15. capvector-oft/capvector/tools/check_model_config.py +0 -23
  16. capvector-oft/capvector/tools/compute_lora_diff.py +0 -36
  17. capvector-oft/capvector/tools/compute_lora_diff.sh +0 -8
  18. capvector-oft/capvector/tools/vector_analyze.py +0 -153
  19. capvector-oft/capvector/tools/vector_regularize.py +0 -75
  20. capvector-oft/experiments/robot/aloha/aloha_utils.py +0 -85
  21. capvector-oft/experiments/robot/aloha/constants.py +0 -100
  22. capvector-oft/experiments/robot/aloha/preprocess_split_aloha_data.py +0 -260
  23. capvector-oft/experiments/robot/aloha/real_env.py +0 -213
  24. capvector-oft/experiments/robot/aloha/requirements_aloha.txt +0 -26
  25. capvector-oft/experiments/robot/aloha/robot_utils.py +0 -187
  26. capvector-oft/experiments/robot/aloha/run_aloha_eval.py +0 -385
  27. capvector-oft/experiments/robot/libero/libero_requirements.txt +0 -6
  28. capvector-oft/experiments/robot/libero/libero_utils.py +0 -87
  29. capvector-oft/experiments/robot/libero/regenerate_libero_dataset.py +0 -249
  30. capvector-oft/experiments/robot/libero/run_libero_eval.py +0 -540
  31. capvector-oft/experiments/robot/libero/sample_libero_spatial_observation.pkl +0 -3
  32. capvector-oft/experiments/robot/openvla_utils.py +0 -818
  33. capvector-oft/experiments/robot/robot_utils.py +0 -199
  34. capvector-oft/prismatic/__init__.py +0 -1
  35. capvector-oft/prismatic/conf/__init__.py +0 -3
  36. capvector-oft/prismatic/conf/datasets.py +0 -133
  37. capvector-oft/prismatic/conf/models.py +0 -584
  38. capvector-oft/prismatic/conf/vla.py +0 -235
  39. capvector-oft/prismatic/extern/__init__.py +0 -0
  40. capvector-oft/prismatic/extern/hf/__init__.py +0 -0
  41. capvector-oft/prismatic/extern/hf/configuration_prismatic.py +0 -140
  42. capvector-oft/prismatic/extern/hf/modeling_prismatic.py +0 -1085
  43. capvector-oft/prismatic/extern/hf/processing_prismatic.py +0 -252
  44. capvector-oft/prismatic/models/__init__.py +0 -2
  45. capvector-oft/prismatic/models/action_heads.py +0 -211
  46. capvector-oft/prismatic/models/backbones/__init__.py +0 -0
  47. capvector-oft/prismatic/models/backbones/llm/__init__.py +0 -4
  48. capvector-oft/prismatic/models/backbones/llm/base_llm.py +0 -223
  49. capvector-oft/prismatic/models/backbones/llm/llama2.py +0 -102
  50. capvector-oft/prismatic/models/backbones/llm/mistral.py +0 -72
.gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md DELETED
@@ -1,57 +0,0 @@
1
- # CapVector: Learning Transferable Capability Vectors in Parametric Space for Vision-Language-Action Models
2
-
3
- <div align="center">
4
-
5
- [![Paper](https://img.shields.io/badge/Paper-A42C25?style=for-the-badge&logo=arxiv&logoColor=white)](http://arxiv.org/abs/) [![Page](https://img.shields.io/badge/Project--Page-blue?style=for-the-badge&logo=homepage&logoColor=white)](https://capvector.github.io/) [![Hugging Face Collection](https://img.shields.io/badge/Models-fcd022?style=for-the-badge&logo=huggingface&logoColor=white)](https://huggingface.co/haofuly/capvector_models_collection)
6
-
7
- </div>
8
-
9
- CapVector is a training recipe for vision-language-action (VLA) models that extracts a transferable capability vector from the parameter difference between auxiliary-objective SFT methods and standard SFT methods. This vector is merged into a pretrained VLA to form a stronger initialization, and downstream adaptation uses standard SFT with a lightweight orthogonal regularization loss to preserve the injected capability.
10
-
11
-
12
- ## 🌟 Key Features
13
- - **Efficient downstream adaptation**: CapVector recovers much of the benefit of auxiliary-objective SFT methods, while keeping the downstream overhead close to standard SFT.
14
- - **Versatility**: CapVector fits for OpenVLA-based, OpenPi-based, and StarVLA-based backbones.
15
- - **Generalization**: CapVector is designed to transfer across tasks, environments, and robot embodiments.
16
-
17
-
18
- ## 🚀 Get Started
19
-
20
- This repository provides two implementation paths:
21
- - [`capvector-oft/`](./capvector-oft) based implementation
22
- - [`capvector-pi05/`](./capvector-pi05) based implementation.
23
-
24
- Choose the subdirectory that matches your base model and training stack. Follow the subproject README for environment setup, data preparation, training, and inference.
25
-
26
- [`capvector-pi05/`](./capvector-pi05) provides the capability vector extraction and merging scripts.
27
-
28
-
29
- ## 🌏 Contact
30
- For further discussion and collaboration, please feel free to contact us via Email and WeChat:
31
-
32
- | Author | Email | WeChat |
33
- |:---:|:---:|:---:|
34
- | Wenxuan Song | songwenxuan0115@gmail.com | swx0757 |
35
-
36
-
37
- ## ❤️ Acknowledgments
38
-
39
- CapVector builds on and interfaces with several excellent open-source projects, including:
40
-
41
- - [OpenVLA-OFT](https://github.com/moojink/openvla-oft)
42
- - [OpenPI](https://github.com/Physical-Intelligence/openpi)
43
-
44
-
45
- ## 🖊 Citation
46
-
47
- If you find this work useful, please cite:
48
-
49
- ```bibtex
50
- @article{song2026capvector,
51
- title = {CapVector: Learning Transferable Capability Vectors in Parametric Space for Vision-Language-Action Models},
52
- author = {Song, Wenxuan and Zhao, Han and Li, Fuhao and Zhou, Ziyang and Wang, Xi and Lyu, Jing and Ding, Pengxiang and Wang, Yan and Wang, Donglin and Li, Haoang},
53
- journal = {Preprint},
54
- year = {2026}
55
- }
56
- ```
57
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/.pre-commit-config.yaml DELETED
@@ -1,27 +0,0 @@
1
- # See https://pre-commit.com for more information
2
- # See https://pre-commit.com/hooks.html for more hooks
3
- exclude: ".git"
4
-
5
- repos:
6
- - repo: https://github.com/astral-sh/ruff-pre-commit
7
- rev: v0.2.2
8
- hooks:
9
- - id: ruff
10
- args: [ --fix, --exit-non-zero-on-fix ]
11
-
12
- - repo: https://github.com/psf/black
13
- rev: 24.2.0
14
- hooks:
15
- - id: black
16
-
17
- - repo: https://github.com/pre-commit/pre-commit-hooks
18
- rev: v4.5.0
19
- hooks:
20
- - id: check-added-large-files
21
- - id: check-ast
22
- - id: check-case-conflict
23
- - id: check-merge-conflict
24
- - id: check-toml
25
- - id: check-yaml
26
- - id: end-of-file-fixer
27
- - id: trailing-whitespace
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/ALOHA.md DELETED
@@ -1,157 +0,0 @@
1
- # OpenVLA-OFT+ in Real-World ALOHA Robot Tasks
2
-
3
- ## Relevant Files
4
-
5
- Evaluation
6
- * `experiments/robot/aloha/`: ALOHA training and eval files
7
- * `run_aloha_eval.py`: ALOHA eval script (CLIENT SIDE; see "SERVER SIDE" below)
8
- * `aloha_utils.py`: ALOHA eval utils
9
- * Other ALOHA robot environment files copied from the original [ALOHA GitHub repo](https://github.com/tonyzhaozh/aloha):
10
- * `constants.py`
11
- * `real_env.py`
12
- * `robot_utils.py`
13
- * `experiments/robot/`: General eval utils files
14
- * `openvla_utils.py`: OpenVLA-specific eval utils
15
- * `robot_utils.py`: Other eval utils
16
- * `vla-scripts/deploy.py`: VLA server deploy script (SERVER SIDE)
17
-
18
- Note: Unlike the LIBERO evaluation setup, we use a server-client interface here. This is particularly useful if the user's machine which commands the robot does not have access to a local GPU with sufficient specs to run the fine-tuned VLA policies.
19
-
20
- Training
21
- * `experiments/robot/aloha/`: ALOHA training and eval files
22
- * `preprocess_split_aloha_data.py`: ALOHA data preprocessing script
23
- * `vla-scripts/finetune.py`: VLA fine-tuning script
24
-
25
- ## Setup
26
-
27
- Set up a conda environment for training policies and deploying them on the VLA server (see instructions in [SETUP.md](SETUP.md)).
28
-
29
- ## Fine-Tuning on ALOHA Robot Data
30
-
31
- We assume that you have collected a set of expert demonstrations on the ALOHA robot already.
32
-
33
- First, use our `preprocess_split_aloha_data.py` script to preprocess the raw ALOHA dataset: downsize images from 480x640 to 256x256 and split into training and validation sets. Below are examples for the `put X into pot` task in our paper (which has 3 possible target objects, 1 per episode):
34
-
35
- ```bash
36
- python experiments/robot/aloha/preprocess_split_aloha_data.py \
37
- --dataset_path /scr/moojink/data/aloha1_raw/put_green_pepper_into_pot/ \
38
- --out_base_dir /scr/moojink/data/aloha1_preprocessed/ \
39
- --percent_val 0.05
40
- python experiments/robot/aloha/preprocess_split_aloha_data.py \
41
- --dataset_path /scr/moojink/data/aloha1_raw/put_red_pepper_into_pot/ \
42
- --out_base_dir /scr/moojink/data/aloha1_preprocessed/ \
43
- --percent_val 0.05
44
- python experiments/robot/aloha/preprocess_split_aloha_data.py \
45
- --dataset_path /scr/moojink/data/aloha1_raw/put_yellow_corn_into_pot/ \
46
- --out_base_dir /scr/moojink/data/aloha1_preprocessed/ \
47
- --percent_val 0.05
48
- ```
49
-
50
- Then, convert the preprocessed ALOHA datasets into a single RLDS dataset that is compatible with OpenVLA fine-tuning. This process is the same as in the original OpenVLA repo. See instructions for converting to RLDS [here](https://github.com/moojink/rlds_dataset_builder) (a sample ALOHA preprocessed-to-RLDS conversion script is available [here](https://github.com/moojink/rlds_dataset_builder/blob/main/aloha1_put_X_into_pot_300_demos/aloha1_put_X_into_pot_300_demos_dataset_builder.py); this script converts the three preprocessed datasets above into one unified RLDS dataset, with train/val splits).
51
-
52
- After converting to RLDS, register the dataset (which, for the example task above, would be called `aloha1_put_X_into_pot_300_demos`) with our dataloader by adding an entry for it in `configs.py` ([here](prismatic/vla/datasets/rlds/oxe/configs.py#L680)), `transforms.py` ([here](prismatic/vla/datasets/rlds/oxe/transforms.py#L928)), and `mixtures.py` ([here](prismatic/vla/datasets/rlds/oxe/mixtures.py#L216)). For reference, in each of these files, there are sample entries for the ALOHA datasets that we used in our paper.
53
-
54
- Before fine-tuning, set the desired ALOHA action chunk size in [`prismatic/vla/constants.py`](prismatic/vla/constants.py) (see `NUM_ACTIONS_CHUNK` in `ALOHA_CONSTANTS`). We set it to 25 by default because we used a control frequency of 25 Hz in our ALOHA setup to reduce storage costs and training time (while still maintaining smoothness in the robot's motions). If you use 50 Hz, we recommend setting `NUM_ACTIONS_CHUNK` to `50`. In general, 1 second-long action chunks are a good default. Do NOT modify `ACTION_PROPRIO_NORMALIZATION_TYPE`: Since the ALOHA robot action space is absolute joint angles, we do not want to use a normalization scheme that clips outlier values (like the Q1-Q99 normalization we used with the relative end-effector pose actions for LIBERO), since that would prevent the model from outputting certain robot joint angles that are crucial for solving the task.
55
-
56
- Now begin fine-tuning! Below is a sample command to fine-tune OpenVLA using our OFT+ recipe on the `put X into pot` task above ("+" in "OFT+" means FiLM is included for enhanced language grounding). Replace `X` in the first line with the number of GPUs available to you.
57
-
58
- ```bash
59
- torchrun --standalone --nnodes 1 --nproc-per-node X vla-scripts/finetune.py \
60
- --vla_path openvla/openvla-7b \
61
- --data_root_dir /PATH/TO/RLDS/DATASETS/DIR/ \
62
- --dataset_name aloha1_put_X_into_pot_300_demos \
63
- --run_root_dir /YOUR/CHECKPOINTS/AND/LOG/DIR/ \
64
- --use_l1_regression True \
65
- --use_diffusion False \
66
- --use_film True \
67
- --num_images_in_input 3 \
68
- --use_proprio True \
69
- --batch_size 4 \
70
- --learning_rate 5e-4 \
71
- --num_steps_before_decay 50000 \
72
- --max_steps 100005 \
73
- --use_val_set True \
74
- --val_freq 10000 \
75
- --save_freq 10000 \
76
- --save_latest_checkpoint_only False \
77
- --image_aug True \
78
- --lora_rank 32 \
79
- --wandb_entity "YOUR_WANDB_ENTITY" \
80
- --wandb_project "YOUR_WANDB_PROJECT" \
81
- --run_id_note parallel_dec--25_acts_chunk--continuous_acts--L1_regression--3rd_person_img--left_right_wrist_imgs--proprio_state--film
82
- ```
83
-
84
- The above training command should reproduce our OpenVLA-OFT+ results on the `put X into pot` task if `X = 8` and the 100K step checkpoint is evaluated. It will fine-tune OpenVLA using 3 input images (1 third-person image + 2 wrist camera images). Note that we use learning rate decay after a certain point (50K steps in the command above) since doing so speeds up training convergence (train L1 loss spikes down from our experience).
85
-
86
- Best practices for fine-tuning:
87
- * In general, we recommend fine-tuning until training L1 loss goes below 0.01 and starts to plateau.
88
- * One way to achieve this is to fine-tune using our default learning rate of `5e-4` until the loss starts to decrease very slowly, and then decay the learning rate by 10x to `5e-5` (which should make the loss spike down) and train until the training L1 loss finally plateaus.
89
- * Depending on your dataset size, you may need to adjust some hyperparameters. For example, if you use a large dataset with over 300 demos, you may need to decay the learning rate later and train for longer for best performance. Decaying too earlier can lead to a suboptimal policy.
90
- * If your task does not require good langauge grounding (e.g., if there is only one language instruction), FiLM is not necessary; consider setting `--use_film False` to train fewer model parameters.
91
- * Please be sure to test your policy with the same device/GPU used to train it! Otherwise, performance may drop substantially. You may be able to avoid the performance drop if you merge the LoRA weights into the base model on the downstream device used for testing (e.g., if you train on H100 and then merge on A100 before testing on A100). You can see our script [vla-scripts/merge_lora_weights_and_save.py](vla-scripts/merge_lora_weights_and_save.py) for merging the LoRA adapter into the base model offline. It's okay if you already merged LoRA weights into the base OpenVLA model during fine-tuning; you can always redownload the base model and merge again as long as you still have the LoRA adapter (`merge_lora_weights_and_save.py` will handle this for you).
92
-
93
- If you run into any issues, please open a new GitHub issue.
94
-
95
- ## Launching ALOHA Robot Evaluations
96
-
97
- In the primary conda environment (`openvla-oft`) which you will use to launch the VLA server, install a few packages for the server-client interface:
98
-
99
- ```bash
100
- conda activate openvla-oft
101
- pip install uvicorn fastapi json-numpy
102
- ```
103
-
104
- On the machine that you will use to command the robot, set up a second conda environment that will be used to run the robot environment, query the VLA server, and execute actions in the environment:
105
-
106
- ```bash
107
- # Create and activate client conda environment
108
- conda create -n openvla-oft-aloha python=3.10 -y
109
- conda activate openvla-oft-aloha
110
-
111
- # Install PyTorch
112
- # Use a command specific to your machine: https://pytorch.org/get-started/locally/
113
- pip3 install torch torchvision torchaudio
114
-
115
- # Clone openvla-oft repo and pip install to download dependencies
116
- git clone https://github.com/moojink/openvla-oft.git
117
- cd openvla-oft
118
- pip install -e .
119
-
120
- # Install packages needed for the ALOHA robot environment
121
- pip install -r experiments/robot/aloha/requirements_aloha.txt
122
- ```
123
-
124
- Launch the VLA server on the machine that has the GPU you will use to run model inference (using the `openvla-oft` conda environment). Below is a sample command for this (change as needed):
125
-
126
- ```bash
127
- python vla-scripts/deploy.py \
128
- --pretrained_checkpoint /PATH/TO/FINETUNED/MODEL/CHECKPOINT/DIR/ \
129
- --use_l1_regression True \
130
- --use_film True \
131
- --num_images_in_input 3 \
132
- --use_proprio True \
133
- --center_crop True \
134
- --unnorm_key aloha1_put_X_into_pot_300_demos
135
- ```
136
-
137
- Then, run the ALOHA evaluation script. Specify the VLA server URL or IP address in the `vla_server_url` argument. Below is a sample command:
138
-
139
- ```bash
140
- python experiments/robot/aloha/run_aloha_eval.py \
141
- --center_crop True \
142
- --num_open_loop_steps 25 \
143
- --use_vla_server True \
144
- --vla_server_url <URL OF VLA SERVER> \
145
- --num_rollouts_planned <NUM TEST ROLLOUTS> \
146
- --max_steps <MAX NUM STEPS PER ROLLOUT>
147
- ```
148
-
149
- If you run into any issues, please open a new GitHub issue.
150
-
151
- ## Troubleshooting Tips
152
-
153
- * Tip #1: If you run into a ROS error such as `ImportError: /lib/x86_64-linux-gnu/libp11-kit.so.0: undefined symbol: ffi_type_pointer, version LIBFFI_BASE_7.0`, try running the following command in your client conda environment (`openvla-oft-aloha`):
154
-
155
- ```
156
- conda install -c conda-forge libffi
157
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/LIBERO.md DELETED
@@ -1,130 +0,0 @@
1
- # OpenVLA-OFT in the LIBERO Simulation Benchmark
2
-
3
- ## Relevant Files
4
-
5
- Evaluation
6
- * `experiments/robot/libero/`: LIBERO eval files
7
- * `run_libero_eval.py`: LIBERO eval script
8
- * `libero_utils.py`: LIBERO eval utils
9
- * `experiments/robot/`: General eval utils files
10
- * `openvla_utils.py`: OpenVLA-specific eval utils
11
- * `robot_utils.py`: Other eval utils
12
-
13
- Training
14
- * `vla-scripts/finetune.py`: VLA fine-tuning script
15
-
16
-
17
- ## Setup
18
-
19
- Set up a conda environment (see instructions in [SETUP.md](SETUP.md)).
20
-
21
- Clone and install the [LIBERO repo](https://github.com/Lifelong-Robot-Learning/LIBERO) and required packages:
22
-
23
- ```bash
24
- git clone https://github.com/Lifelong-Robot-Learning/LIBERO.git
25
- pip install -e LIBERO
26
- pip install -r experiments/robot/libero/libero_requirements.txt # From openvla-oft base dir
27
- ```
28
-
29
- (Optional, if you plan to launch training) To download the [LIBERO datasets](https://huggingface.co/datasets/openvla/modified_libero_rlds) that we used in our fine-tuning
30
- experiments, run the command below. This will download the LIBERO-Spatial, LIBERO-Object, LIBERO-Goal,
31
- and LIBERO-10 datasets in RLDS data format (~10 GB total). You can use these to fine-tune OpenVLA or
32
- train other methods. This step is optional since we provide pretrained OpenVLA-OFT checkpoints below.
33
- Note that these are the same datasets used in the original OpenVLA project. If needed, see details on how to download the original non-RLDS datasets [here](https://github.com/openvla/openvla?tab=readme-ov-file#libero-setup).
34
- ```bash
35
- git clone git@hf.co:datasets/openvla/modified_libero_rlds
36
- ```
37
-
38
- ## Launching LIBERO Evaluations
39
-
40
- We fine-tuned OpenVLA via LoRA (r=32) with our OFT recipe on four LIBERO task suites: LIBERO-Spatial, LIBERO-Object, LIBERO-Goal, and LIBERO-10 (also called LIBERO-Long).
41
- In the initial version of our paper, we trained one checkpoint for each LIBERO task suite independently. In an updated version of the paper, we conducted an additional experiment in which we trained a single policy on all four task suites combined (results for this are available in the Additional Experiments section in the Appendix). Overall, the results for the task-specific policies and the combined policy are comparable: 97.1% vs. 96.8% average success rate across the four suites, respectively.
42
-
43
- Below are the four independently trained OpenVLA-OFT checkpoints for LIBERO:
44
- * [moojink/openvla-7b-oft-finetuned-libero-spatial](https://huggingface.co/moojink/openvla-7b-oft-finetuned-libero-spatial)
45
- * [moojink/openvla-7b-oft-finetuned-libero-object](https://huggingface.co/moojink/openvla-7b-oft-finetuned-libero-object)
46
- * [moojink/openvla-7b-oft-finetuned-libero-goal](https://huggingface.co/moojink/openvla-7b-oft-finetuned-libero-goal)
47
- * [moojink/openvla-7b-oft-finetuned-libero-10](https://huggingface.co/moojink/openvla-7b-oft-finetuned-libero-10)
48
-
49
- Below is the OpenVLA-OFT checkpoint trained on all four task suites combined:
50
- * [moojink/openvla-7b-oft-finetuned-libero-spatial-object-goal-10](https://huggingface.co/moojink/openvla-7b-oft-finetuned-libero-spatial-object-goal-10)
51
-
52
- To start evaluations with one of the independently trained checkpoints, run one of the commands below. Each will automatically download the appropriate checkpoint listed above. You can set the `TRANSFORMERS_CACHE` and `HF_HOME` environment variable to change where the checkpoint files get cached.
53
-
54
- ```bash
55
- # Launch LIBERO-Spatial evals
56
- python experiments/robot/libero/run_libero_eval.py \
57
- --pretrained_checkpoint moojink/openvla-7b-oft-finetuned-libero-spatial \
58
- --task_suite_name libero_spatial
59
-
60
- # Launch LIBERO-Object evals
61
- python experiments/robot/libero/run_libero_eval.py \
62
- --pretrained_checkpoint moojink/openvla-7b-oft-finetuned-libero-object \
63
- --task_suite_name libero_object
64
-
65
- # Launch LIBERO-Goal evals
66
- python experiments/robot/libero/run_libero_eval.py \
67
- --pretrained_checkpoint moojink/openvla-7b-oft-finetuned-libero-goal \
68
- --task_suite_name libero_goal
69
-
70
- # Launch LIBERO-10 (LIBERO-Long) evals
71
- python experiments/robot/libero/run_libero_eval.py \
72
- --pretrained_checkpoint moojink/openvla-7b-oft-finetuned-libero-10 \
73
- --task_suite_name libero_10
74
- ```
75
-
76
- To evaluate the policy trained on all four task suites together, simply swap out the `--pretrained_checkpoint` in the commands above with `moojink/openvla-7b-oft-finetuned-libero-spatial-object-goal-10`.
77
-
78
- Notes:
79
- * The evaluation script will run 500 trials by default (10 tasks x 50 episodes each). You can modify the number of
80
- trials per task by setting `--num_trials_per_task`. You can also change the random seed via `--seed`. There are
81
- other arguments in the script; we set them to the default values that work with the OpenVLA-OFT checkpoints above.
82
- * **NOTE: Setting `--center_crop True` is important** because we fine-tuned OpenVLA with random crop augmentations
83
- (we took a random crop with 90% area in every training sample, so at test time we simply take the center 90% crop).
84
- * The evaluation script logs results locally. You can also log results in Weights & Biases
85
- by setting `--use_wandb True` and specifying `--wandb_project <PROJECT>` and `--wandb_entity <ENTITY>`.
86
- * The results reported in our paper were obtained using **Python 3.10.14, PyTorch 2.2.0, and our
87
- [custom transformers v4.40.1 fork](https://github.com/moojink/transformers-openvla-oft.git)**
88
- on an **NVIDIA A100 GPU**, averaged over three random seeds. Please stick to these package versions if possible.
89
- Note that results may vary slightly if you use a different GPU than the A100. If the discrepancy is large,
90
- please post a GitHub issue, and we will look into it.
91
-
92
- ## Fine-Tuning on LIBERO Datasets
93
-
94
- First, download the LIBERO datasets as mentioned above in the Setup section above: `libero_spatial_no_noops`, `libero_object_no_noops`, `libero_goal_no_noops`, `libero_10_no_noops`. (`"_no_noops"` stands for no no-op actions, i.e., training samples with near-zero actions are filtered out).
95
-
96
- Then, launch the fine-tuning script with the OFT configuration below, replacing `X` in the first line with the number of GPUs. The command below launches fine-tuning on LIBERO-Spatial with the hyperparameters that we used in our paper. Here, batch size 8 per GPU will require ~62 GB VRAM, and batch size 1 per GPU will require ~25 GB VRAM.
97
-
98
- ```bash
99
- torchrun --standalone --nnodes 1 --nproc-per-node X vla-scripts/finetune.py \
100
- --vla_path openvla/openvla-7b \
101
- --data_root_dir /PATH/TO/RLDS/DATASETS/DIR/ \
102
- --dataset_name libero_spatial_no_noops \
103
- --run_root_dir /YOUR/CHECKPOINTS/AND/LOG/DIR/ \
104
- --use_l1_regression True \
105
- --use_diffusion False \
106
- --use_film False \
107
- --num_images_in_input 2 \
108
- --use_proprio True \
109
- --batch_size 8 \
110
- --learning_rate 5e-4 \
111
- --num_steps_before_decay 100000 \
112
- --max_steps 150005 \
113
- --save_freq 10000 \
114
- --save_latest_checkpoint_only False \
115
- --image_aug True \
116
- --lora_rank 32 \
117
- --wandb_entity "YOUR_WANDB_ENTITY" \
118
- --wandb_project "YOUR_WANDB_PROJECT" \
119
- --run_id_note parallel_dec--8_acts_chunk--continuous_acts--L1_regression--3rd_person_img--wrist_img--proprio_state
120
- ```
121
-
122
- The above training command should reproduce our OpenVLA-OFT results if `X = 8` and the 150K step checkpoint is evaluated.
123
-
124
- You can replace `libero_spatial_no_noops` with `libero_object_no_noops`, `libero_goal_no_noops`, or `libero_10_no_noops`. You can also modify other args — e.g., if you want to train with just one input image from the third-person camera and disable proprio state input, you can set `--num_images_in_input 1` and `--use_proprio False`.
125
-
126
- In general, we recommend fine-tuning until training L1 loss goes below 0.01 and starts to plateau (with the above configuration, it should reach ~0.006 L1 loss on LIBERO-Spatial after 150K gradient steps with 10x LR decay after 100K steps). However, for LIBERO-Goal only, we found that the 50K checkpoint (which was at ~0.02 L1 loss) performed best for unknown reasons. For all other task suites though, we found that the 150K checkpoint performed best.
127
-
128
- Please be sure to test your policy with the same device/GPU used to train it! Otherwise, performance may drop substantially. You may be able to avoid the performance drop if you merge the LoRA weights into the base model on the downstream device used for testing (e.g., if you train on H100 and then merge on A100 before testing on A100). You can see our script [vla-scripts/merge_lora_weights_and_save.py](vla-scripts/merge_lora_weights_and_save.py) for merging the LoRA adapter into the base model offline. It's okay if you already merged LoRA weights into the base OpenVLA model during fine-tuning; you can always redownload the base model and merge again as long as you still have the LoRA adapter (`merge_lora_weights_and_save.py` will handle this for you).
129
-
130
- If you run into any issues, please open a new GitHub issue. If you do not receive a response within 2 business days, please email Moo Jin Kim (moojink@cs.stanford.edu) to bring the issue to his attention.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/LICENSE DELETED
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) 2025 Moo Jin Kim, Chelsea Finn, Percy Liang.
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/SETUP.md DELETED
@@ -1,24 +0,0 @@
1
- # Setup Instructions
2
-
3
- ## Set Up Conda Environment
4
-
5
- ```bash
6
- # Create and activate conda environment
7
- conda create -n capvector-openvla-oft python=3.10 -y
8
- conda activate capvector-openvla-oft
9
-
10
- # Install PyTorch
11
- # Use a command specific to your machine: https://pytorch.org/get-started/locally/
12
- pip3 install torch torchvision torchaudio
13
-
14
- # Clone openvla-oft repo and pip install to download dependencies
15
- git clone https://github.com/Songwxuan/CapVector
16
- cd openvla-oft
17
- pip install -e .
18
-
19
- # Install Flash Attention 2 for training (https://github.com/Dao-AILab/flash-attention)
20
- # =>> If you run into difficulty, try `pip cache remove flash_attn` first
21
- pip install packaging ninja
22
- ninja --version; echo $? # Verify Ninja --> should return exit code "0"
23
- pip install "flash-attn==2.5.5" --no-build-isolation
24
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/capvector/.gitignore DELETED
@@ -1,8 +0,0 @@
1
- bin/
2
- draw_pic/
3
- feature_vector_ckpt/
4
- figure/
5
- id_extrapolation/
6
- id_interpolation/
7
- initialized_pt_vla/
8
- lora_diff/
 
 
 
 
 
 
 
 
 
capvector-oft/capvector/compute_lora_diff.py DELETED
@@ -1,35 +0,0 @@
1
- from safetensors.torch import load_file, save_file
2
- import torch
3
- import argparse
4
-
5
- def main():
6
- parser = argparse.ArgumentParser()
7
- parser.add_argument("--base", required=True)
8
- parser.add_argument("--target", required=True)
9
- parser.add_argument("--out", default="lora_diff.safetensors")
10
- args = parser.parse_args()
11
-
12
- base = load_file(args.base)
13
- target = load_file(args.target)
14
-
15
- diff = {}
16
-
17
- print("=== Key Comparison ===")
18
- only_in_base = set(base) - set(target)
19
- only_in_target = set(target) - set(base)
20
-
21
- print("Only in base:", list(only_in_base)[:10])
22
- print("Only in target:", list(only_in_target)[:10])
23
-
24
- for k in target:
25
- if k in base:
26
- diff[k] = target[k] - base[k]
27
- else:
28
- # new parameters are directly retained
29
- diff[k] = target[k].clone()
30
-
31
- save_file(diff, args.out)
32
- print(f"\nSaved diff to: {args.out}")
33
-
34
- if __name__ == "__main__":
35
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/capvector/compute_lora_shell/compute_lora_diff.sh DELETED
@@ -1,8 +0,0 @@
1
- BASE_ADAPTER="checkpoints/reference_models/openvla_oft_libero_spatial/lora_adapter/adapter_model.safetensors"
2
- TARGET_ADAPTER="checkpoints/task_models/SF_spatial/lora_adapter/adapter_model.safetensors"
3
- OUTPUT_DIFF="checkpoints/lora_diff/sf_150000_steps_spatial_adapter_diff.safetensors"
4
-
5
- python compute_lora_diff.py \
6
- --base "$BASE_ADAPTER" \
7
- --target "$TARGET_ADAPTER" \
8
- --out "$OUTPUT_DIFF"
 
 
 
 
 
 
 
 
 
capvector-oft/capvector/initialized_interpolate_shell/get_vector_robotwin.sh DELETED
@@ -1,26 +0,0 @@
1
- TASK=bigbin_pot_microwave_qrcode_bowlsthree # Customize for your task
2
- VERSION=53
3
- PT_CKPT="checkpoints/openvla_base"
4
- TASK_MODEL_CHECKPOINT="checkpoints/task_models/v106.1"
5
- REFERENCE_MODEL_CHECKPOINT="checkpoints/reference_models/v106.0"
6
- VECTOR_SAVE_PATH="checkpoints/feature_vectors/feature_vector_with_SF_${TASK}_v${VERSION}.pth"
7
- INITIALIZED_PT_VLA_PATH="checkpoints/initialized_pt_vla/initailized_openvla_with_SF_${TASK}_v${VERSION}"
8
- TASK_SUITE_NAME="ALOHA_${TASK}"
9
-
10
- python interpolate_robotwin.py \
11
- --pretrained_checkpoint "$TASK_MODEL_CHECKPOINT" \
12
- --original_pretrained_checkpoint "$REFERENCE_MODEL_CHECKPOINT" \
13
- --vector_save_path "$VECTOR_SAVE_PATH" \
14
- --initialized_pt_vla_path $INITIALIZED_PT_VLA_PATH \
15
- --pt_ckpt $PT_CKPT\
16
- --feature_vector_weight 1.1 \
17
- --task_suite_name $TASK_SUITE_NAME
18
-
19
- #the code below is used to transplant the model parameters except for the vla backbone, such as processor and tokenizer, so that the initialized model is complete
20
-
21
- rsync -av \
22
- --ignore-existing \
23
- --exclude='*.safetensors' \
24
- --exclude='*.back.*' \
25
- $PT_CKPT/ \
26
- $INITIALIZED_PT_VLA_PATH/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/capvector/interpolate.py DELETED
@@ -1,247 +0,0 @@
1
- """
2
- This is for extracting feature vector from the openvla-oft model and interpolating it with the original openvla model.
3
- """
4
-
5
-
6
- import os
7
- import json
8
- import logging
9
-
10
- import sys
11
- from collections import deque
12
- from dataclasses import dataclass
13
- from enum import Enum
14
- from pathlib import Path
15
- from typing import Optional, Union
16
- from PIL import Image
17
-
18
- import draccus
19
- import numpy as np
20
- from tqdm import tqdm
21
- import torch
22
- import copy
23
-
24
- import wandb
25
-
26
- REPO_ROOT = Path(__file__).resolve().parents[1]
27
- if str(REPO_ROOT) not in sys.path:
28
- sys.path.append(str(REPO_ROOT))
29
- from experiments.robot.openvla_utils import (
30
- get_action_head,
31
- get_noisy_action_projector,
32
- get_processor,
33
- get_proprio_projector,
34
- resize_image_for_policy,
35
- )
36
- from experiments.robot.robot_utils import (
37
- DATE_TIME,
38
- get_action,
39
- get_image_resize_size,
40
- get_model,
41
- invert_gripper_action,
42
- normalize_gripper_action,
43
- set_seed_everywhere,
44
- )
45
- from experiments.robot.libero.run_libero_eval import check_unnorm_key
46
- from prismatic.vla.constants import NUM_ACTIONS_CHUNK
47
-
48
-
49
- # Set up logging
50
- logging.basicConfig(
51
- level=logging.INFO,
52
- format="%(asctime)s [%(levelname)s] %(message)s",
53
- handlers=[logging.StreamHandler()],
54
- )
55
- logger = logging.getLogger(__name__)
56
-
57
-
58
- @dataclass
59
- class GenerateConfig:
60
- # fmt: off
61
-
62
- #################################################################################################################
63
- # Model-specific parameters
64
- #################################################################################################################
65
- model_family: str = "openvla" # Model family
66
- #the task-specific model after sf fine-tuning
67
- pretrained_checkpoint: Union[str, Path] = "checkpoints/task_model" # Task-specific checkpoint path
68
- #the task-specific model after oft fine-tuning
69
- original_pretrained_checkpoint: Union[str, Path] = "checkpoints/reference_model" # Reference checkpoint path
70
- #feature vector is the difference between the two models, which represents the spatial features
71
- vector_save_path: Union[str, Path] = "checkpoints/feature_vectors/feature_vector.pth"
72
- #the pt vla model initialized with the feature vector, named rule: initailized_{pt_ckpt}_with_{task-specific model name}_${task name on libero}
73
- initialized_pt_vla_path: Union[str, Path] = "checkpoints/initialized_pt_vla"
74
- #the original pretrained openvla model
75
- pt_ckpt: Union[str, Path] = "checkpoints/openvla_base"
76
- #the weight of the feature vector when initializing the pt vla model
77
- feature_vector_weight: float = 1 # Weight of feature vector for interpolation
78
-
79
- use_l1_regression: bool = True # If True, uses continuous action head with L1 regression objective
80
- use_diffusion: bool = False # If True, uses continuous action head with diffusion modeling objective (DDIM)
81
- num_diffusion_steps_train: int = 50 # (When `diffusion==True`) Number of diffusion steps used for training
82
- num_diffusion_steps_inference: int = 50 # (When `diffusion==True`) Number of diffusion steps used for inference
83
- use_film: bool = False # If True, uses FiLM to infuse language inputs into visual features
84
- num_images_in_input: int = 2 # Number of images in the VLA input (default: 1)
85
- use_proprio: bool = True # Whether to include proprio state in input
86
-
87
- center_crop: bool = True # Center crop? (if trained w/ random crop image aug)
88
- num_open_loop_steps: int = 8 # Number of actions to execute open-loop before requerying policy
89
-
90
- lora_rank: int = 32 # Rank of LoRA weight matrix (MAKE SURE THIS MATCHES TRAINING!)
91
-
92
- unnorm_key: Union[str, Path] = "" # Action un-normalization key
93
-
94
- load_in_8bit: bool = False # (For OpenVLA only) Load with 8-bit quantization
95
- load_in_4bit: bool = False # (For OpenVLA only) Load with 4-bit quantization
96
-
97
- #################################################################################################################
98
- # LIBERO environment-specific parameters
99
- #################################################################################################################
100
- task_suite_name: str = "de" # Task suite
101
- num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize in sim
102
- num_trials_per_task: int = 50 # Number of rollouts per task
103
- initial_states_path: str = "DEFAULT" # "DEFAULT", or path to initial states JSON file
104
- env_img_res: int = 256 # Resolution for environment images (not policy input resolution)
105
-
106
- #################################################################################################################
107
- # Utils
108
- #################################################################################################################
109
- run_id_note: Optional[str] = None # Extra note to add to end of run ID for logging
110
- local_log_dir: str = "./experiments/logs" # Local directory for eval logs
111
-
112
- use_wandb: bool = False # Whether to also log results in Weights & Biases
113
- wandb_entity: str = "your-wandb-entity" # Name of WandB entity
114
- wandb_project: str = "your-wandb-project" # Name of WandB project
115
-
116
- seed: int = 7 # Random Seed (for reproducibility)
117
-
118
- def validate_config(cfg: GenerateConfig) -> None:
119
- """Validate configuration parameters."""
120
- assert cfg.pretrained_checkpoint is not None, "pretrained_checkpoint must not be None!"
121
-
122
- if "image_aug" in str(cfg.pretrained_checkpoint):
123
- assert cfg.center_crop, "Expecting `center_crop==True` because model was trained with image augmentations!"
124
-
125
- assert not (cfg.load_in_8bit and cfg.load_in_4bit), "Cannot use both 8-bit and 4-bit quantization!"
126
-
127
- # Validate task suite
128
- assert cfg.task_suite_name in [suite.value for suite in TaskSuite], f"Invalid task suite: {cfg.task_suite_name}"
129
-
130
- def initialize_model(cfg: GenerateConfig, only_pt: bool = False): #load action_head and noisy_action_projector separately
131
- """Initialize model and associated components."""
132
- # Load model
133
- model = get_model(cfg)
134
-
135
- # Load proprio projector if needed
136
- proprio_projector = None
137
- if cfg.use_proprio:
138
- proprio_projector = get_proprio_projector(
139
- cfg,
140
- model.llm_dim,
141
- proprio_dim=8, # 8-dimensional proprio for LIBERO
142
- )
143
-
144
- # Load action head if needed
145
- action_head = None
146
- if cfg.use_l1_regression or cfg.use_diffusion:
147
- action_head = get_action_head(cfg, model.llm_dim)
148
-
149
- # Load noisy action projector if using diffusion
150
- noisy_action_projector = None
151
- if cfg.use_diffusion:
152
- noisy_action_projector = get_noisy_action_projector(cfg, model.llm_dim)
153
-
154
- # Get OpenVLA processor if needed
155
- processor = None
156
- if not only_pt:
157
- if cfg.model_family == "openvla":
158
- processor = get_processor(cfg)
159
- check_unnorm_key(cfg, model)
160
-
161
- return model, action_head, proprio_projector, noisy_action_projector, processor
162
-
163
- # @draccus.wrap()
164
- def generate_feature_vector(cfg: GenerateConfig):
165
- """Generate a feature vector (parameter differences) between two task-specific models."""
166
- # Validate configuration
167
-
168
- # Set random seed
169
- set_seed_everywhere(cfg.seed)
170
-
171
- # Initialize model and components
172
- model, action_head, proprio_projector, noisy_action_projector, processor = initialize_model(cfg)
173
-
174
- original_config = GenerateConfig(
175
- pretrained_checkpoint=cfg.original_pretrained_checkpoint,
176
- task_suite_name=cfg.task_suite_name,
177
- )
178
-
179
- original_model, original_action_head, original_proprio_projector, original_noisy_action_projector, original_processor = initialize_model(original_config)
180
- #for action_head and noisy_action_projector, these modules are not interpolated
181
- assert len(model.state_dict()) == len(original_model.state_dict())
182
- feature_vector_dict = {}
183
- total = len(original_model.state_dict())
184
- for name, original_model_param in tqdm(original_model.named_parameters(), total=total):
185
- model_param = model.state_dict()[name]
186
- feature_vector_dict[name] = (model_param - original_model_param).detach().cpu()
187
-
188
- return feature_vector_dict
189
-
190
- # @draccus.wrap()
191
- def interpolate_feature_vector(cfg: GenerateConfig):
192
- """Interpolate feature vector."""
193
- feature_vector_dict = torch.load(cfg.vector_save_path)
194
-
195
- pt_vla_config = GenerateConfig(
196
- pretrained_checkpoint=cfg.pt_ckpt,
197
- original_pretrained_checkpoint=cfg.original_pretrained_checkpoint,
198
- vector_save_path=cfg.vector_save_path,
199
- initialized_pt_vla_path=cfg.initialized_pt_vla_path,
200
- feature_vector_weight=cfg.feature_vector_weight,
201
- pt_ckpt=cfg.pt_ckpt,
202
- task_suite_name=cfg.task_suite_name,
203
- use_proprio=False,
204
- use_l1_regression=False,
205
- use_diffusion=False
206
- )
207
-
208
- pt_vla,_,_,_,_ = initialize_model(pt_vla_config, only_pt=True)
209
-
210
- #copy the SF parameters for checking the change before and after interpolation
211
- model_sd = pt_vla.state_dict()
212
- before_interp_sd = {k: v.clone() for k, v in model_sd.items() if v.dtype.is_floating_point}
213
-
214
- with torch.no_grad():
215
- pt_params = dict(pt_vla.named_parameters())
216
- for name, diff in feature_vector_dict.items():
217
- if name in pt_params:
218
- pt_param = pt_params[name]
219
- diff = diff.to(pt_param.device)
220
- pt_param.add_(diff, alpha=cfg.feature_vector_weight)
221
-
222
- #check after interpolation
223
- diffs_after = []
224
- for name, before_tensor in before_interp_sd.items():
225
- after_tensor = model_sd[name]
226
- difference = (after_tensor - before_tensor).float().norm().item()
227
- diffs_after.append(difference)
228
-
229
- print(f"[DEBUG] post-interp (SF -> interp): mean={sum(diffs_after)/len(diffs_after):.6f}, "
230
- f"max={max(diffs_after):.6f}, num_tensors={len(diffs_after)}")
231
-
232
- #########################################################
233
- return pt_vla
234
-
235
- @draccus.wrap()
236
- def main(cfg: GenerateConfig):
237
- if not os.path.exists(cfg.vector_save_path):
238
- feature_vector_dict = generate_feature_vector(cfg)
239
- torch.save(feature_vector_dict, cfg.vector_save_path)
240
- else:
241
- print(f"Feature vector already exists at {cfg.vector_save_path}")
242
- initialized_pt_vla = interpolate_feature_vector(cfg)
243
- os.makedirs(cfg.initialized_pt_vla_path, exist_ok=True)
244
- initialized_pt_vla.save_pretrained(cfg.initialized_pt_vla_path)
245
-
246
- if __name__ == "__main__":
247
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/capvector/interpolate.sh DELETED
@@ -1,26 +0,0 @@
1
- TASK=spatial # or object / goal / 10 / 90
2
- VERSION=21.4
3
- PT_CKPT="checkpoints/openvla_base"
4
- TASK_MODEL_CHECKPOINT="checkpoints/task_models/SF_${TASK}"
5
- REFERENCE_MODEL_CHECKPOINT="checkpoints/reference_models/openvla_oft_libero_${TASK}"
6
- VECTOR_SAVE_PATH="checkpoints/feature_vectors/feature_vector_with_SF_${TASK}_v${VERSION}.pth"
7
- INITIALIZED_PT_VLA_PATH="checkpoints/initialized_pt_vla/initailized_openvla_with_SF_${TASK}_v${VERSION}"
8
- TASK_SUITE_NAME="libero_${TASK}"
9
-
10
- python interpolate.py \
11
- --pretrained_checkpoint "$TASK_MODEL_CHECKPOINT" \
12
- --original_pretrained_checkpoint "$REFERENCE_MODEL_CHECKPOINT" \
13
- --vector_save_path "$VECTOR_SAVE_PATH" \
14
- --initialized_pt_vla_path $INITIALIZED_PT_VLA_PATH \
15
- --pt_ckpt $PT_CKPT\
16
- --feature_vector_weight 0.5 \
17
- --task_suite_name $TASK_SUITE_NAME
18
-
19
- #the code below is used to transplant the model parameters except for the vla backbone, such as processor and tokenizer, so that the initialized model is complete
20
-
21
- rsync -av \
22
- --ignore-existing \
23
- --exclude='*.safetensors' \
24
- --exclude='*.back.*' \
25
- $PT_CKPT/ \
26
- $INITIALIZED_PT_VLA_PATH/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/capvector/interpolate_robotwin.py DELETED
@@ -1,247 +0,0 @@
1
- """
2
- This is for extracting feature vector from the openvla-oft model and interpolating it with the original openvla model.
3
- """
4
-
5
-
6
- import os
7
- import json
8
- import logging
9
-
10
- import sys
11
- from collections import deque
12
- from dataclasses import dataclass
13
- from enum import Enum
14
- from pathlib import Path
15
- from typing import Optional, Union
16
- from PIL import Image
17
-
18
- import draccus
19
- import numpy as np
20
- from tqdm import tqdm
21
- import torch
22
- import copy
23
-
24
- import wandb
25
-
26
- REPO_ROOT = Path(__file__).resolve().parents[1]
27
- if str(REPO_ROOT) not in sys.path:
28
- sys.path.append(str(REPO_ROOT))
29
- from experiments.robot.openvla_utils import (
30
- get_action_head,
31
- get_noisy_action_projector,
32
- get_processor,
33
- get_proprio_projector,
34
- resize_image_for_policy,
35
- )
36
- from experiments.robot.robot_utils import (
37
- DATE_TIME,
38
- get_action,
39
- get_image_resize_size,
40
- get_model,
41
- invert_gripper_action,
42
- normalize_gripper_action,
43
- set_seed_everywhere,
44
- )
45
- from experiments.robot.libero.run_libero_eval import check_unnorm_key
46
- from prismatic.vla.constants import NUM_ACTIONS_CHUNK
47
- from prismatic.vla.constants import PROPRIO_DIM
48
-
49
- # Set up logging
50
- logging.basicConfig(
51
- level=logging.INFO,
52
- format="%(asctime)s [%(levelname)s] %(message)s",
53
- handlers=[logging.StreamHandler()],
54
- )
55
- logger = logging.getLogger(__name__)
56
-
57
-
58
- @dataclass
59
- class GenerateConfig:
60
- # fmt: off
61
-
62
- #################################################################################################################
63
- # Model-specific parameters
64
- #################################################################################################################
65
- model_family: str = "openvla" # Model family
66
- #the task-specific model after sf fine-tuning
67
- pretrained_checkpoint: Union[str, Path] = "checkpoints/task_model" # Task-specific checkpoint path
68
- #the task-specific model after oft fine-tuning
69
- original_pretrained_checkpoint: Union[str, Path] = "checkpoints/reference_model" # Reference checkpoint path
70
- #feature vector is the difference between the two models, which represents the spatial features
71
- vector_save_path: Union[str, Path] = "checkpoints/feature_vectors/feature_vector.pth"
72
- #the pt vla model initialized with the feature vector, named rule: initailized_{pt_ckpt}_with_{task-specific model name}_${task name on libero}
73
- initialized_pt_vla_path: Union[str, Path] = "checkpoints/initialized_pt_vla"
74
- #the original pretrained openvla model
75
- pt_ckpt: Union[str, Path] = "checkpoints/openvla_base"
76
- #the weight of the feature vector when initializing the pt vla model
77
- feature_vector_weight: float = 1 # Weight of feature vector for interpolation
78
-
79
- use_l1_regression: bool = True # If True, uses continuous action head with L1 regression objective
80
- use_diffusion: bool = False # If True, uses continuous action head with diffusion modeling objective (DDIM)
81
- num_diffusion_steps_train: int = 50 # (When `diffusion==True`) Number of diffusion steps used for training
82
- num_diffusion_steps_inference: int = 50 # (When `diffusion==True`) Number of diffusion steps used for inference
83
- use_film: bool = False # If True, uses FiLM to infuse language inputs into visual features
84
- num_images_in_input: int = 3 # Number of images in the VLA input (default: 1)
85
- use_proprio: bool = True # Whether to include proprio state in input
86
-
87
- center_crop: bool = True # Center crop? (if trained w/ random crop image aug)
88
- num_open_loop_steps: int = 8 # Number of actions to execute open-loop before requerying policy
89
-
90
- lora_rank: int = 32 # Rank of LoRA weight matrix (MAKE SURE THIS MATCHES TRAINING!)
91
-
92
- unnorm_key: Union[str, Path] = "" # Action un-normalization key
93
-
94
- load_in_8bit: bool = False # (For OpenVLA only) Load with 8-bit quantization
95
- load_in_4bit: bool = False # (For OpenVLA only) Load with 4-bit quantization
96
-
97
- #################################################################################################################
98
- # LIBERO environment-specific parameters
99
- #################################################################################################################
100
- task_suite_name: str = "de" # Task suite
101
- num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize in sim
102
- num_trials_per_task: int = 50 # Number of rollouts per task
103
- initial_states_path: str = "DEFAULT" # "DEFAULT", or path to initial states JSON file
104
- env_img_res: int = 256 # Resolution for environment images (not policy input resolution)
105
-
106
- #################################################################################################################
107
- # Utils
108
- #################################################################################################################
109
- run_id_note: Optional[str] = None # Extra note to add to end of run ID for logging
110
- local_log_dir: str = "./experiments/logs" # Local directory for eval logs
111
-
112
- use_wandb: bool = False # Whether to also log results in Weights & Biases
113
- wandb_entity: str = "your-wandb-entity" # Name of WandB entity
114
- wandb_project: str = "your-wandb-project" # Name of WandB project
115
-
116
- seed: int = 7 # Random Seed (for reproducibility)
117
-
118
- def validate_config(cfg: GenerateConfig) -> None:
119
- """Validate configuration parameters."""
120
- assert cfg.pretrained_checkpoint is not None, "pretrained_checkpoint must not be None!"
121
-
122
- if "image_aug" in str(cfg.pretrained_checkpoint):
123
- assert cfg.center_crop, "Expecting `center_crop==True` because model was trained with image augmentations!"
124
-
125
- assert not (cfg.load_in_8bit and cfg.load_in_4bit), "Cannot use both 8-bit and 4-bit quantization!"
126
-
127
- # Validate task suite
128
- # assert cfg.task_suite_name in [suite.value for suite in TaskSuite], f"Invalid task suite: {cfg.task_suite_name}"
129
-
130
- def initialize_model(cfg: GenerateConfig, only_pt: bool = False): #load action_head and noisy_action_projector separately
131
- """Initialize model and associated components."""
132
- # Load model
133
- model = get_model(cfg)
134
-
135
- # Load proprio projector if needed
136
- proprio_projector = None
137
- if cfg.use_proprio:
138
- proprio_projector = get_proprio_projector(
139
- cfg,
140
- model.llm_dim,
141
- proprio_dim=PROPRIO_DIM, #set the proprio_dim for different robots
142
- )
143
-
144
- # Load action head if needed
145
- action_head = None
146
- if cfg.use_l1_regression or cfg.use_diffusion:
147
- action_head = get_action_head(cfg, model.llm_dim)
148
-
149
- # Load noisy action projector if using diffusion
150
- noisy_action_projector = None
151
- if cfg.use_diffusion:
152
- noisy_action_projector = get_noisy_action_projector(cfg, model.llm_dim)
153
-
154
- # Get OpenVLA processor if needed
155
- processor = None
156
- if not only_pt:
157
- if cfg.model_family == "openvla":
158
- processor = get_processor(cfg)
159
- # check_unnorm_key(cfg, model)
160
-
161
- return model, action_head, proprio_projector, noisy_action_projector, processor
162
-
163
- # @draccus.wrap()
164
- def generate_feature_vector(cfg: GenerateConfig):
165
- """Generate a feature vector (parameter differences) between two task-specific models."""
166
- # Validate configuration
167
-
168
- # Set random seed
169
- set_seed_everywhere(cfg.seed)
170
-
171
- # Initialize model and components
172
- model, action_head, proprio_projector, noisy_action_projector, processor = initialize_model(cfg)
173
-
174
- original_config = GenerateConfig(
175
- pretrained_checkpoint=cfg.original_pretrained_checkpoint,
176
- task_suite_name=cfg.task_suite_name,
177
- )
178
-
179
- original_model, original_action_head, original_proprio_projector, original_noisy_action_projector, original_processor = initialize_model(original_config)
180
- #for action_head and noisy_action_projector, these modules are not interpolated
181
- assert len(model.state_dict()) == len(original_model.state_dict())
182
- feature_vector_dict = {}
183
- total = len(original_model.state_dict())
184
- for name, original_model_param in tqdm(original_model.named_parameters(), total=total):
185
- model_param = model.state_dict()[name]
186
- feature_vector_dict[name] = (model_param - original_model_param).detach().cpu()
187
-
188
- return feature_vector_dict
189
-
190
- # @draccus.wrap()
191
- def interpolate_feature_vector(cfg: GenerateConfig):
192
- """Interpolate feature vector."""
193
- feature_vector_dict = torch.load(cfg.vector_save_path)
194
-
195
- pt_vla_config = GenerateConfig(
196
- pretrained_checkpoint=cfg.pt_ckpt,
197
- original_pretrained_checkpoint=cfg.original_pretrained_checkpoint,
198
- vector_save_path=cfg.vector_save_path,
199
- initialized_pt_vla_path=cfg.initialized_pt_vla_path,
200
- feature_vector_weight=cfg.feature_vector_weight,
201
- pt_ckpt=cfg.pt_ckpt,
202
- task_suite_name=cfg.task_suite_name,
203
- use_proprio=False,
204
- use_l1_regression=False,
205
- use_diffusion=False
206
- )
207
-
208
- pt_vla,_,_,_,_ = initialize_model(pt_vla_config, only_pt=True)
209
-
210
- #copy the SF parameters for checking the change before and after interpolation
211
- model_sd = pt_vla.state_dict()
212
- before_interp_sd = {k: v.clone() for k, v in model_sd.items() if v.dtype.is_floating_point}
213
-
214
- with torch.no_grad():
215
- pt_params = dict(pt_vla.named_parameters())
216
- for name, diff in feature_vector_dict.items():
217
- if name in pt_params:
218
- pt_param = pt_params[name]
219
- diff = diff.to(pt_param.device)
220
- pt_param.add_(diff, alpha=cfg.feature_vector_weight)
221
-
222
- #check after interpolation
223
- diffs_after = []
224
- for name, before_tensor in before_interp_sd.items():
225
- after_tensor = model_sd[name]
226
- difference = (after_tensor - before_tensor).float().norm().item()
227
- diffs_after.append(difference)
228
-
229
- print(f"[DEBUG] post-interp (SF -> interp): mean={sum(diffs_after)/len(diffs_after):.6f}, "
230
- f"max={max(diffs_after):.6f}, num_tensors={len(diffs_after)}")
231
-
232
- #########################################################
233
- return pt_vla
234
-
235
- @draccus.wrap()
236
- def main(cfg: GenerateConfig):
237
- if not os.path.exists(cfg.vector_save_path):
238
- feature_vector_dict = generate_feature_vector(cfg)
239
- torch.save(feature_vector_dict, cfg.vector_save_path)
240
- else:
241
- print(f"Feature vector already exists at {cfg.vector_save_path}")
242
- initialized_pt_vla = interpolate_feature_vector(cfg)
243
- os.makedirs(cfg.initialized_pt_vla_path, exist_ok=True)
244
- initialized_pt_vla.save_pretrained(cfg.initialized_pt_vla_path)
245
-
246
- if __name__ == "__main__":
247
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/capvector/tools/check_model_config.py DELETED
@@ -1,23 +0,0 @@
1
- #This is for checking the completeness of the model parameters.
2
- import argparse
3
-
4
- import torch
5
-
6
-
7
- def main():
8
- parser = argparse.ArgumentParser()
9
- parser.add_argument("checkpoint_path", help="Path to the feature vector checkpoint (.pth)")
10
- args = parser.parse_args()
11
-
12
- fv = torch.load(args.checkpoint_path, map_location="cpu")
13
-
14
- print("num_tensors:", len(fv))
15
- nz = 0
16
- for _, value in fv.items():
17
- if value.abs().sum().item() != 0:
18
- nz += 1
19
- print("nonzero_tensors:", nz)
20
-
21
-
22
- if __name__ == "__main__":
23
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/capvector/tools/compute_lora_diff.py DELETED
@@ -1,36 +0,0 @@
1
- #This is for computing the difference between the base model and the target model.
2
- from safetensors.torch import load_file, save_file
3
- import torch
4
- import argparse
5
-
6
- def main():
7
- parser = argparse.ArgumentParser()
8
- parser.add_argument("--base", required=True)
9
- parser.add_argument("--target", required=True)
10
- parser.add_argument("--out", default="lora_diff.safetensors")
11
- args = parser.parse_args()
12
-
13
- base = load_file(args.base)
14
- target = load_file(args.target)
15
-
16
- diff = {}
17
-
18
- print("=== Key Comparison ===")
19
- only_in_base = set(base) - set(target)
20
- only_in_target = set(target) - set(base)
21
-
22
- print("Only in base:", list(only_in_base)[:10])
23
- print("Only in target:", list(only_in_target)[:10])
24
-
25
- for k in target:
26
- if k in base:
27
- diff[k] = target[k] - base[k]
28
- else:
29
- # keep the new parameters
30
- diff[k] = target[k].clone()
31
-
32
- save_file(diff, args.out)
33
- print(f"\nSaved diff to: {args.out}")
34
-
35
- if __name__ == "__main__":
36
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/capvector/tools/compute_lora_diff.sh DELETED
@@ -1,8 +0,0 @@
1
- BASE_ADAPTER="checkpoints/reference_models/openvla_oft_libero_spatial/lora_adapter/adapter_model.safetensors"
2
- TARGET_ADAPTER="checkpoints/task_models/SF_spatial/lora_adapter/adapter_model.safetensors"
3
- OUTPUT_DIFF="checkpoints/lora_diff/sf_150000_steps_spatial_adapter_diff.safetensors"
4
-
5
- python compute_lora_diff.py \
6
- --base "$BASE_ADAPTER" \
7
- --target "$TARGET_ADAPTER" \
8
- --out "$OUTPUT_DIFF"
 
 
 
 
 
 
 
 
 
capvector-oft/capvector/tools/vector_analyze.py DELETED
@@ -1,153 +0,0 @@
1
- #This is for analyzing the vector of the model and finding out which layers have the largest absolute values.
2
- import argparse
3
- import csv
4
- import os
5
- import re
6
- from collections import OrderedDict, defaultdict
7
-
8
- import matplotlib.pyplot as plt
9
- import torch
10
-
11
-
12
- LAYER_PREFIX = "language_model.model.layers."
13
- NUM_LAYERS = 32
14
- USE_LOG_Y = True
15
-
16
-
17
- def pick_state_dict(obj):
18
- if isinstance(obj, (OrderedDict, dict)):
19
- for key in ["state_dict", "model_state_dict", "model", "net", "weights", "params"]:
20
- if key in obj and isinstance(obj[key], (OrderedDict, dict)):
21
- return obj[key]
22
- if any(torch.is_tensor(value) for value in obj.values()):
23
- return obj
24
- return None
25
-
26
-
27
- def aggregate_layers_abs_sum(state_dict):
28
- layer_sum = defaultdict(float)
29
- layer_cnt = defaultdict(int)
30
- pattern = re.compile(r"^" + re.escape(LAYER_PREFIX) + r"(\d+)\.")
31
-
32
- for name, tensor in state_dict.items():
33
- if not isinstance(name, str):
34
- continue
35
- match = pattern.match(name)
36
- if match is None or not torch.is_tensor(tensor):
37
- continue
38
-
39
- layer_id = int(match.group(1))
40
- if layer_id < 0 or layer_id >= NUM_LAYERS:
41
- continue
42
-
43
- value = tensor.detach()
44
- if value.is_cuda:
45
- value = value.cpu()
46
-
47
- value = value.to(torch.float64)
48
- layer_sum[layer_id] += value.abs().sum().item()
49
- layer_cnt[layer_id] += 1
50
-
51
- for layer_id in range(NUM_LAYERS):
52
- layer_sum[layer_id] = float(layer_sum.get(layer_id, 0.0))
53
- layer_cnt[layer_id] = int(layer_cnt.get(layer_id, 0))
54
-
55
- return layer_sum, layer_cnt
56
-
57
-
58
- def save_layer_csv(layer_sum, layer_cnt, path):
59
- output_dir = os.path.dirname(path)
60
- if output_dir:
61
- os.makedirs(output_dir, exist_ok=True)
62
- with open(path, "w", newline="") as file_obj:
63
- writer = csv.DictWriter(file_obj, fieldnames=["layer_id", "abs_sum", "num_tensors"])
64
- writer.writeheader()
65
- for layer_id in range(NUM_LAYERS):
66
- writer.writerow(
67
- {
68
- "layer_id": layer_id,
69
- "abs_sum": f"{layer_sum[layer_id]:.12e}",
70
- "num_tensors": layer_cnt[layer_id],
71
- }
72
- )
73
-
74
-
75
- def plot_line(xs, ys, out_png, title):
76
- ys_plot = ys[:]
77
- if USE_LOG_Y:
78
- min_pos = min([value for value in ys_plot if value > 0], default=1e-300)
79
- eps = min_pos * 1e-6 if min_pos > 0 else 1e-300
80
- ys_plot = [value if value > 0 else eps for value in ys_plot]
81
-
82
- plt.figure(figsize=(12, 4.5))
83
- plt.plot(xs, ys_plot, marker="o", linewidth=1.5)
84
- plt.xlabel("Layer id")
85
- plt.ylabel("abs_sum (all params in layer)")
86
- plt.title(title)
87
- plt.grid(True, which="both", linestyle="--", linewidth=0.5, alpha=0.5)
88
- if USE_LOG_Y:
89
- plt.yscale("log")
90
- plt.tight_layout()
91
- plt.savefig(out_png, dpi=200)
92
- plt.close()
93
-
94
-
95
- def plot_bar(xs, ys, out_png, title):
96
- ys_plot = ys[:]
97
- if USE_LOG_Y:
98
- min_pos = min([value for value in ys_plot if value > 0], default=1e-300)
99
- eps = min_pos * 1e-6 if min_pos > 0 else 1e-300
100
- ys_plot = [value if value > 0 else eps for value in ys_plot]
101
-
102
- plt.figure(figsize=(12, 4.5))
103
- plt.bar(xs, ys_plot)
104
- plt.xlabel("Layer id")
105
- plt.ylabel("abs_sum (all params in layer)")
106
- plt.title(title)
107
- plt.grid(True, which="both", axis="y", linestyle="--", linewidth=0.5, alpha=0.5)
108
- if USE_LOG_Y:
109
- plt.yscale("log")
110
- plt.tight_layout()
111
- plt.savefig(out_png, dpi=200)
112
- plt.close()
113
-
114
-
115
- def main():
116
- parser = argparse.ArgumentParser()
117
- parser.add_argument("checkpoint_path", help="Path to the feature vector checkpoint (.pth)")
118
- args = parser.parse_args()
119
-
120
- base = os.path.splitext(args.checkpoint_path)[0]
121
- out_csv = base + "_language_model_layers_abs_sum.csv"
122
- out_png_line = base + "_language_model_layers_abs_sum_line.png"
123
- out_png_bar = base + "_language_model_layers_abs_sum_bar.png"
124
-
125
- ckpt = torch.load(args.checkpoint_path, map_location="cpu")
126
- state_dict = pick_state_dict(ckpt)
127
-
128
- if state_dict is None:
129
- print("Not a state_dict-like dict. Type:", type(ckpt))
130
- if isinstance(ckpt, dict):
131
- print("Top-level keys:", list(ckpt.keys())[:50])
132
- raise SystemExit(1)
133
-
134
- layer_sum, layer_cnt = aggregate_layers_abs_sum(state_dict)
135
- save_layer_csv(layer_sum, layer_cnt, out_csv)
136
- print(f"Saved CSV: {out_csv}")
137
-
138
- xs = list(range(NUM_LAYERS))
139
- ys = [layer_sum[i] for i in xs]
140
- plot_line(xs, ys, out_png_line, f"{LAYER_PREFIX}*: abs_sum per layer")
141
- plot_bar(xs, ys, out_png_bar, f"{LAYER_PREFIX}*: abs_sum per layer")
142
-
143
- print(f"Saved plot: {out_png_line}")
144
- print(f"Saved plot: {out_png_bar}")
145
-
146
- top = sorted(((i, layer_sum[i], layer_cnt[i]) for i in xs), key=lambda item: item[1], reverse=True)[:5]
147
- print("Top-5 layers by abs_sum:")
148
- for layer_id, abs_sum, tensor_count in top:
149
- print(f" layer {layer_id:02d}: abs_sum={abs_sum:.6e}, tensors={tensor_count}")
150
-
151
-
152
- if __name__ == "__main__":
153
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/capvector/tools/vector_regularize.py DELETED
@@ -1,75 +0,0 @@
1
- # Used to regularize feature vectors by first computing the absolute-sum of each layer and then performing normalization
2
-
3
- import argparse
4
- from collections import OrderedDict
5
-
6
- import torch
7
-
8
-
9
- def pick_state_dict(obj):
10
- """Extract state_dict from a checkpoint-like object"""
11
- if isinstance(obj, (OrderedDict, dict)):
12
- for k in ["state_dict", "model_state_dict", "model", "net", "weights", "params"]:
13
- if k in obj and isinstance(obj[k], (OrderedDict, dict)):
14
- return obj[k]
15
- if any(torch.is_tensor(v) for v in obj.values()):
16
- return obj
17
- return None
18
-
19
-
20
- def calculate_total_abs_sum(state_dict):
21
- """Compute the sum of absolute values over all parameters"""
22
- total_sum = 0.0
23
- param_count = 0
24
-
25
- for name, tensor in state_dict.items():
26
- if not torch.is_tensor(tensor):
27
- continue
28
-
29
- x = tensor.detach()
30
- if x.is_cuda:
31
- x = x.cpu()
32
-
33
- # Use float64 to ensure numerical precision
34
- x = x.to(torch.float64)
35
- abs_sum = x.abs().sum().item()
36
- total_sum += abs_sum
37
- param_count += 1
38
-
39
- print(f"{name}: {abs_sum:.12e} (shape: {list(x.shape)}, numel: {x.numel()})")
40
-
41
- return total_sum, param_count
42
-
43
-
44
- def main():
45
- parser = argparse.ArgumentParser()
46
- parser.add_argument("checkpoint_path", help="Path to the feature vector checkpoint (.pth)")
47
- args = parser.parse_args()
48
-
49
- print(f"Loading checkpoint: {args.checkpoint_path}")
50
- ckpt = torch.load(args.checkpoint_path, map_location="cpu")
51
- sd = pick_state_dict(ckpt)
52
-
53
- if sd is None:
54
- print("Error: failed to extract state_dict from checkpoint")
55
- print(f"Checkpoint type: {type(ckpt)}")
56
- if isinstance(ckpt, dict):
57
- print(f"Top-level keys: {list(ckpt.keys())[:20]}")
58
- raise SystemExit(1)
59
-
60
- print(f"\nFound {len(sd)} parameters\n")
61
- print("=" * 80)
62
- print("Absolute-sum of each parameter:")
63
- print("=" * 80)
64
-
65
- total_abs_sum, param_count = calculate_total_abs_sum(sd)
66
-
67
- print("=" * 80)
68
- print(f"\nSummary:")
69
- print(f" Total number of parameters: {param_count}")
70
- print(f" Sum of absolute values of all parameters: {total_abs_sum:.12e}")
71
- print(f" Sum of absolute values of all parameters (scientific notation): {total_abs_sum:.6e}")
72
-
73
-
74
- if __name__ == "__main__":
75
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/experiments/robot/aloha/aloha_utils.py DELETED
@@ -1,85 +0,0 @@
1
- """Utils for evaluating policies in real-world ALOHA environments."""
2
-
3
- import os
4
-
5
- import imageio
6
- import numpy as np
7
- from PIL import Image
8
-
9
- from experiments.robot.aloha.real_env import make_real_env
10
- from experiments.robot.robot_utils import (
11
- DATE,
12
- DATE_TIME,
13
- )
14
-
15
-
16
- def get_next_task_label(task_label):
17
- """Prompt the user to input the next task."""
18
- if task_label == "":
19
- user_input = ""
20
- while user_input == "":
21
- user_input = input("Enter the task name: ")
22
- task_label = user_input
23
- else:
24
- user_input = input("Enter the task name (or leave blank to repeat the previous task): ")
25
- if user_input == "":
26
- pass # Do nothing -> Let task_label be the same
27
- else:
28
- task_label = user_input
29
- print(f"Task: {task_label}")
30
- return task_label
31
-
32
-
33
- def get_aloha_env():
34
- """Initializes and returns the ALOHA environment."""
35
- env = make_real_env(init_node=True)
36
- return env
37
-
38
-
39
- def resize_image_for_preprocessing(img):
40
- """
41
- Takes numpy array corresponding to a single image and resizes to 256x256, exactly as done
42
- in the ALOHA data preprocessing script, which is used before converting the dataset to RLDS.
43
- """
44
- ALOHA_PREPROCESS_SIZE = 256
45
- img = np.array(
46
- Image.fromarray(img).resize((ALOHA_PREPROCESS_SIZE, ALOHA_PREPROCESS_SIZE), resample=Image.BICUBIC)
47
- ) # BICUBIC is default; specify explicitly to make it clear
48
- return img
49
-
50
-
51
- def get_aloha_image(obs):
52
- """Extracts third-person image from observations and preprocesses it."""
53
- # obs: dm_env._environment.TimeStep
54
- img = obs.observation["images"]["cam_high"]
55
- img = resize_image_for_preprocessing(img)
56
- return img
57
-
58
-
59
- def get_aloha_wrist_images(obs):
60
- """Extracts both wrist camera images from observations and preprocesses them."""
61
- # obs: dm_env._environment.TimeStep
62
- left_wrist_img = obs.observation["images"]["cam_left_wrist"]
63
- right_wrist_img = obs.observation["images"]["cam_right_wrist"]
64
- left_wrist_img = resize_image_for_preprocessing(left_wrist_img)
65
- right_wrist_img = resize_image_for_preprocessing(right_wrist_img)
66
- return left_wrist_img, right_wrist_img
67
-
68
-
69
- def save_rollout_video(rollout_images, idx, success, task_description, log_file=None, notes=None):
70
- """Saves an MP4 replay of an episode."""
71
- rollout_dir = f"./rollouts/{DATE}"
72
- os.makedirs(rollout_dir, exist_ok=True)
73
- processed_task_description = task_description.lower().replace(" ", "_").replace("\n", "_").replace(".", "_")[:50]
74
- filetag = f"{rollout_dir}/{DATE_TIME}--openvla_oft--episode={idx}--success={success}--task={processed_task_description}"
75
- if notes is not None:
76
- filetag += f"--{notes}"
77
- mp4_path = f"{filetag}.mp4"
78
- video_writer = imageio.get_writer(mp4_path, fps=25)
79
- for img in rollout_images:
80
- video_writer.append_data(img)
81
- video_writer.close()
82
- print(f"Saved rollout MP4 at path {mp4_path}")
83
- if log_file is not None:
84
- log_file.write(f"Saved rollout MP4 at path {mp4_path}\n")
85
- return mp4_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/experiments/robot/aloha/constants.py DELETED
@@ -1,100 +0,0 @@
1
- ### Task parameters
2
-
3
- DATA_DIR = '/scr2/moojink/data/aloha1/'
4
- TASK_CONFIGS = {
5
- # fold shorts
6
- 'fold_shorts':{
7
- 'dataset_dir': DATA_DIR + '/fold_shorts',
8
- 'num_episodes': 20,
9
- 'episode_len': 1000,
10
- 'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
11
- },
12
- # fold shirt
13
- 'fold_shirt':{
14
- 'dataset_dir': DATA_DIR + '/fold_shirt',
15
- 'num_episodes': 30,
16
- 'episode_len': 1250,
17
- 'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
18
- },
19
- # scoop X into bowl
20
- 'scoop_raisins_into_bowl':{
21
- 'dataset_dir': DATA_DIR + '/scoop_raisins_into_bowl',
22
- 'num_episodes': 15,
23
- 'episode_len': 900,
24
- 'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
25
- },
26
- 'scoop_almonds_and_green_M&Ms_into_bowl':{
27
- 'dataset_dir': DATA_DIR + '/scoop_almonds_and_green_M&Ms_into_bowl',
28
- 'num_episodes': 15,
29
- 'episode_len': 900,
30
- 'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
31
- },
32
- 'scoop_pretzels_into_bowl':{
33
- 'dataset_dir': DATA_DIR + '/scoop_pretzels_into_bowl',
34
- 'num_episodes': 15,
35
- 'episode_len': 900,
36
- 'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
37
- },
38
- # put X into pot
39
- 'put_red_pepper_into_pot':{
40
- 'dataset_dir': DATA_DIR + '/put_red_pepper_into_pot',
41
- 'num_episodes': 100,
42
- 'episode_len': 400,
43
- 'camera_names': ['cam_high', 'cam_low', 'cam_left_wrist', 'cam_right_wrist']
44
- },
45
- 'put_yellow_corn_into_pot':{
46
- 'dataset_dir': DATA_DIR + '/put_yellow_corn_into_pot',
47
- 'num_episodes': 100,
48
- 'episode_len': 400,
49
- 'camera_names': ['cam_high', 'cam_low', 'cam_left_wrist', 'cam_right_wrist']
50
- },
51
- 'put_green_pepper_into_pot':{
52
- 'dataset_dir': DATA_DIR + '/put_green_pepper_into_pot',
53
- 'num_episodes': 100,
54
- 'episode_len': 400,
55
- 'camera_names': ['cam_high', 'cam_low', 'cam_left_wrist', 'cam_right_wrist']
56
- },
57
- }
58
-
59
- ### ALOHA fixed constants
60
- DT = 0.04 # 1 / 0.04 -> 25 Hz
61
- JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
62
- START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239]
63
-
64
- # Left finger position limits (qpos[7]), right_finger = -1 * left_finger
65
- MASTER_GRIPPER_POSITION_OPEN = 0.02417
66
- MASTER_GRIPPER_POSITION_CLOSE = 0.01244
67
- PUPPET_GRIPPER_POSITION_OPEN = 0.05800
68
- PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
69
-
70
- # Gripper joint limits (qpos[6])
71
- MASTER_GRIPPER_JOINT_OPEN = 0.3083 # For ALOHA 1
72
- MASTER_GRIPPER_JOINT_CLOSE = -0.6842 # For ALOHA 1
73
- # MASTER_GRIPPER_JOINT_OPEN = -0.8 # For ALOHA 2
74
- # MASTER_GRIPPER_JOINT_CLOSE = -1.65 # For ALOHA 2
75
- PUPPET_GRIPPER_JOINT_OPEN = 1.4910
76
- PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
77
-
78
- ############################ Helper functions ############################
79
-
80
- MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
81
- PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
82
- MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
83
- PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
84
- MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x))
85
-
86
- MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
87
- PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
88
- MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
89
- PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
90
- MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x))
91
-
92
- MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
93
- PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
94
-
95
- MASTER_POS2JOINT = lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
96
- MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN((x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE))
97
- PUPPET_POS2JOINT = lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
98
- PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN((x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE))
99
-
100
- MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE)/2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/experiments/robot/aloha/preprocess_split_aloha_data.py DELETED
@@ -1,260 +0,0 @@
1
- """
2
- Preprocesses ALOHA dataset(s) and splits them into train/val sets.
3
-
4
- Preprocessing includes downsizing images from 480x640 to 256x256.
5
- Splits happen at the episode level (not step level), which means that
6
- an episode is treated as an atomic unit that entirely goes to either
7
- the train set or val set.
8
-
9
- Original ALOHA data layout:
10
- /PATH/TO/DATASET/dataset_name/
11
- - episode_0.hdf5
12
- - episode_1.hdf5
13
- - ...
14
- - episode_N.hdf5
15
-
16
- Preprocessed data layout (after running this script):
17
- /PATH/TO/PREPROCESSED_DATASETS/dataset_name/
18
- - train/
19
- - episode_0.hdf5
20
- - episode_1.hdf5
21
- - ...
22
- - episode_M.hdf5
23
- - val/
24
- - episode_0.hdf5
25
- - episode_1.hdf5
26
- - ...
27
- - episode_K.hdf5
28
-
29
- where N > M > K
30
-
31
- Example usage:
32
- # "put X into pot" task
33
- python experiments/robot/aloha/preprocess_split_aloha_data.py \
34
- --dataset_path /scr/moojink/data/aloha1_raw/put_green_pepper_into_pot/ \
35
- --out_base_dir /scr/moojink/data/aloha1_preprocessed/ \
36
- --percent_val 0.05 && \
37
- python experiments/robot/aloha/preprocess_split_aloha_data.py \
38
- --dataset_path /scr/moojink/data/aloha1_raw/put_red_pepper_into_pot/ \
39
- --out_base_dir /scr/moojink/data/aloha1_preprocessed/ \
40
- --percent_val 0.05 && \
41
- python experiments/robot/aloha/preprocess_split_aloha_data.py \
42
- --dataset_path /scr/moojink/data/aloha1_raw/put_yellow_corn_into_pot/ \
43
- --out_base_dir /scr/moojink/data/aloha1_preprocessed/ \
44
- --percent_val 0.05
45
- """
46
-
47
- import argparse
48
- import glob
49
- import os
50
- import random
51
-
52
- import h5py
53
- import numpy as np
54
- from PIL import Image
55
- from tqdm import tqdm
56
-
57
-
58
- def load_hdf5(demo_path):
59
- """Loads single episode."""
60
- if not os.path.isfile(demo_path):
61
- print(f"Dataset does not exist at \n{demo_path}\n")
62
- exit()
63
-
64
- print(f"Loading {demo_path}...")
65
- with h5py.File(demo_path, "r") as root:
66
- is_sim = root.attrs["sim"]
67
- qpos = root["/observations/qpos"][()]
68
- qvel = root["/observations/qvel"][()]
69
- effort = root["/observations/effort"][()]
70
- action = root["/action"][()]
71
- image_dict = dict()
72
- for cam_name in root["/observations/images/"].keys():
73
- image_dict[cam_name] = root[f"/observations/images/{cam_name}"][()]
74
- print(f"Loading episode complete: {demo_path}")
75
-
76
- return qpos, qvel, effort, action, image_dict, is_sim
77
-
78
-
79
- def load_and_preprocess_all_episodes(demo_paths, out_dataset_dir):
80
- """
81
- Loads and preprocesses all episodes.
82
- Resizes all images in one episode before loading the next, to reduce memory usage.
83
- """
84
- cam_names = ["cam_high", "cam_left_wrist", "cam_right_wrist"]
85
- idx = 0
86
- for demo in tqdm(demo_paths):
87
- qpos, qvel, effort, action, image_dict, is_sim = load_hdf5(demo)
88
- # Save non-image info
89
- episode_len = image_dict["cam_high"].shape[0]
90
- # Resize all images
91
- print("Resizing images in episode...")
92
- for k in cam_names:
93
- resized_images = []
94
- for i in range(episode_len):
95
- resized_images.append(
96
- np.array(
97
- Image.fromarray(image_dict[k][i]).resize(
98
- (args.img_resize_size, args.img_resize_size), resample=Image.BICUBIC
99
- )
100
- ) # BICUBIC is default; specify explicitly to make it clear
101
- )
102
- image_dict[k] = np.stack(resized_images)
103
- print("Resizing images in episode complete!")
104
- # Save preprocessed episode
105
- data_dict = dict(
106
- qpos=qpos,
107
- qvel=qvel,
108
- effort=effort,
109
- action=action,
110
- image_dict=image_dict,
111
- is_sim=is_sim,
112
- )
113
- save_new_hdf5(out_dataset_dir, data_dict, idx)
114
- idx += 1
115
-
116
-
117
- def randomly_split(full_qpos, full_qvel, full_effort, full_action, full_image_dict, percent_val):
118
- """Randomly splits dataset into train and validation sets."""
119
- # Create a list of episode indices
120
- num_episodes_total = len(full_qpos)
121
- indices = list(range(num_episodes_total))
122
- # Shuffle the episode indices
123
- random.shuffle(indices)
124
- # Create new lists using the shuffled indices
125
- shuffled_qpos = [full_qpos[idx] for idx in indices]
126
- shuffled_qvel = [full_qvel[idx] for idx in indices]
127
- shuffled_effort = [full_effort[idx] for idx in indices]
128
- shuffled_action = [full_action[idx] for idx in indices]
129
- shuffled_image_dict = {
130
- "cam_high": [],
131
- "cam_left_wrist": [],
132
- "cam_right_wrist": [],
133
- }
134
- for k in full_image_dict.keys():
135
- shuffled_image_dict[k] = [full_image_dict[k][idx] for idx in indices]
136
- # Split into train and val sets
137
- num_episodes_val = int(num_episodes_total * percent_val)
138
- print(f"Total # steps: {num_episodes_total}; using {num_episodes_val} ({percent_val:.2f}%) for val set")
139
- num_episodes_train = num_episodes_total - num_episodes_val
140
- train_dict = dict(
141
- qpos=shuffled_qpos[:num_episodes_train],
142
- qvel=shuffled_qvel[:num_episodes_train],
143
- effort=shuffled_effort[:num_episodes_train],
144
- action=shuffled_action[:num_episodes_train],
145
- image_dict=dict(
146
- cam_high=shuffled_image_dict["cam_high"][:num_episodes_train],
147
- cam_left_wrist=shuffled_image_dict["cam_left_wrist"][:num_episodes_train],
148
- cam_right_wrist=shuffled_image_dict["cam_right_wrist"][:num_episodes_train],
149
- ),
150
- )
151
- val_dict = dict(
152
- qpos=shuffled_qpos[num_episodes_train:],
153
- qvel=shuffled_qvel[num_episodes_train:],
154
- effort=shuffled_effort[num_episodes_train:],
155
- action=shuffled_action[num_episodes_train:],
156
- image_dict=dict(
157
- cam_high=shuffled_image_dict["cam_high"][num_episodes_train:],
158
- cam_left_wrist=shuffled_image_dict["cam_left_wrist"][num_episodes_train:],
159
- cam_right_wrist=shuffled_image_dict["cam_right_wrist"][num_episodes_train:],
160
- ),
161
- )
162
- return train_dict, val_dict
163
-
164
-
165
- def save_new_hdf5(out_dataset_dir, data_dict, episode_idx):
166
- """Saves an HDF5 file for a new episode."""
167
- camera_names = data_dict["image_dict"].keys()
168
- H, W, C = data_dict["image_dict"]["cam_high"][0].shape
169
- out_path = os.path.join(out_dataset_dir, f"episode_{episode_idx}.hdf5")
170
- # Save HDF5 with same structure as original demos (except that now we combine all episodes into one HDF5 file)
171
- with h5py.File(
172
- out_path, "w", rdcc_nbytes=1024**2 * 2
173
- ) as root: # Magic constant for rdcc_nbytes comes from ALOHA codebase
174
- episode_len = data_dict["qpos"].shape[0]
175
- root.attrs["sim"] = data_dict["is_sim"]
176
- obs = root.create_group("observations")
177
- _ = obs.create_dataset("qpos", (episode_len, 14))
178
- _ = obs.create_dataset("qvel", (episode_len, 14))
179
- _ = obs.create_dataset("effort", (episode_len, 14))
180
- root["/observations/qpos"][...] = data_dict["qpos"]
181
- root["/observations/qvel"][...] = data_dict["qvel"]
182
- root["/observations/effort"][...] = data_dict["effort"]
183
- image = obs.create_group("images")
184
- for cam_name in camera_names:
185
- _ = image.create_dataset(
186
- cam_name,
187
- (episode_len, H, W, C),
188
- dtype="uint8",
189
- chunks=(1, H, W, C),
190
- )
191
- root[f"/observations/images/{cam_name}"][...] = data_dict["image_dict"][cam_name]
192
- _ = root.create_dataset("action", (episode_len, 14))
193
- root["/action"][...] = data_dict["action"]
194
- # Compute and save *relative* actions as well
195
- actions = data_dict["action"]
196
- relative_actions = np.zeros_like(actions)
197
- relative_actions[:-1] = actions[1:] - actions[:-1] # Relative actions are the changes in joint pos
198
- relative_actions[-1] = relative_actions[-2] # Just copy the second-to-last action for the last action
199
- _ = root.create_dataset("relative_action", (episode_len, 14))
200
- root["/relative_action"][...] = relative_actions
201
- print(f"Saved dataset: {out_path}")
202
-
203
-
204
- def main(args):
205
- # Create directory to save preprocessed dataset (if it doesn't exist already)
206
- os.makedirs(args.out_base_dir, exist_ok=True)
207
- out_dataset_dir = os.path.join(args.out_base_dir, os.path.basename(args.dataset_path.rstrip("/")))
208
- os.makedirs(out_dataset_dir, exist_ok=True)
209
- # Get list of filepaths of all episodes
210
- all_demo_paths = glob.glob(os.path.join(args.dataset_path, "*.hdf5")) # List of HDF5 filepaths
211
- all_demo_paths.sort()
212
- # Create a list of episode indices
213
- num_episodes_total = len(all_demo_paths)
214
- indices = list(range(num_episodes_total))
215
- # Shuffle the episode indices
216
- random.shuffle(indices)
217
- # Split into train and val sets
218
- num_episodes_val = int(num_episodes_total * args.percent_val)
219
- print(f"Total # episodes: {num_episodes_total}; using {num_episodes_val} ({args.percent_val:.2f}%) for val set")
220
- num_episodes_train = num_episodes_total - num_episodes_val
221
- train_indices = indices[:num_episodes_train]
222
- val_indices = indices[num_episodes_train:]
223
- train_demo_paths = [all_demo_paths[i] for i in train_indices]
224
- val_demo_paths = [all_demo_paths[i] for i in val_indices]
225
- # Preprocess all episodes and save the result
226
- out_dataset_dir_train = os.path.join(out_dataset_dir, "train")
227
- out_dataset_dir_val = os.path.join(out_dataset_dir, "val")
228
- os.makedirs(out_dataset_dir_train, exist_ok=True)
229
- os.makedirs(out_dataset_dir_val, exist_ok=True)
230
- load_and_preprocess_all_episodes(train_demo_paths, out_dataset_dir_train)
231
- load_and_preprocess_all_episodes(val_demo_paths, out_dataset_dir_val)
232
-
233
-
234
- if __name__ == "__main__":
235
- parser = argparse.ArgumentParser()
236
- parser.add_argument(
237
- "--dataset_path",
238
- required=True,
239
- help="Path to raw ALOHA dataset directory. Example: /PATH/TO/USER/data/aloha_raw/put_green_pepper_into_pot/",
240
- )
241
- parser.add_argument(
242
- "--out_base_dir",
243
- required=True,
244
- help="Path to directory in which to save preprocessed dataset. Example: /PATH/TO/USER/data/aloha_preprocessed/",
245
- )
246
- parser.add_argument(
247
- "--percent_val",
248
- type=float,
249
- help="Percent of dataset to use as validation set (measured in episodes, not steps).",
250
- default=0.05,
251
- )
252
- parser.add_argument(
253
- "--img_resize_size",
254
- type=int,
255
- help="Size to resize images to. Final images will be square (img_resize_size x img_resize_size pixels).",
256
- default=256,
257
- )
258
- args = parser.parse_args()
259
-
260
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/experiments/robot/aloha/real_env.py DELETED
@@ -1,213 +0,0 @@
1
- import time
2
- import numpy as np
3
- import collections
4
- import matplotlib.pyplot as plt
5
- import dm_env
6
-
7
- from experiments.robot.aloha.constants import DT, START_ARM_POSE, MASTER_GRIPPER_JOINT_NORMALIZE_FN, PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN
8
- from experiments.robot.aloha.constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN, PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN
9
- from experiments.robot.aloha.constants import PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
10
- from experiments.robot.aloha.robot_utils import Recorder, ImageRecorder
11
- from experiments.robot.aloha.robot_utils import setup_master_bot, setup_puppet_bot, move_arms, move_grippers
12
- from interbotix_xs_modules.arm import InterbotixManipulatorXS
13
- from interbotix_xs_msgs.msg import JointSingleCommand
14
-
15
- import IPython
16
- e = IPython.embed
17
-
18
- class RealEnv:
19
- """
20
- Environment for real robot bi-manual manipulation
21
- Action space: [left_arm_qpos (6), # absolute joint position
22
- left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
23
- right_arm_qpos (6), # absolute joint position
24
- right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
25
-
26
- Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
27
- left_gripper_position (1), # normalized gripper position (0: close, 1: open)
28
- right_arm_qpos (6), # absolute joint position
29
- right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
30
- "qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
31
- left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
32
- right_arm_qvel (6), # absolute joint velocity (rad)
33
- right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
34
- "images": {"cam_high": (480x640x3), # h, w, c, dtype='uint8'
35
- "cam_low": (480x640x3), # h, w, c, dtype='uint8'
36
- "cam_left_wrist": (480x640x3), # h, w, c, dtype='uint8'
37
- "cam_right_wrist": (480x640x3)} # h, w, c, dtype='uint8'
38
- """
39
-
40
- def __init__(self, init_node, setup_robots=True):
41
- self.puppet_bot_left = InterbotixManipulatorXS(robot_model="vx300s", group_name="arm", gripper_name="gripper",
42
- robot_name=f'puppet_left', init_node=init_node)
43
- self.puppet_bot_right = InterbotixManipulatorXS(robot_model="vx300s", group_name="arm", gripper_name="gripper",
44
- robot_name=f'puppet_right', init_node=False)
45
- if setup_robots:
46
- self.setup_robots()
47
-
48
- self.recorder_left = Recorder('left', init_node=False)
49
- self.recorder_right = Recorder('right', init_node=False)
50
- self.image_recorder = ImageRecorder(init_node=False)
51
- self.gripper_command = JointSingleCommand(name="gripper")
52
-
53
- def setup_robots(self):
54
- setup_puppet_bot(self.puppet_bot_left)
55
- setup_puppet_bot(self.puppet_bot_right)
56
-
57
- def get_qpos(self):
58
- left_qpos_raw = self.recorder_left.qpos
59
- right_qpos_raw = self.recorder_right.qpos
60
- left_arm_qpos = left_qpos_raw[:6]
61
- right_arm_qpos = right_qpos_raw[:6]
62
- left_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[7])] # this is position not joint
63
- right_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[7])] # this is position not joint
64
- return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
65
-
66
- def get_qvel(self):
67
- left_qvel_raw = self.recorder_left.qvel
68
- right_qvel_raw = self.recorder_right.qvel
69
- left_arm_qvel = left_qvel_raw[:6]
70
- right_arm_qvel = right_qvel_raw[:6]
71
- left_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[7])]
72
- right_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[7])]
73
- return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
74
-
75
- def get_effort(self):
76
- left_effort_raw = self.recorder_left.effort
77
- right_effort_raw = self.recorder_right.effort
78
- left_robot_effort = left_effort_raw[:7]
79
- right_robot_effort = right_effort_raw[:7]
80
- return np.concatenate([left_robot_effort, right_robot_effort])
81
-
82
- def get_images(self):
83
- return self.image_recorder.get_images()
84
-
85
- def set_gripper_pose(self, left_gripper_desired_pos_normalized, right_gripper_desired_pos_normalized):
86
- left_gripper_desired_joint = PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(left_gripper_desired_pos_normalized)
87
- self.gripper_command.cmd = left_gripper_desired_joint
88
- self.puppet_bot_left.gripper.core.pub_single.publish(self.gripper_command)
89
-
90
- right_gripper_desired_joint = PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(right_gripper_desired_pos_normalized)
91
- self.gripper_command.cmd = right_gripper_desired_joint
92
- self.puppet_bot_right.gripper.core.pub_single.publish(self.gripper_command)
93
-
94
- def _reset_joints(self):
95
- reset_position = START_ARM_POSE[:6]
96
- move_arms([self.puppet_bot_left, self.puppet_bot_right], [reset_position, reset_position], move_time=1)
97
-
98
- def _reset_gripper(self):
99
- """Set to position mode and do position resets: first open then close. Then change back to PWM mode"""
100
- move_grippers([self.puppet_bot_left, self.puppet_bot_right], [PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5)
101
- move_grippers([self.puppet_bot_left, self.puppet_bot_right], [PUPPET_GRIPPER_JOINT_CLOSE] * 2, move_time=1)
102
-
103
- def _get_obs(self):
104
- obs = collections.OrderedDict()
105
- obs['qpos'] = self.get_qpos()
106
- obs['qvel'] = self.get_qvel()
107
- obs['effort'] = self.get_effort()
108
- obs['images'] = self.get_images()
109
- return obs
110
-
111
- def get_observation(self, t=0):
112
- step_type = dm_env.StepType.FIRST if t == 0 else dm_env.StepType.MID
113
- return dm_env.TimeStep(
114
- step_type=step_type,
115
- reward=self.get_reward(),
116
- discount=None,
117
- observation=self._get_obs()
118
- )
119
-
120
- def get_reward(self):
121
- return 0
122
-
123
- def reset(self, fake=False):
124
- if not fake:
125
- # Reboot puppet robot gripper motors
126
- self.puppet_bot_left.dxl.robot_reboot_motors("single", "gripper", True)
127
- self.puppet_bot_right.dxl.robot_reboot_motors("single", "gripper", True)
128
- self._reset_joints()
129
- self._reset_gripper()
130
- return dm_env.TimeStep(
131
- step_type=dm_env.StepType.FIRST,
132
- reward=self.get_reward(),
133
- discount=None,
134
- observation=self._get_obs())
135
-
136
- def step(self, action):
137
- state_len = int(len(action) / 2)
138
- left_action = action[:state_len]
139
- right_action = action[state_len:]
140
- self.puppet_bot_left.arm.set_joint_positions(left_action[:6], blocking=False)
141
- self.puppet_bot_right.arm.set_joint_positions(right_action[:6], blocking=False)
142
- self.set_gripper_pose(left_action[-1], right_action[-1])
143
- time.sleep(DT)
144
- return dm_env.TimeStep(
145
- step_type=dm_env.StepType.MID,
146
- reward=self.get_reward(),
147
- discount=None,
148
- observation=self._get_obs())
149
-
150
-
151
- def get_action(master_bot_left, master_bot_right):
152
- action = np.zeros(14) # 6 joint + 1 gripper, for two arms
153
- # Arm actions
154
- action[:6] = master_bot_left.dxl.joint_states.position[:6]
155
- action[7:7+6] = master_bot_right.dxl.joint_states.position[:6]
156
- # Gripper actions
157
- action[6] = MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_left.dxl.joint_states.position[6])
158
- action[7+6] = MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_right.dxl.joint_states.position[6])
159
-
160
- return action
161
-
162
-
163
- def make_real_env(init_node, setup_robots=True):
164
- env = RealEnv(init_node, setup_robots)
165
- return env
166
-
167
-
168
- def test_real_teleop():
169
- """
170
- Test bimanual teleoperation and show image observations onscreen.
171
- It first reads joint poses from both master arms.
172
- Then use it as actions to step the environment.
173
- The environment returns full observations including images.
174
-
175
- An alternative approach is to have separate scripts for teleoperation and observation recording.
176
- This script will result in higher fidelity (obs, action) pairs
177
- """
178
-
179
- onscreen_render = True
180
- render_cam = 'cam_left_wrist'
181
-
182
- # source of data
183
- master_bot_left = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper",
184
- robot_name=f'master_left', init_node=True)
185
- master_bot_right = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper",
186
- robot_name=f'master_right', init_node=False)
187
- setup_master_bot(master_bot_left)
188
- setup_master_bot(master_bot_right)
189
-
190
- # setup the environment
191
- env = make_real_env(init_node=False)
192
- ts = env.reset(fake=True)
193
- episode = [ts]
194
- # setup visualization
195
- if onscreen_render:
196
- ax = plt.subplot()
197
- plt_img = ax.imshow(ts.observation['images'][render_cam])
198
- plt.ion()
199
-
200
- for t in range(1000):
201
- action = get_action(master_bot_left, master_bot_right)
202
- ts = env.step(action)
203
- episode.append(ts)
204
-
205
- if onscreen_render:
206
- plt_img.set_data(ts.observation['images'][render_cam])
207
- plt.pause(DT)
208
- else:
209
- time.sleep(DT)
210
-
211
-
212
- if __name__ == '__main__':
213
- test_real_teleop()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/experiments/robot/aloha/requirements_aloha.txt DELETED
@@ -1,26 +0,0 @@
1
- numpy<2
2
- draccus
3
- torchvision
4
- torch
5
- pyquaternion
6
- pyyaml
7
- rospkg
8
- pexpect
9
- mujoco==2.3.7
10
- dm_control==1.0.14
11
- opencv-python
12
- matplotlib
13
- einops
14
- packaging
15
- h5py
16
- traitlets
17
- ipdb
18
- IPython
19
- modern_robotics
20
- Pillow
21
- termcolor
22
- imageio[ffmpeg]
23
- uvicorn
24
- fastapi
25
- requests
26
- json_numpy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/experiments/robot/aloha/robot_utils.py DELETED
@@ -1,187 +0,0 @@
1
- import numpy as np
2
- import time
3
- from experiments.robot.aloha.constants import DT
4
- from interbotix_xs_msgs.msg import JointSingleCommand
5
-
6
- import IPython
7
- e = IPython.embed
8
-
9
- class ImageRecorder:
10
- def __init__(self, init_node=True, is_debug=False):
11
- from collections import deque
12
- import rospy
13
- from cv_bridge import CvBridge
14
- from sensor_msgs.msg import Image
15
- self.is_debug = is_debug
16
- self.bridge = CvBridge()
17
- self.camera_names = ['cam_high', 'cam_low', 'cam_left_wrist', 'cam_right_wrist']
18
- if init_node:
19
- rospy.init_node('image_recorder', anonymous=True)
20
- for cam_name in self.camera_names:
21
- setattr(self, f'{cam_name}_image', None)
22
- setattr(self, f'{cam_name}_secs', None)
23
- setattr(self, f'{cam_name}_nsecs', None)
24
- if cam_name == 'cam_high':
25
- callback_func = self.image_cb_cam_high
26
- elif cam_name == 'cam_low':
27
- callback_func = self.image_cb_cam_low
28
- elif cam_name == 'cam_left_wrist':
29
- callback_func = self.image_cb_cam_left_wrist
30
- elif cam_name == 'cam_right_wrist':
31
- callback_func = self.image_cb_cam_right_wrist
32
- else:
33
- raise NotImplementedError
34
- rospy.Subscriber(f"/usb_{cam_name}/image_raw", Image, callback_func)
35
- if self.is_debug:
36
- setattr(self, f'{cam_name}_timestamps', deque(maxlen=50))
37
- time.sleep(0.5)
38
-
39
- def image_cb(self, cam_name, data):
40
- setattr(self, f'{cam_name}_image', self.bridge.imgmsg_to_cv2(data, desired_encoding='passthrough'))
41
- setattr(self, f'{cam_name}_secs', data.header.stamp.secs)
42
- setattr(self, f'{cam_name}_nsecs', data.header.stamp.nsecs)
43
- # cv2.imwrite('/home/tonyzhao/Desktop/sample.jpg', cv_image)
44
- if self.is_debug:
45
- getattr(self, f'{cam_name}_timestamps').append(data.header.stamp.secs + data.header.stamp.secs * 1e-9)
46
-
47
- def image_cb_cam_high(self, data):
48
- cam_name = 'cam_high'
49
- return self.image_cb(cam_name, data)
50
-
51
- def image_cb_cam_low(self, data):
52
- cam_name = 'cam_low'
53
- return self.image_cb(cam_name, data)
54
-
55
- def image_cb_cam_left_wrist(self, data):
56
- cam_name = 'cam_left_wrist'
57
- return self.image_cb(cam_name, data)
58
-
59
- def image_cb_cam_right_wrist(self, data):
60
- cam_name = 'cam_right_wrist'
61
- return self.image_cb(cam_name, data)
62
-
63
- def get_images(self):
64
- image_dict = dict()
65
- for cam_name in self.camera_names:
66
- image_dict[cam_name] = getattr(self, f'{cam_name}_image')
67
- return image_dict
68
-
69
- def print_diagnostics(self):
70
- def dt_helper(l):
71
- l = np.array(l)
72
- diff = l[1:] - l[:-1]
73
- return np.mean(diff)
74
- for cam_name in self.camera_names:
75
- image_freq = 1 / dt_helper(getattr(self, f'{cam_name}_timestamps'))
76
- print(f'{cam_name} {image_freq=:.2f}')
77
- print()
78
-
79
- class Recorder:
80
- def __init__(self, side, init_node=True, is_debug=False):
81
- from collections import deque
82
- import rospy
83
- from sensor_msgs.msg import JointState
84
- from interbotix_xs_msgs.msg import JointGroupCommand, JointSingleCommand
85
-
86
- self.secs = None
87
- self.nsecs = None
88
- self.qpos = None
89
- self.effort = None
90
- self.arm_command = None
91
- self.gripper_command = None
92
- self.is_debug = is_debug
93
-
94
- if init_node:
95
- rospy.init_node('recorder', anonymous=True)
96
- rospy.Subscriber(f"/puppet_{side}/joint_states", JointState, self.puppet_state_cb)
97
- rospy.Subscriber(f"/puppet_{side}/commands/joint_group", JointGroupCommand, self.puppet_arm_commands_cb)
98
- rospy.Subscriber(f"/puppet_{side}/commands/joint_single", JointSingleCommand, self.puppet_gripper_commands_cb)
99
- if self.is_debug:
100
- self.joint_timestamps = deque(maxlen=50)
101
- self.arm_command_timestamps = deque(maxlen=50)
102
- self.gripper_command_timestamps = deque(maxlen=50)
103
- time.sleep(0.1)
104
-
105
- def puppet_state_cb(self, data):
106
- self.qpos = data.position
107
- self.qvel = data.velocity
108
- self.effort = data.effort
109
- self.data = data
110
- if self.is_debug:
111
- self.joint_timestamps.append(time.time())
112
-
113
- def puppet_arm_commands_cb(self, data):
114
- self.arm_command = data.cmd
115
- if self.is_debug:
116
- self.arm_command_timestamps.append(time.time())
117
-
118
- def puppet_gripper_commands_cb(self, data):
119
- self.gripper_command = data.cmd
120
- if self.is_debug:
121
- self.gripper_command_timestamps.append(time.time())
122
-
123
- def print_diagnostics(self):
124
- def dt_helper(l):
125
- l = np.array(l)
126
- diff = l[1:] - l[:-1]
127
- return np.mean(diff)
128
-
129
- joint_freq = 1 / dt_helper(self.joint_timestamps)
130
- arm_command_freq = 1 / dt_helper(self.arm_command_timestamps)
131
- gripper_command_freq = 1 / dt_helper(self.gripper_command_timestamps)
132
-
133
- print(f'{joint_freq=:.2f}\n{arm_command_freq=:.2f}\n{gripper_command_freq=:.2f}\n')
134
-
135
- def get_arm_joint_positions(bot):
136
- return bot.arm.core.joint_states.position[:6]
137
-
138
- def get_arm_gripper_positions(bot):
139
- joint_position = bot.gripper.core.joint_states.position[6]
140
- return joint_position
141
-
142
- def move_arms(bot_list, target_pose_list, move_time=1):
143
- num_steps = int(move_time / DT)
144
- curr_pose_list = [get_arm_joint_positions(bot) for bot in bot_list]
145
- traj_list = [np.linspace(curr_pose, target_pose, num_steps) for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)]
146
- for t in range(num_steps):
147
- for bot_id, bot in enumerate(bot_list):
148
- bot.arm.set_joint_positions(traj_list[bot_id][t], blocking=False)
149
- time.sleep(DT)
150
-
151
- def move_grippers(bot_list, target_pose_list, move_time):
152
- gripper_command = JointSingleCommand(name="gripper")
153
- num_steps = int(move_time / DT)
154
- curr_pose_list = [get_arm_gripper_positions(bot) for bot in bot_list]
155
- traj_list = [np.linspace(curr_pose, target_pose, num_steps) for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)]
156
- for t in range(num_steps):
157
- for bot_id, bot in enumerate(bot_list):
158
- gripper_command.cmd = traj_list[bot_id][t]
159
- bot.gripper.core.pub_single.publish(gripper_command)
160
- time.sleep(DT)
161
-
162
- def setup_puppet_bot(bot):
163
- bot.dxl.robot_reboot_motors("single", "gripper", True)
164
- bot.dxl.robot_set_operating_modes("group", "arm", "position")
165
- bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
166
- torque_on(bot)
167
-
168
- def setup_master_bot(bot):
169
- bot.dxl.robot_set_operating_modes("group", "arm", "pwm")
170
- bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
171
- torque_off(bot)
172
-
173
- def set_standard_pid_gains(bot):
174
- bot.dxl.robot_set_motor_registers("group", "arm", 'Position_P_Gain', 800)
175
- bot.dxl.robot_set_motor_registers("group", "arm", 'Position_I_Gain', 0)
176
-
177
- def set_low_pid_gains(bot):
178
- bot.dxl.robot_set_motor_registers("group", "arm", 'Position_P_Gain', 100)
179
- bot.dxl.robot_set_motor_registers("group", "arm", 'Position_I_Gain', 0)
180
-
181
- def torque_off(bot):
182
- bot.dxl.robot_torque_enable("group", "arm", False)
183
- bot.dxl.robot_torque_enable("single", "gripper", False)
184
-
185
- def torque_on(bot):
186
- bot.dxl.robot_torque_enable("group", "arm", True)
187
- bot.dxl.robot_torque_enable("single", "gripper", True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/experiments/robot/aloha/run_aloha_eval.py DELETED
@@ -1,385 +0,0 @@
1
- """
2
- run_aloha_eval.py
3
-
4
- Evaluates a model in a real-world ALOHA environment.
5
- """
6
-
7
- import logging
8
- import os
9
- import socket
10
- import sys
11
- import time
12
- from collections import deque
13
- from dataclasses import dataclass
14
- from pathlib import Path
15
- from typing import Optional, Union
16
-
17
- import draccus
18
- import tqdm
19
-
20
- # Append current directory so that interpreter can find experiments.robot
21
- sys.path.append(".")
22
- from experiments.robot.aloha.aloha_utils import (
23
- get_aloha_env,
24
- get_aloha_image,
25
- get_aloha_wrist_images,
26
- get_next_task_label,
27
- save_rollout_video,
28
- )
29
- from experiments.robot.openvla_utils import (
30
- get_action_from_server,
31
- resize_image_for_policy,
32
- )
33
- from experiments.robot.robot_utils import (
34
- DATE_TIME,
35
- get_image_resize_size,
36
- set_seed_everywhere,
37
- )
38
-
39
- # Set up logging
40
- logging.basicConfig(
41
- level=logging.INFO,
42
- format="%(asctime)s [%(levelname)s] %(message)s",
43
- handlers=[logging.StreamHandler()],
44
- )
45
- logger = logging.getLogger(__name__)
46
-
47
-
48
- @dataclass
49
- class GenerateConfig:
50
- # fmt: off
51
-
52
- #################################################################################################################
53
- # Model-specific parameters
54
- #################################################################################################################
55
- model_family: str = "openvla" # Model family
56
-
57
- center_crop: bool = True # Center crop? (if trained w/ random crop image aug)
58
- num_open_loop_steps: int = 25 # Number of actions to execute open-loop before requerying policy
59
-
60
- use_vla_server: bool = True # Whether to query remote VLA server for actions
61
- vla_server_url: Union[str, Path] = "" # Remote VLA server URL (set to 127.0.0.1 if on same machine)
62
-
63
- #################################################################################################################
64
- # ALOHA environment-specific parameters
65
- #################################################################################################################
66
- num_rollouts_planned: int = 50 # Number of test rollouts
67
- max_steps: int = 1500 # Max number of steps per rollout
68
- use_relative_actions: bool = False # Whether to use relative actions (delta joint angles)
69
-
70
- #################################################################################################################
71
- # Utils
72
- #################################################################################################################
73
- run_id_note: Optional[str] = None # Extra note to add to end of run ID for logging
74
- local_log_dir: str = "./experiments/logs" # Local directory for eval logs
75
-
76
- seed: int = 7 # Random Seed (for reproducibility)
77
-
78
- # fmt: on
79
-
80
-
81
- def validate_config(cfg: GenerateConfig) -> None:
82
- """Validate configuration parameters."""
83
- assert cfg.use_vla_server, (
84
- "Must use VLA server (server-client interface) to query model and get actions! Please set --use_vla_server=True"
85
- )
86
-
87
-
88
- def setup_logging(cfg: GenerateConfig):
89
- """Set up logging to file."""
90
- # Create run ID
91
- run_id = f"EVAL-{cfg.model_family}-{DATE_TIME}"
92
- if cfg.run_id_note is not None:
93
- run_id += f"--{cfg.run_id_note}"
94
-
95
- # Set up local logging
96
- os.makedirs(cfg.local_log_dir, exist_ok=True)
97
- local_log_filepath = os.path.join(cfg.local_log_dir, run_id + ".txt")
98
- log_file = open(local_log_filepath, "w")
99
- logger.info(f"Logging to local log file: {local_log_filepath}")
100
-
101
- return log_file, local_log_filepath, run_id
102
-
103
-
104
- def log_message(message: str, log_file=None):
105
- """Log a message to console and optionally to a log file."""
106
- print(message)
107
- logger.info(message)
108
- if log_file:
109
- log_file.write(message + "\n")
110
- log_file.flush()
111
-
112
-
113
- def get_server_endpoint(cfg: GenerateConfig):
114
- """Get the server endpoint for remote inference."""
115
- ip_address = socket.gethostbyname(cfg.vla_server_url)
116
- return f"http://{ip_address}:8777/act"
117
-
118
-
119
- def prepare_observation(obs, resize_size):
120
- """Prepare observation for policy input."""
121
- # Get preprocessed images
122
- img = get_aloha_image(obs)
123
- left_wrist_img, right_wrist_img = get_aloha_wrist_images(obs)
124
-
125
- # Resize images to size expected by model
126
- img_resized = resize_image_for_policy(img, resize_size)
127
- left_wrist_img_resized = resize_image_for_policy(left_wrist_img, resize_size)
128
- right_wrist_img_resized = resize_image_for_policy(right_wrist_img, resize_size)
129
-
130
- # Prepare observations dict
131
- observation = {
132
- "full_image": img_resized,
133
- "left_wrist_image": left_wrist_img_resized,
134
- "right_wrist_image": right_wrist_img_resized,
135
- "state": obs.observation["qpos"],
136
- }
137
-
138
- return observation, img_resized, left_wrist_img_resized, right_wrist_img_resized
139
-
140
-
141
- def run_episode(
142
- cfg: GenerateConfig,
143
- env,
144
- task_description: str,
145
- server_endpoint: str,
146
- resize_size,
147
- log_file=None,
148
- ):
149
- """Run a single episode in the ALOHA environment."""
150
- # Define control frequency
151
- STEP_DURATION_IN_SEC = 1.0 / 25.0
152
-
153
- # Reset environment
154
- obs = env.reset()
155
-
156
- # Initialize action queue
157
- action_queue = deque(maxlen=cfg.num_open_loop_steps)
158
-
159
- # Setup
160
- t = 0
161
- curr_state = None
162
- replay_images = []
163
- replay_images_resized = []
164
- replay_images_left_wrist_resized = []
165
- replay_images_right_wrist_resized = []
166
-
167
- log_message("Prepare the scene, and then press Enter to begin...", log_file)
168
- input()
169
-
170
- # Reset environment again to fetch first timestep observation
171
- obs = env.reset()
172
-
173
- # Fetch initial robot state (but sleep first so that robot stops moving)
174
- time.sleep(2)
175
- curr_state = env.get_qpos()
176
-
177
- episode_start_time = time.time()
178
- total_model_query_time = 0.0
179
-
180
- try:
181
- while t < cfg.max_steps:
182
- # Get step start time (used to compute how much to sleep between steps)
183
- step_start_time = time.time()
184
-
185
- # Get observation
186
- obs = env.get_observation(t=t)
187
-
188
- # Save raw high camera image for replay video
189
- replay_images.append(obs.observation["images"]["cam_high"])
190
-
191
- # If action queue is empty, requery model
192
- if len(action_queue) == 0:
193
- # Prepare observation
194
- observation, img_resized, left_wrist_resized, right_wrist_resized = prepare_observation(obs, resize_size)
195
- observation["instruction"] = task_description
196
-
197
- # Save processed images for replay
198
- replay_images_resized.append(img_resized)
199
- replay_images_left_wrist_resized.append(left_wrist_resized)
200
- replay_images_right_wrist_resized.append(right_wrist_resized)
201
-
202
- # Query model to get action
203
- log_message("Requerying model...", log_file)
204
- model_query_start_time = time.time()
205
- actions = get_action_from_server(observation, server_endpoint)
206
- actions = actions[: cfg.num_open_loop_steps]
207
- total_model_query_time += time.time() - model_query_start_time
208
- action_queue.extend(actions)
209
-
210
- # Get action from queue
211
- action = action_queue.popleft()
212
- log_message("-----------------------------------------------------", log_file)
213
- log_message(f"t: {t}", log_file)
214
- log_message(f"action: {action}", log_file)
215
-
216
- # Execute action in environment
217
- if cfg.use_relative_actions:
218
- # Get absolute joint angles from relative action
219
- rel_action = action
220
- target_state = curr_state + rel_action
221
- obs = env.step(target_state.tolist())
222
- # Update current state (assume it is the commanded target state)
223
- curr_state = target_state
224
- else:
225
- obs = env.step(action.tolist())
226
- t += 1
227
-
228
- # Sleep until next timestep
229
- step_elapsed_time = time.time() - step_start_time
230
- if step_elapsed_time < STEP_DURATION_IN_SEC:
231
- time_to_sleep = STEP_DURATION_IN_SEC - step_elapsed_time
232
- log_message(f"Sleeping {time_to_sleep} sec...", log_file)
233
- time.sleep(time_to_sleep)
234
-
235
- except (KeyboardInterrupt, Exception) as e:
236
- if isinstance(e, KeyboardInterrupt):
237
- log_message("\nCaught KeyboardInterrupt: Terminating episode early.", log_file)
238
- else:
239
- log_message(f"\nCaught exception: {e}", log_file)
240
-
241
- episode_end_time = time.time()
242
-
243
- # Get success feedback from user
244
- user_input = input("Success? Enter 'y' or 'n': ")
245
- success = True if user_input.lower() == "y" else False
246
-
247
- # Calculate episode statistics
248
- episode_stats = {
249
- "success": success,
250
- "total_steps": t,
251
- "model_query_time": total_model_query_time,
252
- "episode_duration": episode_end_time - episode_start_time,
253
- }
254
-
255
- return (
256
- episode_stats,
257
- replay_images,
258
- replay_images_resized,
259
- replay_images_left_wrist_resized,
260
- replay_images_right_wrist_resized,
261
- )
262
-
263
-
264
- def save_episode_videos(
265
- replay_images,
266
- replay_images_resized,
267
- replay_images_left_wrist,
268
- replay_images_right_wrist,
269
- episode_idx,
270
- success,
271
- task_description,
272
- log_file=None,
273
- ):
274
- """Save videos of the episode from different camera angles."""
275
- # Save main replay video
276
- save_rollout_video(replay_images, episode_idx, success=success, task_description=task_description, log_file=log_file)
277
-
278
- # Save processed view videos
279
- save_rollout_video(
280
- replay_images_resized,
281
- episode_idx,
282
- success=success,
283
- task_description=task_description,
284
- log_file=log_file,
285
- notes="resized",
286
- )
287
- save_rollout_video(
288
- replay_images_left_wrist,
289
- episode_idx,
290
- success=success,
291
- task_description=task_description,
292
- log_file=log_file,
293
- notes="left_wrist_resized",
294
- )
295
- save_rollout_video(
296
- replay_images_right_wrist,
297
- episode_idx,
298
- success=success,
299
- task_description=task_description,
300
- log_file=log_file,
301
- notes="right_wrist_resized",
302
- )
303
-
304
-
305
- @draccus.wrap()
306
- def eval_aloha(cfg: GenerateConfig) -> None:
307
- """Main function to evaluate a trained policy in a real-world ALOHA environment."""
308
- # Validate configuration
309
- validate_config(cfg)
310
-
311
- # Set random seed
312
- set_seed_everywhere(cfg.seed)
313
-
314
- # Setup logging
315
- log_file, local_log_filepath, run_id = setup_logging(cfg)
316
-
317
- # Get expected image dimensions
318
- resize_size = get_image_resize_size(cfg)
319
-
320
- # Get ALOHA environment
321
- env = get_aloha_env()
322
-
323
- # Get server endpoint for remote inference
324
- server_endpoint = get_server_endpoint(cfg)
325
-
326
- # Initialize task description
327
- task_description = ""
328
-
329
- # Start evaluation
330
- num_rollouts_completed, total_successes = 0, 0
331
-
332
- for episode_idx in tqdm.tqdm(range(cfg.num_rollouts_planned)):
333
- # Get task description from user
334
- task_description = get_next_task_label(task_description)
335
- log_message(f"\nTask: {task_description}", log_file)
336
-
337
- log_message(f"Starting episode {num_rollouts_completed + 1}...", log_file)
338
-
339
- # Run episode
340
- episode_stats, replay_images, replay_images_resized, replay_images_left_wrist, replay_images_right_wrist = (
341
- run_episode(cfg, env, task_description, server_endpoint, resize_size, log_file)
342
- )
343
-
344
- # Update counters
345
- num_rollouts_completed += 1
346
- if episode_stats["success"]:
347
- total_successes += 1
348
-
349
- # Save videos
350
- save_episode_videos(
351
- replay_images,
352
- replay_images_resized,
353
- replay_images_left_wrist,
354
- replay_images_right_wrist,
355
- num_rollouts_completed,
356
- episode_stats["success"],
357
- task_description,
358
- log_file,
359
- )
360
-
361
- # Log results
362
- log_message(f"Success: {episode_stats['success']}", log_file)
363
- log_message(f"# episodes completed so far: {num_rollouts_completed}", log_file)
364
- log_message(f"# successes: {total_successes} ({total_successes / num_rollouts_completed * 100:.1f}%)", log_file)
365
- log_message(f"Total model query time: {episode_stats['model_query_time']:.2f} sec", log_file)
366
- log_message(f"Total episode elapsed time: {episode_stats['episode_duration']:.2f} sec", log_file)
367
-
368
- # Calculate final success rate
369
- final_success_rate = float(total_successes) / float(num_rollouts_completed) if num_rollouts_completed > 0 else 0
370
-
371
- # Log final results
372
- log_message("\nFinal results:", log_file)
373
- log_message(f"Total episodes: {num_rollouts_completed}", log_file)
374
- log_message(f"Total successes: {total_successes}", log_file)
375
- log_message(f"Overall success rate: {final_success_rate:.4f} ({final_success_rate * 100:.1f}%)", log_file)
376
-
377
- # Close log file
378
- if log_file:
379
- log_file.close()
380
-
381
- return final_success_rate
382
-
383
-
384
- if __name__ == "__main__":
385
- eval_aloha()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/experiments/robot/libero/libero_requirements.txt DELETED
@@ -1,6 +0,0 @@
1
- imageio[ffmpeg]
2
- robosuite==1.4.1
3
- bddl
4
- easydict
5
- cloudpickle
6
- gym
 
 
 
 
 
 
 
capvector-oft/experiments/robot/libero/libero_utils.py DELETED
@@ -1,87 +0,0 @@
1
- """Utils for evaluating policies in LIBERO simulation environments."""
2
-
3
- import math
4
- import os
5
-
6
- import imageio
7
- import numpy as np
8
- import tensorflow as tf
9
- from libero.libero import get_libero_path
10
- from libero.libero.envs import OffScreenRenderEnv
11
-
12
- from experiments.robot.robot_utils import (
13
- DATE,
14
- DATE_TIME,
15
- )
16
-
17
-
18
- def get_libero_env(task, model_family, resolution=256):
19
- """Initializes and returns the LIBERO environment, along with the task description."""
20
- task_description = task.language
21
- task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
22
- env_args = {"bddl_file_name": task_bddl_file, "camera_heights": resolution, "camera_widths": resolution}
23
- env = OffScreenRenderEnv(**env_args)
24
- env.seed(0) # IMPORTANT: seed seems to affect object positions even when using fixed initial state
25
- return env, task_description
26
-
27
-
28
- def get_libero_dummy_action(model_family: str):
29
- """Get dummy/no-op action, used to roll out the simulation while the robot does nothing."""
30
- return [0, 0, 0, 0, 0, 0, -1]
31
-
32
-
33
- def get_libero_image(obs):
34
- """Extracts third-person image from observations and preprocesses it."""
35
- img = obs["agentview_image"]
36
- img = img[::-1, ::-1] # IMPORTANT: rotate 180 degrees to match train preprocessing
37
- return img
38
-
39
-
40
- def get_libero_wrist_image(obs):
41
- """Extracts wrist camera image from observations and preprocesses it."""
42
- img = obs["robot0_eye_in_hand_image"]
43
- img = img[::-1, ::-1] # IMPORTANT: rotate 180 degrees to match train preprocessing
44
- return img
45
-
46
-
47
- def save_rollout_video(rollout_images, idx, success, task_description, log_file=None):
48
- """Saves an MP4 replay of an episode."""
49
- rollout_dir = f"./rollouts/{DATE}"
50
- os.makedirs(rollout_dir, exist_ok=True)
51
- processed_task_description = task_description.lower().replace(" ", "_").replace("\n", "_").replace(".", "_")[:50]
52
- mp4_path = f"{rollout_dir}/{DATE_TIME}--openvla_oft--episode={idx}--success={success}--task={processed_task_description}.mp4"
53
- video_writer = imageio.get_writer(mp4_path, fps=30)
54
- for img in rollout_images:
55
- video_writer.append_data(img)
56
- video_writer.close()
57
- print(f"Saved rollout MP4 at path {mp4_path}")
58
- if log_file is not None:
59
- log_file.write(f"Saved rollout MP4 at path {mp4_path}\n")
60
- return mp4_path
61
-
62
-
63
- def quat2axisangle(quat):
64
- """
65
- Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
66
-
67
- Converts quaternion to axis-angle format.
68
- Returns a unit vector direction scaled by its angle in radians.
69
-
70
- Args:
71
- quat (np.array): (x,y,z,w) vec4 float angles
72
-
73
- Returns:
74
- np.array: (ax,ay,az) axis-angle exponential coordinates
75
- """
76
- # clip quaternion
77
- if quat[3] > 1.0:
78
- quat[3] = 1.0
79
- elif quat[3] < -1.0:
80
- quat[3] = -1.0
81
-
82
- den = np.sqrt(1.0 - quat[3] * quat[3])
83
- if math.isclose(den, 0.0):
84
- # This is (close to) a zero degree rotation, immediately return
85
- return np.zeros(3)
86
-
87
- return (quat[:3] * 2.0 * math.acos(quat[3])) / den
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/experiments/robot/libero/regenerate_libero_dataset.py DELETED
@@ -1,249 +0,0 @@
1
- """
2
- Regenerates a LIBERO dataset (HDF5 files) by replaying demonstrations in the environments.
3
-
4
- Notes:
5
- - We save image observations at 256x256px resolution (instead of 128x128).
6
- - We filter out transitions with "no-op" (zero) actions that do not change the robot's state.
7
- - We filter out unsuccessful demonstrations.
8
- - In the LIBERO HDF5 data -> RLDS data conversion (not shown here), we rotate the images by
9
- 180 degrees because we observe that the environments return images that are upside down
10
- on our platform.
11
-
12
- Usage:
13
- python experiments/robot/libero/regenerate_libero_dataset.py \
14
- --libero_task_suite [ libero_spatial | libero_object | libero_goal | libero_10 ] \
15
- --libero_raw_data_dir <PATH TO RAW HDF5 DATASET DIR> \
16
- --libero_target_dir <PATH TO TARGET DIR>
17
-
18
- Example (LIBERO-Spatial):
19
- python experiments/robot/libero/regenerate_libero_dataset.py \
20
- --libero_task_suite libero_spatial \
21
- --libero_raw_data_dir ./LIBERO/libero/datasets/libero_spatial \
22
- --libero_target_dir ./LIBERO/libero/datasets/libero_spatial_no_noops
23
-
24
- """
25
-
26
- import argparse
27
- import json
28
- import os
29
- import time
30
-
31
- import h5py
32
- import numpy as np
33
- import robosuite.utils.transform_utils as T
34
- import tqdm
35
- from libero.libero import benchmark
36
-
37
- from experiments.robot.libero.libero_utils import (
38
- get_libero_dummy_action,
39
- get_libero_env,
40
- )
41
-
42
-
43
- IMAGE_RESOLUTION = 256
44
-
45
-
46
- def is_noop(action, prev_action=None, threshold=1e-4):
47
- """
48
- Returns whether an action is a no-op action.
49
-
50
- A no-op action satisfies two criteria:
51
- (1) All action dimensions, except for the last one (gripper action), are near zero.
52
- (2) The gripper action is equal to the previous timestep's gripper action.
53
-
54
- Explanation of (2):
55
- Naively filtering out actions with just criterion (1) is not good because you will
56
- remove actions where the robot is staying still but opening/closing its gripper.
57
- So you also need to consider the current state (by checking the previous timestep's
58
- gripper action as a proxy) to determine whether the action really is a no-op.
59
- """
60
- # Special case: Previous action is None if this is the first action in the episode
61
- # Then we only care about criterion (1)
62
- if prev_action is None:
63
- return np.linalg.norm(action[:-1]) < threshold
64
-
65
- # Normal case: Check both criteria (1) and (2)
66
- gripper_action = action[-1]
67
- prev_gripper_action = prev_action[-1]
68
- return np.linalg.norm(action[:-1]) < threshold and gripper_action == prev_gripper_action
69
-
70
-
71
- def main(args):
72
- print(f"Regenerating {args.libero_task_suite} dataset!")
73
-
74
- # Create target directory
75
- if os.path.isdir(args.libero_target_dir):
76
- user_input = input(f"Target directory already exists at path: {args.libero_target_dir}\nEnter 'y' to overwrite the directory, or anything else to exit: ")
77
- if user_input != 'y':
78
- exit()
79
- os.makedirs(args.libero_target_dir, exist_ok=True)
80
-
81
- # Prepare JSON file to record success/false and initial states per episode
82
- metainfo_json_dict = {}
83
- metainfo_json_out_path = f"./experiments/robot/libero/{args.libero_task_suite}_metainfo.json"
84
- with open(metainfo_json_out_path, "w") as f:
85
- # Just test that we can write to this file (we overwrite it later)
86
- json.dump(metainfo_json_dict, f)
87
-
88
- # Get task suite
89
- benchmark_dict = benchmark.get_benchmark_dict()
90
- task_suite = benchmark_dict[args.libero_task_suite]()
91
- num_tasks_in_suite = task_suite.n_tasks
92
-
93
- # Setup
94
- num_replays = 0
95
- num_success = 0
96
- num_noops = 0
97
-
98
- for task_id in tqdm.tqdm(range(num_tasks_in_suite)):
99
- # Get task in suite
100
- task = task_suite.get_task(task_id)
101
- env, task_description = get_libero_env(task, "llava", resolution=IMAGE_RESOLUTION)
102
-
103
- # Get dataset for task
104
- orig_data_path = os.path.join(args.libero_raw_data_dir, f"{task.name}_demo.hdf5")
105
- assert os.path.exists(orig_data_path), f"Cannot find raw data file {orig_data_path}."
106
- orig_data_file = h5py.File(orig_data_path, "r")
107
- orig_data = orig_data_file["data"]
108
-
109
- # Create new HDF5 file for regenerated demos
110
- new_data_path = os.path.join(args.libero_target_dir, f"{task.name}_demo.hdf5")
111
- new_data_file = h5py.File(new_data_path, "w")
112
- grp = new_data_file.create_group("data")
113
-
114
- for i in range(len(orig_data.keys())):
115
- # Get demo data
116
- demo_data = orig_data[f"demo_{i}"]
117
- orig_actions = demo_data["actions"][()]
118
- orig_states = demo_data["states"][()]
119
-
120
- # Reset environment, set initial state, and wait a few steps for environment to settle
121
- env.reset()
122
- env.set_init_state(orig_states[0])
123
- for _ in range(10):
124
- obs, reward, done, info = env.step(get_libero_dummy_action("llava"))
125
-
126
- # Set up new data lists
127
- states = []
128
- actions = []
129
- ee_states = []
130
- gripper_states = []
131
- joint_states = []
132
- robot_states = []
133
- agentview_images = []
134
- eye_in_hand_images = []
135
-
136
- # Replay original demo actions in environment and record observations
137
- for _, action in enumerate(orig_actions):
138
- # Skip transitions with no-op actions
139
- prev_action = actions[-1] if len(actions) > 0 else None
140
- if is_noop(action, prev_action):
141
- print(f"\tSkipping no-op action: {action}")
142
- num_noops += 1
143
- continue
144
-
145
- if states == []:
146
- # In the first timestep, since we're using the original initial state to initialize the environment,
147
- # copy the initial state (first state in episode) over from the original HDF5 to the new one
148
- states.append(orig_states[0])
149
- robot_states.append(demo_data["robot_states"][0])
150
- else:
151
- # For all other timesteps, get state from environment and record it
152
- states.append(env.sim.get_state().flatten())
153
- robot_states.append(
154
- np.concatenate([obs["robot0_gripper_qpos"], obs["robot0_eef_pos"], obs["robot0_eef_quat"]])
155
- )
156
-
157
- # Record original action (from demo)
158
- actions.append(action)
159
-
160
- # Record data returned by environment
161
- if "robot0_gripper_qpos" in obs:
162
- gripper_states.append(obs["robot0_gripper_qpos"])
163
- joint_states.append(obs["robot0_joint_pos"])
164
- ee_states.append(
165
- np.hstack(
166
- (
167
- obs["robot0_eef_pos"],
168
- T.quat2axisangle(obs["robot0_eef_quat"]),
169
- )
170
- )
171
- )
172
- agentview_images.append(obs["agentview_image"])
173
- eye_in_hand_images.append(obs["robot0_eye_in_hand_image"])
174
-
175
- # Execute demo action in environment
176
- obs, reward, done, info = env.step(action.tolist())
177
-
178
- # At end of episode, save replayed trajectories to new HDF5 files (only keep successes)
179
- if done:
180
- dones = np.zeros(len(actions)).astype(np.uint8)
181
- dones[-1] = 1
182
- rewards = np.zeros(len(actions)).astype(np.uint8)
183
- rewards[-1] = 1
184
- assert len(actions) == len(agentview_images)
185
-
186
- ep_data_grp = grp.create_group(f"demo_{i}")
187
- obs_grp = ep_data_grp.create_group("obs")
188
- obs_grp.create_dataset("gripper_states", data=np.stack(gripper_states, axis=0))
189
- obs_grp.create_dataset("joint_states", data=np.stack(joint_states, axis=0))
190
- obs_grp.create_dataset("ee_states", data=np.stack(ee_states, axis=0))
191
- obs_grp.create_dataset("ee_pos", data=np.stack(ee_states, axis=0)[:, :3])
192
- obs_grp.create_dataset("ee_ori", data=np.stack(ee_states, axis=0)[:, 3:])
193
- obs_grp.create_dataset("agentview_rgb", data=np.stack(agentview_images, axis=0))
194
- obs_grp.create_dataset("eye_in_hand_rgb", data=np.stack(eye_in_hand_images, axis=0))
195
- ep_data_grp.create_dataset("actions", data=actions)
196
- ep_data_grp.create_dataset("states", data=np.stack(states))
197
- ep_data_grp.create_dataset("robot_states", data=np.stack(robot_states, axis=0))
198
- ep_data_grp.create_dataset("rewards", data=rewards)
199
- ep_data_grp.create_dataset("dones", data=dones)
200
-
201
- num_success += 1
202
-
203
- num_replays += 1
204
-
205
- # Record success/false and initial environment state in metainfo dict
206
- task_key = task_description.replace(" ", "_")
207
- episode_key = f"demo_{i}"
208
- if task_key not in metainfo_json_dict:
209
- metainfo_json_dict[task_key] = {}
210
- if episode_key not in metainfo_json_dict[task_key]:
211
- metainfo_json_dict[task_key][episode_key] = {}
212
- metainfo_json_dict[task_key][episode_key]["success"] = bool(done)
213
- metainfo_json_dict[task_key][episode_key]["initial_state"] = orig_states[0].tolist()
214
-
215
- # Write metainfo dict to JSON file
216
- # (We repeatedly overwrite, rather than doing this once at the end, just in case the script crashes midway)
217
- with open(metainfo_json_out_path, "w") as f:
218
- json.dump(metainfo_json_dict, f, indent=2)
219
-
220
- # Count total number of successful replays so far
221
- print(
222
- f"Total # episodes replayed: {num_replays}, Total # successes: {num_success} ({num_success / num_replays * 100:.1f} %)"
223
- )
224
-
225
- # Report total number of no-op actions filtered out so far
226
- print(f" Total # no-op actions filtered out: {num_noops}")
227
-
228
- # Close HDF5 files
229
- orig_data_file.close()
230
- new_data_file.close()
231
- print(f"Saved regenerated demos for task '{task_description}' at: {new_data_path}")
232
-
233
- print(f"Dataset regeneration complete! Saved new dataset at: {args.libero_target_dir}")
234
- print(f"Saved metainfo JSON at: {metainfo_json_out_path}")
235
-
236
-
237
- if __name__ == "__main__":
238
- # Parse command-line arguments
239
- parser = argparse.ArgumentParser()
240
- parser.add_argument("--libero_task_suite", type=str, choices=["libero_spatial", "libero_object", "libero_goal", "libero_10", "libero_90"],
241
- help="LIBERO task suite. Example: libero_spatial", required=True)
242
- parser.add_argument("--libero_raw_data_dir", type=str,
243
- help="Path to directory containing raw HDF5 dataset. Example: ./LIBERO/libero/datasets/libero_spatial", required=True)
244
- parser.add_argument("--libero_target_dir", type=str,
245
- help="Path to regenerated dataset directory. Example: ./LIBERO/libero/datasets/libero_spatial_no_noops", required=True)
246
- args = parser.parse_args()
247
-
248
- # Start data regeneration
249
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/experiments/robot/libero/run_libero_eval.py DELETED
@@ -1,540 +0,0 @@
1
- """
2
- run_libero_eval.py
3
-
4
- Evaluates a trained policy in a LIBERO simulation benchmark task suite.
5
- """
6
-
7
- import json
8
- import logging
9
- import os
10
- import sys
11
- from collections import deque
12
- from dataclasses import dataclass
13
- from enum import Enum
14
- from pathlib import Path
15
- from typing import Optional, Union
16
-
17
- import draccus
18
- import numpy as np
19
- import tqdm
20
- from libero.libero import benchmark
21
-
22
- import wandb
23
-
24
- # Append current directory so that interpreter can find experiments.robot
25
- sys.path.append("../..")
26
- from experiments.robot.libero.libero_utils import (
27
- get_libero_dummy_action,
28
- get_libero_env,
29
- get_libero_image,
30
- get_libero_wrist_image,
31
- quat2axisangle,
32
- save_rollout_video,
33
- )
34
- from experiments.robot.openvla_utils import (
35
- get_action_head,
36
- get_noisy_action_projector,
37
- get_processor,
38
- get_proprio_projector,
39
- resize_image_for_policy,
40
- )
41
- from experiments.robot.robot_utils import (
42
- DATE_TIME,
43
- get_action,
44
- get_image_resize_size,
45
- get_model,
46
- invert_gripper_action,
47
- normalize_gripper_action,
48
- set_seed_everywhere,
49
- )
50
- from prismatic.vla.constants import NUM_ACTIONS_CHUNK
51
-
52
-
53
- # import debugpy
54
- # try:
55
- # debugpy.listen(("localhost", 9501))
56
- # print("Waiting for debugger attach")
57
- # debugpy.wait_for_client()
58
- # except Exception as e:
59
- # pass
60
-
61
-
62
- # Define task suite constants
63
- class TaskSuite(str, Enum):
64
- LIBERO_SPATIAL = "libero_spatial"
65
- LIBERO_OBJECT = "libero_object"
66
- LIBERO_GOAL = "libero_goal"
67
- LIBERO_10 = "libero_10"
68
- LIBERO_90 = "libero_90"
69
-
70
-
71
- # Define max steps for each task suite
72
- TASK_MAX_STEPS = {
73
- TaskSuite.LIBERO_SPATIAL: 220, # longest training demo has 193 steps
74
- TaskSuite.LIBERO_OBJECT: 280, # longest training demo has 254 steps
75
- TaskSuite.LIBERO_GOAL: 300, # longest training demo has 270 steps
76
- TaskSuite.LIBERO_10: 520, # longest training demo has 505 steps
77
- TaskSuite.LIBERO_90: 400, # longest training demo has 373 steps
78
- }
79
-
80
-
81
- # Set up logging
82
- logging.basicConfig(
83
- level=logging.INFO,
84
- format="%(asctime)s [%(levelname)s] %(message)s",
85
- handlers=[logging.StreamHandler()],
86
- )
87
- logger = logging.getLogger(__name__)
88
-
89
-
90
- @dataclass
91
- class GenerateConfig:
92
- # fmt: off
93
-
94
- #################################################################################################################
95
- # Model-specific parameters
96
- #################################################################################################################
97
- model_family: str = "openvla" # Model family
98
- pretrained_checkpoint: Union[str, Path] = "" # Pretrained checkpoint path
99
-
100
- use_l1_regression: bool = True # If True, uses continuous action head with L1 regression objective
101
- use_diffusion: bool = False # If True, uses continuous action head with diffusion modeling objective (DDIM)
102
- num_diffusion_steps_train: int = 50 # (When `diffusion==True`) Number of diffusion steps used for training
103
- num_diffusion_steps_inference: int = 50 # (When `diffusion==True`) Number of diffusion steps used for inference
104
- use_film: bool = False # If True, uses FiLM to infuse language inputs into visual features
105
- num_images_in_input: int = 2 # Number of images in the VLA input (default: 1)
106
- use_proprio: bool = True # Whether to include proprio state in input
107
-
108
- center_crop: bool = True # Center crop? (if trained w/ random crop image aug)
109
- num_open_loop_steps: int = 8 # Number of actions to execute open-loop before requerying policy
110
-
111
- lora_rank: int = 32 # Rank of LoRA weight matrix (MAKE SURE THIS MATCHES TRAINING!)
112
-
113
- unnorm_key: Union[str, Path] = "" # Action un-normalization key
114
-
115
- load_in_8bit: bool = False # (For OpenVLA only) Load with 8-bit quantization
116
- load_in_4bit: bool = False # (For OpenVLA only) Load with 4-bit quantization
117
-
118
- #################################################################################################################
119
- # LIBERO environment-specific parameters
120
- #################################################################################################################
121
- task_suite_name: str = TaskSuite.LIBERO_SPATIAL # Task suite
122
- num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize in sim
123
- num_trials_per_task: int = 50 # Number of rollouts per task
124
- initial_states_path: str = "DEFAULT" # "DEFAULT", or path to initial states JSON file
125
- env_img_res: int = 256 # Resolution for environment images (not policy input resolution)
126
-
127
- #################################################################################################################
128
- # Utils
129
- #################################################################################################################
130
- run_id_note: Optional[str] = None # Extra note to add to end of run ID for logging
131
- local_log_dir: str = "./experiments/logs" # Local directory for eval logs
132
-
133
- use_wandb: bool = False # Whether to also log results in Weights & Biases
134
- wandb_entity: str = "your-wandb-entity" # Name of WandB entity
135
- wandb_project: str = "your-wandb-project" # Name of WandB project
136
-
137
- seed: int = 7 # Random Seed (for reproducibility)
138
-
139
- # fmt: on
140
-
141
-
142
- def validate_config(cfg: GenerateConfig) -> None:
143
- """Validate configuration parameters."""
144
- assert cfg.pretrained_checkpoint is not None, "pretrained_checkpoint must not be None!"
145
-
146
- if "image_aug" in str(cfg.pretrained_checkpoint):
147
- assert cfg.center_crop, "Expecting `center_crop==True` because model was trained with image augmentations!"
148
-
149
- assert not (cfg.load_in_8bit and cfg.load_in_4bit), "Cannot use both 8-bit and 4-bit quantization!"
150
-
151
- # Validate task suite
152
- assert cfg.task_suite_name in [suite.value for suite in TaskSuite], f"Invalid task suite: {cfg.task_suite_name}"
153
-
154
-
155
- def initialize_model(cfg: GenerateConfig):
156
- """Initialize model and associated components."""
157
- # Load model
158
- model = get_model(cfg)
159
-
160
- # Load proprio projector if needed
161
- proprio_projector = None
162
- if cfg.use_proprio:
163
- proprio_projector = get_proprio_projector(
164
- cfg,
165
- model.llm_dim,
166
- proprio_dim=8, # 8-dimensional proprio for LIBERO
167
- )
168
-
169
- # Load action head if needed
170
- action_head = None
171
- if cfg.use_l1_regression or cfg.use_diffusion:
172
- action_head = get_action_head(cfg, model.llm_dim)
173
-
174
- # Load noisy action projector if using diffusion
175
- noisy_action_projector = None
176
- if cfg.use_diffusion:
177
- noisy_action_projector = get_noisy_action_projector(cfg, model.llm_dim)
178
-
179
- # Get OpenVLA processor if needed
180
- processor = None
181
- if cfg.model_family == "openvla":
182
- processor = get_processor(cfg)
183
- check_unnorm_key(cfg, model)
184
-
185
- return model, action_head, proprio_projector, noisy_action_projector, processor
186
-
187
-
188
- def check_unnorm_key(cfg: GenerateConfig, model) -> None:
189
- """Check that the model contains the action un-normalization key."""
190
- # Initialize unnorm_key
191
- unnorm_key = cfg.task_suite_name
192
-
193
- # In some cases, the key must be manually modified (e.g. after training on a modified version of the dataset
194
- # with the suffix "_no_noops" in the dataset name)
195
- if unnorm_key not in model.norm_stats and f"{unnorm_key}_no_noops" in model.norm_stats:
196
- unnorm_key = f"{unnorm_key}_no_noops"
197
-
198
- assert unnorm_key in model.norm_stats, f"Action un-norm key {unnorm_key} not found in VLA `norm_stats`!"
199
-
200
- # Set the unnorm_key in cfg
201
- cfg.unnorm_key = unnorm_key
202
-
203
-
204
- def setup_logging(cfg: GenerateConfig):
205
- """Set up logging to file and optionally to wandb."""
206
- # Create run ID
207
- run_id = f"EVAL-{cfg.task_suite_name}-{cfg.model_family}-{DATE_TIME}"
208
- if cfg.run_id_note is not None:
209
- run_id += f"--{cfg.run_id_note}"
210
-
211
- # Set up local logging
212
- os.makedirs(cfg.local_log_dir, exist_ok=True)
213
- local_log_filepath = os.path.join(cfg.local_log_dir, run_id + ".txt")
214
- log_file = open(local_log_filepath, "w")
215
- logger.info(f"Logging to local log file: {local_log_filepath}")
216
-
217
- # Initialize Weights & Biases logging if enabled
218
- if cfg.use_wandb:
219
- wandb.init(
220
- entity=cfg.wandb_entity,
221
- project=cfg.wandb_project,
222
- name=run_id,
223
- )
224
-
225
- return log_file, local_log_filepath, run_id
226
-
227
-
228
- def log_message(message: str, log_file=None):
229
- """Log a message to console and optionally to a log file."""
230
- logger.info(message)
231
- if log_file:
232
- log_file.write(message + "\n")
233
- log_file.flush()
234
-
235
-
236
- def load_initial_states(cfg: GenerateConfig, task_suite, task_id: int, log_file=None):
237
- """Load initial states for the given task."""
238
- # Get default initial states
239
- initial_states = task_suite.get_task_init_states(task_id)
240
-
241
- # If using custom initial states, load them from file
242
- if cfg.initial_states_path != "DEFAULT":
243
- with open(cfg.initial_states_path, "r") as f:
244
- all_initial_states = json.load(f)
245
- log_message(f"Using initial states from {cfg.initial_states_path}", log_file)
246
- return initial_states, all_initial_states
247
- else:
248
- log_message("Using default initial states", log_file)
249
- return initial_states, None
250
-
251
-
252
- def prepare_observation(obs, resize_size):
253
- """Prepare observation for policy input."""
254
- # Get preprocessed images
255
- img = get_libero_image(obs)
256
- wrist_img = get_libero_wrist_image(obs)
257
-
258
- # Resize images to size expected by model
259
- img_resized = resize_image_for_policy(img, resize_size)
260
- wrist_img_resized = resize_image_for_policy(wrist_img, resize_size)
261
-
262
- # Prepare observations dict
263
- observation = {
264
- "full_image": img_resized,
265
- "wrist_image": wrist_img_resized,
266
- "state": np.concatenate(
267
- (obs["robot0_eef_pos"], quat2axisangle(obs["robot0_eef_quat"]), obs["robot0_gripper_qpos"])
268
- ),
269
- }
270
-
271
- return observation, img # Return both processed observation and original image for replay
272
-
273
-
274
- def process_action(action, model_family):
275
- """Process action before sending to environment."""
276
- # Normalize gripper action [0,1] -> [-1,+1] because the environment expects the latter
277
- action = normalize_gripper_action(action, binarize=True)
278
-
279
- # [OpenVLA] The dataloader flips the sign of the gripper action to align with other datasets
280
- # (0 = close, 1 = open), so flip it back (-1 = open, +1 = close) before executing the action
281
- if model_family == "openvla":
282
- action = invert_gripper_action(action)
283
-
284
- return action
285
-
286
-
287
- def run_episode(
288
- cfg: GenerateConfig,
289
- env,
290
- task_description: str,
291
- model,
292
- resize_size,
293
- processor=None,
294
- action_head=None,
295
- proprio_projector=None,
296
- noisy_action_projector=None,
297
- initial_state=None,
298
- log_file=None,
299
- ):
300
- """Run a single episode in the environment."""
301
- # Reset environment
302
- env.reset()
303
-
304
- # Set initial state if provided
305
- if initial_state is not None:
306
- obs = env.set_init_state(initial_state)
307
- else:
308
- obs = env.get_observation()
309
-
310
- # Initialize action queue
311
- if cfg.num_open_loop_steps != NUM_ACTIONS_CHUNK:
312
- print(f"WARNING: cfg.num_open_loop_steps ({cfg.num_open_loop_steps}) does not match the NUM_ACTIONS_CHUNK "
313
- f"({NUM_ACTIONS_CHUNK}) constant defined in prismatic.vla.constants! For best performance (in terms of "
314
- "both speed and success rate), we recommend executing the full action chunk.")
315
- action_queue = deque(maxlen=cfg.num_open_loop_steps)
316
-
317
- # Setup
318
- t = 0
319
- replay_images = []
320
- max_steps = TASK_MAX_STEPS[cfg.task_suite_name]
321
-
322
- # Run episode
323
- success = False
324
- try:
325
- while t < max_steps + cfg.num_steps_wait:
326
- # Do nothing for the first few timesteps to let objects stabilize
327
- if t < cfg.num_steps_wait:
328
- obs, reward, done, info = env.step(get_libero_dummy_action(cfg.model_family))
329
- t += 1
330
- continue
331
-
332
- # Prepare observation
333
- observation, img = prepare_observation(obs, resize_size)
334
- replay_images.append(img)
335
-
336
- # If action queue is empty, requery model
337
- if len(action_queue) == 0:
338
- # Query model to get action
339
- actions = get_action(
340
- cfg,
341
- model,
342
- observation,
343
- task_description,
344
- processor=processor,
345
- action_head=action_head,
346
- proprio_projector=proprio_projector,
347
- noisy_action_projector=noisy_action_projector,
348
- use_film=cfg.use_film,
349
- )
350
- action_queue.extend(actions)
351
-
352
- # Get action from queue
353
- action = action_queue.popleft()
354
-
355
- # Process action
356
- action = process_action(action, cfg.model_family)
357
-
358
- # Execute action in environment
359
- obs, reward, done, info = env.step(action.tolist())
360
- if done:
361
- success = True
362
- break
363
- t += 1
364
-
365
- except Exception as e:
366
- log_message(f"Episode error: {e}", log_file)
367
-
368
- return success, replay_images
369
-
370
-
371
- def run_task(
372
- cfg: GenerateConfig,
373
- task_suite,
374
- task_id: int,
375
- model,
376
- resize_size,
377
- processor=None,
378
- action_head=None,
379
- proprio_projector=None,
380
- noisy_action_projector=None,
381
- total_episodes=0,
382
- total_successes=0,
383
- log_file=None,
384
- ):
385
- """Run evaluation for a single task."""
386
- # Get task
387
- task = task_suite.get_task(task_id)
388
-
389
- # Get initial states
390
- initial_states, all_initial_states = load_initial_states(cfg, task_suite, task_id, log_file)
391
-
392
- # Initialize environment and get task description
393
- env, task_description = get_libero_env(task, cfg.model_family, resolution=cfg.env_img_res)
394
-
395
- # Start episodes
396
- task_episodes, task_successes = 0, 0
397
- for episode_idx in tqdm.tqdm(range(cfg.num_trials_per_task)):
398
- log_message(f"\nTask: {task_description}", log_file)
399
-
400
- # Handle initial state
401
- if cfg.initial_states_path == "DEFAULT":
402
- # Use default initial state
403
- initial_state = initial_states[episode_idx]
404
- else:
405
- # Get keys for fetching initial episode state from JSON
406
- initial_states_task_key = task_description.replace(" ", "_")
407
- episode_key = f"demo_{episode_idx}"
408
-
409
- # Skip episode if expert demonstration failed to complete the task
410
- if not all_initial_states[initial_states_task_key][episode_key]["success"]:
411
- log_message(f"Skipping task {task_id} episode {episode_idx} due to failed expert demo!", log_file)
412
- continue
413
-
414
- # Get initial state
415
- initial_state = np.array(all_initial_states[initial_states_task_key][episode_key]["initial_state"])
416
-
417
- log_message(f"Starting episode {task_episodes + 1}...", log_file)
418
-
419
- # Run episode
420
- success, replay_images = run_episode(
421
- cfg,
422
- env,
423
- task_description,
424
- model,
425
- resize_size,
426
- processor,
427
- action_head,
428
- proprio_projector,
429
- noisy_action_projector,
430
- initial_state,
431
- log_file,
432
- )
433
-
434
- # Update counters
435
- task_episodes += 1
436
- total_episodes += 1
437
- if success:
438
- task_successes += 1
439
- total_successes += 1
440
-
441
- # Save replay video
442
- save_rollout_video(
443
- replay_images, total_episodes, success=success, task_description=task_description, log_file=log_file
444
- )
445
-
446
- # Log results
447
- log_message(f"Success: {success}", log_file)
448
- log_message(f"# episodes completed so far: {total_episodes}", log_file)
449
- log_message(f"# successes: {total_successes} ({total_successes / total_episodes * 100:.1f}%)", log_file)
450
-
451
- # Log task results
452
- task_success_rate = float(task_successes) / float(task_episodes) if task_episodes > 0 else 0
453
- total_success_rate = float(total_successes) / float(total_episodes) if total_episodes > 0 else 0
454
-
455
- log_message(f"Current task success rate: {task_success_rate}", log_file)
456
- log_message(f"Current total success rate: {total_success_rate}", log_file)
457
-
458
- # Log to wandb if enabled
459
- if cfg.use_wandb:
460
- wandb.log(
461
- {
462
- f"success_rate/{task_description}": task_success_rate,
463
- f"num_episodes/{task_description}": task_episodes,
464
- }
465
- )
466
-
467
- return total_episodes, total_successes
468
-
469
-
470
- @draccus.wrap()
471
- def eval_libero(cfg: GenerateConfig) -> float:
472
- """Main function to evaluate a trained policy on LIBERO benchmark tasks."""
473
- # Validate configuration
474
- validate_config(cfg)
475
-
476
- # Set random seed
477
- set_seed_everywhere(cfg.seed)
478
-
479
- # Initialize model and components
480
- model, action_head, proprio_projector, noisy_action_projector, processor = initialize_model(cfg)
481
-
482
- # Get expected image dimensions
483
- resize_size = get_image_resize_size(cfg)
484
-
485
- # Setup logging
486
- log_file, local_log_filepath, run_id = setup_logging(cfg)
487
-
488
- # Initialize LIBERO task suite
489
- benchmark_dict = benchmark.get_benchmark_dict()
490
- task_suite = benchmark_dict[cfg.task_suite_name]()
491
- num_tasks = task_suite.n_tasks
492
-
493
- log_message(f"Task suite: {cfg.task_suite_name}", log_file)
494
-
495
- # Start evaluation
496
- total_episodes, total_successes = 0, 0
497
- for task_id in tqdm.tqdm(range(num_tasks)):
498
- total_episodes, total_successes = run_task(
499
- cfg,
500
- task_suite,
501
- task_id,
502
- model,
503
- resize_size,
504
- processor,
505
- action_head,
506
- proprio_projector,
507
- noisy_action_projector,
508
- total_episodes,
509
- total_successes,
510
- log_file,
511
- )
512
-
513
- # Calculate final success rate
514
- final_success_rate = float(total_successes) / float(total_episodes) if total_episodes > 0 else 0
515
-
516
- # Log final results
517
- log_message("Final results:", log_file)
518
- log_message(f"Total episodes: {total_episodes}", log_file)
519
- log_message(f"Total successes: {total_successes}", log_file)
520
- log_message(f"Overall success rate: {final_success_rate:.4f} ({final_success_rate * 100:.1f}%)", log_file)
521
-
522
- # Log to wandb if enabled
523
- if cfg.use_wandb:
524
- wandb.log(
525
- {
526
- "success_rate/total": final_success_rate,
527
- "num_episodes/total": total_episodes,
528
- }
529
- )
530
- wandb.save(local_log_filepath)
531
-
532
- # Close log file
533
- if log_file:
534
- log_file.close()
535
-
536
- return final_success_rate
537
-
538
-
539
- if __name__ == "__main__":
540
- eval_libero()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/experiments/robot/libero/sample_libero_spatial_observation.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:326db6c78dd0a9d91c11f05af03b93fa3095338ee3cb5a5eb15adf3d87eb0109
3
- size 301501
 
 
 
 
capvector-oft/experiments/robot/openvla_utils.py DELETED
@@ -1,818 +0,0 @@
1
- """Utils for evaluating OpenVLA or fine-tuned OpenVLA policies."""
2
-
3
- import filecmp
4
- import json
5
- import os
6
- import shutil
7
- import time
8
- from datetime import datetime
9
- from pathlib import Path
10
- from typing import Any, Dict, List, Optional, Tuple, Union
11
-
12
- import json_numpy
13
- import numpy as np
14
- import requests
15
- import tensorflow as tf
16
- import torch
17
- from huggingface_hub import HfApi, hf_hub_download
18
- from PIL import Image
19
- from transformers import AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor
20
-
21
- # Apply JSON numpy patch for serialization
22
- json_numpy.patch()
23
-
24
- from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig
25
- from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction
26
- from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor
27
- from prismatic.models.action_heads import DiffusionActionHead, L1RegressionActionHead
28
- from prismatic.models.film_vit_wrapper import FiLMedPrismaticVisionBackbone
29
- from prismatic.models.projectors import NoisyActionProjector, ProprioProjector
30
- from prismatic.vla.constants import (
31
- ACTION_DIM,
32
- ACTION_PROPRIO_NORMALIZATION_TYPE,
33
- )
34
- from prismatic.vla.datasets.rlds.utils.data_utils import NormalizationType
35
-
36
- # Initialize important constants
37
- DATE = time.strftime("%Y_%m_%d")
38
- DATE_TIME = time.strftime("%Y_%m_%d-%H_%M_%S")
39
- DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
40
- OPENVLA_IMAGE_SIZE = 224 # Standard image size expected by OpenVLA
41
-
42
- # Configure NumPy print settings
43
- np.set_printoptions(formatter={"float": lambda x: "{0:0.3f}".format(x)})
44
-
45
-
46
- def model_is_on_hf_hub(model_path: str) -> bool:
47
- """Checks whether a model path points to a model on Hugging Face Hub."""
48
- # If the API call below runs without error, the model is on the hub
49
- try:
50
- HfApi().model_info(model_path)
51
- return True
52
- except Exception:
53
- return False
54
-
55
-
56
- def update_auto_map(pretrained_checkpoint: str) -> None:
57
- """
58
- Update the AutoMap configuration in the checkpoint config.json file.
59
-
60
- This loads the config.json file inside the checkpoint directory and overwrites
61
- the AutoConfig and AutoModelForVision2Seq fields to use OpenVLA-specific classes.
62
-
63
- Args:
64
- pretrained_checkpoint: Path to the checkpoint directory
65
- """
66
- if not os.path.isdir(pretrained_checkpoint):
67
- return
68
-
69
- config_path = os.path.join(pretrained_checkpoint, "config.json")
70
- if not os.path.exists(config_path):
71
- print(f"Warning: No config.json found at {config_path}")
72
- return
73
-
74
- # Create timestamped backup
75
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
76
- backup_path = os.path.join(pretrained_checkpoint, f"config.json.back.{timestamp}")
77
- shutil.copy2(config_path, backup_path)
78
- print(f"Created backup of original config at: {os.path.abspath(backup_path)}")
79
-
80
- # Read and update the config
81
- with open(config_path, "r") as f:
82
- config = json.load(f)
83
-
84
- config["auto_map"] = {
85
- "AutoConfig": "configuration_prismatic.OpenVLAConfig",
86
- "AutoModelForVision2Seq": "modeling_prismatic.OpenVLAForActionPrediction",
87
- }
88
-
89
- # Write back the updated config
90
- with open(config_path, "w") as f:
91
- json.dump(config, f, indent=2)
92
-
93
- print(f"Updated config.json at: {os.path.abspath(config_path)}")
94
- print("Changes made:")
95
- print(' - Set AutoConfig to "configuration_prismatic.OpenVLAConfig"')
96
- print(' - Set AutoModelForVision2Seq to "modeling_prismatic.OpenVLAForActionPrediction"')
97
-
98
-
99
- def check_identical_files(path1: Union[str, Path], path2: Union[str, Path]) -> bool:
100
- """
101
- Check if two files are identical in content.
102
-
103
- Args:
104
- path1: Path to the first file
105
- path2: Path to the second file
106
-
107
- Returns:
108
- bool: True if files are identical, False otherwise
109
- """
110
- path1, path2 = Path(path1), Path(path2)
111
-
112
- # First check if file sizes match
113
- if path1.stat().st_size != path2.stat().st_size:
114
- return False
115
-
116
- # Check if contents match
117
- return filecmp.cmp(path1, path2, shallow=False)
118
-
119
-
120
- def _handle_file_sync(curr_filepath: str, checkpoint_filepath: str, file_type: str) -> None:
121
- """
122
- Handle syncing of files between current directory and checkpoint.
123
-
124
- Creates backups if files exist but differ, and copies current versions to checkpoint.
125
-
126
- Args:
127
- curr_filepath: Path to the current file version
128
- checkpoint_filepath: Path where the file should be in the checkpoint
129
- file_type: Description of the file type for logging
130
- """
131
- if os.path.exists(checkpoint_filepath):
132
- # Check if existing files are identical
133
- match = check_identical_files(curr_filepath, checkpoint_filepath)
134
-
135
- if not match:
136
- print(
137
- "\n------------------------------------------------------------------------------------------------\n"
138
- f"Found mismatch between:\n"
139
- f"Current: {curr_filepath}\n"
140
- f"Checkpoint: {checkpoint_filepath}\n"
141
- )
142
-
143
- # Create timestamped backup
144
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
145
- backup_path = f"{checkpoint_filepath}.back.{timestamp}"
146
- shutil.copy2(checkpoint_filepath, backup_path)
147
- print(f"Created backup of original checkpoint file at: {os.path.abspath(backup_path)}")
148
-
149
- # Copy current version to checkpoint directory
150
- shutil.copy2(curr_filepath, checkpoint_filepath)
151
- print(f"Copied current version to checkpoint at: {os.path.abspath(checkpoint_filepath)}")
152
- print(
153
- f"Changes complete. The checkpoint will now use the current version of {file_type}"
154
- "\n------------------------------------------------------------------------------------------------\n"
155
- )
156
- else:
157
- # If file doesn't exist in checkpoint directory, copy it
158
- shutil.copy2(curr_filepath, checkpoint_filepath)
159
- print(
160
- "\n------------------------------------------------------------------------------------------------\n"
161
- f"No {file_type} found in checkpoint directory.\n"
162
- f"Copied current version from: {curr_filepath}\n"
163
- f"To checkpoint location: {os.path.abspath(checkpoint_filepath)}"
164
- "\n------------------------------------------------------------------------------------------------\n"
165
- )
166
-
167
-
168
- def check_model_logic_mismatch(pretrained_checkpoint: str) -> None:
169
- """
170
- Check and sync model logic files between current code and checkpoint.
171
-
172
- Handles the relationship between current and checkpoint versions of both
173
- modeling_prismatic.py and configuration_prismatic.py:
174
- - If checkpoint file exists and differs: creates backup and copies current version
175
- - If checkpoint file doesn't exist: copies current version
176
-
177
- Args:
178
- pretrained_checkpoint: Path to the checkpoint directory
179
- """
180
- if not os.path.isdir(pretrained_checkpoint):
181
- return
182
-
183
- # Find current files
184
- curr_files = {"modeling_prismatic.py": None, "configuration_prismatic.py": None}
185
-
186
- for root, _, files in os.walk("./prismatic/"):
187
- for filename in curr_files.keys():
188
- if filename in files and curr_files[filename] is None:
189
- curr_files[filename] = os.path.join(root, filename)
190
-
191
- # Check and handle each file
192
- for filename, curr_filepath in curr_files.items():
193
- if curr_filepath is None:
194
- print(f"WARNING: `{filename}` is not found anywhere in the current directory.")
195
- continue
196
-
197
- checkpoint_filepath = os.path.join(pretrained_checkpoint, filename)
198
- _handle_file_sync(curr_filepath, checkpoint_filepath, filename)
199
-
200
-
201
- def find_checkpoint_file(pretrained_checkpoint: str, file_pattern: str) -> str:
202
- """
203
- Find a specific checkpoint file matching a pattern.
204
-
205
- Args:
206
- pretrained_checkpoint: Path to the checkpoint directory
207
- file_pattern: String pattern to match in filenames
208
-
209
- Returns:
210
- str: Path to the matching checkpoint file
211
-
212
- Raises:
213
- AssertionError: If no files or multiple files match the pattern
214
- """
215
- assert os.path.isdir(pretrained_checkpoint), f"Checkpoint path must be a directory: {pretrained_checkpoint}"
216
-
217
- checkpoint_files = []
218
- for filename in os.listdir(pretrained_checkpoint):
219
- if file_pattern in filename and "checkpoint" in filename:
220
- full_path = os.path.join(pretrained_checkpoint, filename)
221
- checkpoint_files.append(full_path)
222
-
223
- assert len(checkpoint_files) == 1, (
224
- f"Expected exactly 1 {file_pattern} checkpoint but found {len(checkpoint_files)} in directory: {pretrained_checkpoint}"
225
- )
226
-
227
- return checkpoint_files[0]
228
-
229
-
230
- def load_component_state_dict(checkpoint_path: str) -> Dict[str, torch.Tensor]:
231
- """
232
- Load a component's state dict from checkpoint and handle DDP prefix if present.
233
-
234
- Args:
235
- checkpoint_path: Path to the checkpoint file
236
-
237
- Returns:
238
- Dict: The processed state dictionary for loading
239
- """
240
- state_dict = torch.load(checkpoint_path, weights_only=True)
241
-
242
- # If the component was trained with DDP, elements in the state dict have prefix "module." which we must remove
243
- new_state_dict = {}
244
- for k, v in state_dict.items():
245
- if k.startswith("module."):
246
- new_state_dict[k[7:]] = v
247
- else:
248
- new_state_dict[k] = v
249
-
250
- return new_state_dict
251
-
252
-
253
- def get_vla(cfg: Any) -> torch.nn.Module:
254
- """
255
- Load and initialize the VLA model from checkpoint.
256
-
257
- Args:
258
- cfg: Configuration object
259
-
260
- Returns:
261
- torch.nn.Module: The initialized VLA model
262
- """
263
- print("Instantiating pretrained VLA policy...")
264
-
265
- # If loading a locally stored pretrained checkpoint, check whether config or model files
266
- # need to be synced so that any changes the user makes to the VLA modeling code will
267
- # actually go into effect
268
- # If loading a pretrained checkpoint from Hugging Face Hub, we just assume that the policy
269
- # will be used as is, with its original modeling logic
270
- if not model_is_on_hf_hub(cfg.pretrained_checkpoint):
271
- # Register OpenVLA model to HF Auto Classes (not needed if the model is on HF Hub)
272
- AutoConfig.register("openvla", OpenVLAConfig)
273
- AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor)
274
- AutoProcessor.register(OpenVLAConfig, PrismaticProcessor)
275
- AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction)
276
-
277
- # Update config.json and sync model files
278
- update_auto_map(cfg.pretrained_checkpoint)
279
- check_model_logic_mismatch(cfg.pretrained_checkpoint)
280
-
281
- # Load the model
282
- vla = AutoModelForVision2Seq.from_pretrained(
283
- cfg.pretrained_checkpoint,
284
- # attn_implementation="flash_attention_2",
285
- torch_dtype=torch.bfloat16,
286
- load_in_8bit=cfg.load_in_8bit,
287
- load_in_4bit=cfg.load_in_4bit,
288
- low_cpu_mem_usage=True,
289
- trust_remote_code=True,
290
- )
291
-
292
- # If using FiLM, wrap the vision backbone to allow for infusion of language inputs
293
- if cfg.use_film:
294
- vla = _apply_film_to_vla(vla, cfg)
295
-
296
- # Set number of images in model input
297
- vla.vision_backbone.set_num_images_in_input(cfg.num_images_in_input)
298
-
299
- vla.eval()
300
-
301
- # Move model to device if not using quantization
302
- if not cfg.load_in_8bit and not cfg.load_in_4bit:
303
- vla = vla.to(DEVICE)
304
-
305
- # Load dataset stats for action normalization
306
- _load_dataset_stats(vla, cfg.pretrained_checkpoint)
307
-
308
- return vla
309
-
310
-
311
- def _apply_film_to_vla(vla: torch.nn.Module, cfg: Any) -> torch.nn.Module:
312
- """
313
- Apply FiLM (Feature-wise Linear Modulation) to the VLA vision backbone.
314
-
315
- Args:
316
- vla: The VLA model
317
- cfg: Configuration object with model parameters
318
-
319
- Returns:
320
- torch.nn.Module: VLA model with FiLM applied
321
- """
322
- from peft import LoraConfig, get_peft_model
323
-
324
- # Apply LoRA configuration
325
- lora_config = LoraConfig(
326
- r=cfg.lora_rank,
327
- lora_alpha=min(cfg.lora_rank, 16),
328
- lora_dropout=0.0,
329
- target_modules="all-linear",
330
- init_lora_weights="gaussian",
331
- )
332
- vla = get_peft_model(vla, lora_config)
333
-
334
- # Create and apply FiLMed vision backbone
335
- new_vision_backbone = FiLMedPrismaticVisionBackbone(
336
- vision_backbone=vla.vision_backbone, llm_dim=vla.llm_dim,
337
- )
338
- vla.model.vision_backbone = new_vision_backbone
339
-
340
- # Load vision backbone checkpoint
341
- checkpoint_path = find_checkpoint_file(cfg.pretrained_checkpoint, "vision_backbone")
342
- state_dict = torch.load(checkpoint_path, weights_only=True)
343
- vla.model.vision_backbone.load_state_dict(state_dict)
344
-
345
- # Use the model component instead of wrapper and convert to bfloat16
346
- vla = vla.model
347
- vla.vision_backbone = vla.vision_backbone.to(torch.bfloat16)
348
-
349
- return vla
350
-
351
-
352
- def _load_dataset_stats(vla: torch.nn.Module, checkpoint_path: str) -> None:
353
- """
354
- Load dataset statistics used during training for action normalization.
355
-
356
- Args:
357
- vla: The VLA model
358
- checkpoint_path: Path to the checkpoint directory
359
- """
360
- if model_is_on_hf_hub(checkpoint_path):
361
- # Download dataset stats directly from HF Hub
362
- dataset_statistics_path = hf_hub_download(
363
- repo_id=checkpoint_path,
364
- filename="dataset_statistics.json",
365
- )
366
- else:
367
- dataset_statistics_path = os.path.join(checkpoint_path, "dataset_statistics.json")
368
- if os.path.isfile(dataset_statistics_path):
369
- with open(dataset_statistics_path, "r") as f:
370
- norm_stats = json.load(f)
371
- vla.norm_stats = norm_stats
372
- else:
373
- print(
374
- "WARNING: No local dataset_statistics.json file found for current checkpoint.\n"
375
- "You can ignore this if you are loading the base VLA (i.e. not fine-tuned) checkpoint."
376
- "Otherwise, you may run into errors when trying to call `predict_action()` due to an absent `unnorm_key`."
377
- )
378
-
379
-
380
- def get_processor(cfg: Any) -> AutoProcessor:
381
- """
382
- Get the VLA model's Hugging Face processor.
383
-
384
- Args:
385
- cfg: Configuration object with model parameters
386
-
387
- Returns:
388
- AutoProcessor: The model's processor
389
- """
390
- return AutoProcessor.from_pretrained(cfg.pretrained_checkpoint, trust_remote_code=True)
391
-
392
-
393
- def get_proprio_projector(cfg: Any, llm_dim: int, proprio_dim: int) -> ProprioProjector:
394
- """
395
- Get proprioception projector for the VLA model.
396
-
397
- Args:
398
- cfg: Configuration object with model parameters
399
- llm_dim: Dimension of the language model
400
- proprio_dim: Dimension of proprioception data
401
-
402
- Returns:
403
- ProprioProjector: The initialized proprio projector
404
- """
405
- # Initialize projector and move to device
406
- proprio_projector = ProprioProjector(
407
- llm_dim=llm_dim,
408
- proprio_dim=proprio_dim,
409
- ).to(DEVICE)
410
- proprio_projector = proprio_projector.to(torch.bfloat16).to(DEVICE)
411
- proprio_projector.eval()
412
-
413
- # Find and load checkpoint (may be on Hugging Face Hub or stored locally)
414
- if model_is_on_hf_hub(cfg.pretrained_checkpoint):
415
- model_path_to_proprio_projector_name = {
416
- "moojink/openvla-7b-oft-finetuned-libero-spatial": "proprio_projector--150000_checkpoint.pt",
417
- "moojink/openvla-7b-oft-finetuned-libero-object": "proprio_projector--150000_checkpoint.pt",
418
- "moojink/openvla-7b-oft-finetuned-libero-goal": "proprio_projector--50000_checkpoint.pt",
419
- "moojink/openvla-7b-oft-finetuned-libero-10": "proprio_projector--150000_checkpoint.pt",
420
- "moojink/openvla-7b-oft-finetuned-libero-spatial-object-goal-10": "proprio_projector--300000_checkpoint.pt",
421
- }
422
- if cfg.pretrained_checkpoint not in model_path_to_proprio_projector_name.keys():
423
- raise ValueError("Unsupported HF Hub pretrained checkpoint found!")
424
- # Download proprio projector directly from HF Hub
425
- proprio_projector_path = hf_hub_download(
426
- repo_id=cfg.pretrained_checkpoint, filename=model_path_to_proprio_projector_name[cfg.pretrained_checkpoint]
427
- )
428
- state_dict = load_component_state_dict(proprio_projector_path)
429
- proprio_projector.load_state_dict(state_dict)
430
- else:
431
- checkpoint_path = find_checkpoint_file(cfg.pretrained_checkpoint, "proprio_projector")
432
- state_dict = load_component_state_dict(checkpoint_path)
433
- proprio_projector.load_state_dict(state_dict)
434
-
435
- return proprio_projector
436
-
437
-
438
- def get_noisy_action_projector(cfg: Any, llm_dim: int) -> NoisyActionProjector:
439
- """
440
- Get noisy action projector for diffusion-based action prediction.
441
-
442
- Args:
443
- cfg: Configuration object with model parameters
444
- llm_dim: Dimension of the language model
445
-
446
- Returns:
447
- NoisyActionProjector: The initialized noisy action projector
448
- """
449
- # Initialize projector and move to device
450
- noisy_action_projector = NoisyActionProjector(
451
- llm_dim=llm_dim,
452
- ).to(DEVICE)
453
- noisy_action_projector = noisy_action_projector.to(torch.bfloat16).to(DEVICE)
454
- noisy_action_projector.eval()
455
-
456
- # Find and load checkpoint
457
- checkpoint_path = find_checkpoint_file(cfg.pretrained_checkpoint, "noisy_action_projector")
458
- state_dict = load_component_state_dict(checkpoint_path)
459
- noisy_action_projector.load_state_dict(state_dict)
460
-
461
- return noisy_action_projector
462
-
463
-
464
- def get_action_head(cfg: Any, llm_dim: int) -> Union[L1RegressionActionHead, DiffusionActionHead]:
465
- """
466
- Get action head for continuous value prediction.
467
-
468
- Args:
469
- cfg: Configuration object with model parameters
470
- llm_dim: Dimension of the language model
471
-
472
- Returns:
473
- Union[L1RegressionActionHead, DiffusionActionHead]: The initialized action head
474
-
475
- Raises:
476
- AssertionError: If both L1 regression and diffusion are specified
477
- """
478
- assert not (cfg.use_l1_regression and cfg.use_diffusion), "Cannot use both L1 regression and diffusion action head!"
479
-
480
- # Initialize appropriate action head based on configuration
481
- if cfg.use_l1_regression:
482
- action_head = L1RegressionActionHead(input_dim=llm_dim, hidden_dim=llm_dim, action_dim=ACTION_DIM)
483
- elif cfg.use_diffusion:
484
- action_head = DiffusionActionHead(
485
- input_dim=llm_dim, hidden_dim=llm_dim, action_dim=ACTION_DIM, num_diffusion_steps_train=cfg.num_diffusion_steps_train
486
- )
487
- # Set number of diffusion steps for inference
488
- action_head.noise_scheduler.set_timesteps(cfg.num_diffusion_steps_inference)
489
- else:
490
- raise ValueError("Either use_l1_regression or use_diffusion must be True")
491
-
492
- action_head = action_head.to(torch.bfloat16).to(DEVICE)
493
- action_head.eval()
494
-
495
- # Find and load checkpoint (may be on Hugging Face Hub or stored locally)
496
- if model_is_on_hf_hub(cfg.pretrained_checkpoint):
497
- model_path_to_action_head_name = {
498
- "moojink/openvla-7b-oft-finetuned-libero-spatial": "action_head--150000_checkpoint.pt",
499
- "moojink/openvla-7b-oft-finetuned-libero-object": "action_head--150000_checkpoint.pt",
500
- "moojink/openvla-7b-oft-finetuned-libero-goal": "action_head--50000_checkpoint.pt",
501
- "moojink/openvla-7b-oft-finetuned-libero-10": "action_head--150000_checkpoint.pt",
502
- "moojink/openvla-7b-oft-finetuned-libero-spatial-object-goal-10": "action_head--300000_checkpoint.pt",
503
- }
504
- if cfg.pretrained_checkpoint not in model_path_to_action_head_name.keys():
505
- raise ValueError("Unsupported HF Hub pretrained checkpoint found!")
506
- # Download proprio projector directly from HF Hub
507
- action_head_path = hf_hub_download(
508
- repo_id=cfg.pretrained_checkpoint, filename=model_path_to_action_head_name[cfg.pretrained_checkpoint]
509
- )
510
- state_dict = load_component_state_dict(action_head_path)
511
- action_head.load_state_dict(state_dict)
512
- else:
513
- checkpoint_path = find_checkpoint_file(cfg.pretrained_checkpoint, "action_head")
514
- state_dict = load_component_state_dict(checkpoint_path)
515
- action_head.load_state_dict(state_dict)
516
-
517
- return action_head
518
-
519
-
520
- def resize_image_for_policy(img: np.ndarray, resize_size: Union[int, Tuple[int, int]]) -> np.ndarray:
521
- """
522
- Resize an image to match the policy's expected input size.
523
-
524
- Uses the same resizing scheme as in the training data pipeline for distribution matching.
525
-
526
- Args:
527
- img: Numpy array containing the image
528
- resize_size: Target size as int (square) or (height, width) tuple
529
-
530
- Returns:
531
- np.ndarray: The resized image
532
- """
533
- assert isinstance(resize_size, int) or isinstance(resize_size, tuple)
534
- if isinstance(resize_size, int):
535
- resize_size = (resize_size, resize_size)
536
-
537
- # Resize using the same pipeline as in RLDS dataset builder
538
- img = tf.image.encode_jpeg(img) # Encode as JPEG
539
- img = tf.io.decode_image(img, expand_animations=False, dtype=tf.uint8) # Decode back
540
- img = tf.image.resize(img, resize_size, method="lanczos3", antialias=True)
541
- img = tf.cast(tf.clip_by_value(tf.round(img), 0, 255), tf.uint8)
542
-
543
- return img.numpy()
544
-
545
-
546
- def crop_and_resize(image: tf.Tensor, crop_scale: float, batch_size: int) -> tf.Tensor:
547
- """
548
- Center-crop an image and resize it back to original dimensions.
549
-
550
- Uses the same logic as in the training data pipeline for distribution matching.
551
-
552
- Args:
553
- image: TF Tensor of shape (batch_size, H, W, C) or (H, W, C) with values in [0,1]
554
- crop_scale: Area of center crop relative to original image
555
- batch_size: Batch size
556
-
557
- Returns:
558
- tf.Tensor: The cropped and resized image
559
- """
560
- # Handle 3D inputs by adding batch dimension if needed
561
- assert image.shape.ndims in (3, 4), "Image must be 3D or 4D tensor"
562
- expanded_dims = False
563
- if image.shape.ndims == 3:
564
- image = tf.expand_dims(image, axis=0)
565
- expanded_dims = True
566
-
567
- # Calculate crop dimensions (note: we use sqrt(crop_scale) for h/w)
568
- new_heights = tf.reshape(tf.clip_by_value(tf.sqrt(crop_scale), 0, 1), shape=(batch_size,))
569
- new_widths = tf.reshape(tf.clip_by_value(tf.sqrt(crop_scale), 0, 1), shape=(batch_size,))
570
-
571
- # Create bounding box for the crop
572
- height_offsets = (1 - new_heights) / 2
573
- width_offsets = (1 - new_widths) / 2
574
- bounding_boxes = tf.stack(
575
- [
576
- height_offsets,
577
- width_offsets,
578
- height_offsets + new_heights,
579
- width_offsets + new_widths,
580
- ],
581
- axis=1,
582
- )
583
-
584
- # Apply crop and resize
585
- image = tf.image.crop_and_resize(
586
- image, bounding_boxes, tf.range(batch_size), (OPENVLA_IMAGE_SIZE, OPENVLA_IMAGE_SIZE)
587
- )
588
-
589
- # Remove batch dimension if it was added
590
- if expanded_dims:
591
- image = image[0]
592
-
593
- return image
594
-
595
-
596
- def center_crop_image(image: Union[np.ndarray, Image.Image]) -> Image.Image:
597
- """
598
- Center crop an image to match training data distribution.
599
-
600
- Args:
601
- image: Input image (PIL or numpy array)
602
-
603
- Returns:
604
- Image.Image: Cropped PIL Image
605
- """
606
- batch_size = 1
607
- crop_scale = 0.9
608
-
609
- # Convert to TF Tensor if needed
610
- if not isinstance(image, tf.Tensor):
611
- image = tf.convert_to_tensor(np.array(image))
612
-
613
- orig_dtype = image.dtype
614
-
615
- # Convert to float32 in range [0,1]
616
- image = tf.image.convert_image_dtype(image, tf.float32)
617
-
618
- # Apply center crop and resize
619
- image = crop_and_resize(image, crop_scale, batch_size)
620
-
621
- # Convert back to original data type
622
- image = tf.clip_by_value(image, 0, 1)
623
- image = tf.image.convert_image_dtype(image, orig_dtype, saturate=True)
624
-
625
- # Convert to PIL Image
626
- return Image.fromarray(image.numpy()).convert("RGB")
627
-
628
-
629
- def check_image_format(image: Any) -> None:
630
- """
631
- Validate input image format.
632
-
633
- Args:
634
- image: Image to check
635
-
636
- Raises:
637
- AssertionError: If image format is invalid
638
- """
639
- is_numpy_array = isinstance(image, np.ndarray)
640
- has_correct_shape = len(image.shape) == 3 and image.shape[-1] == 3
641
- has_correct_dtype = image.dtype == np.uint8
642
-
643
- assert is_numpy_array and has_correct_shape and has_correct_dtype, (
644
- "Incorrect image format detected! Make sure that the input image is a "
645
- "numpy array with shape (H, W, 3) and dtype np.uint8!"
646
- )
647
-
648
-
649
- def normalize_proprio(proprio: np.ndarray, norm_stats: Dict[str, Any]) -> np.ndarray:
650
- """
651
- Normalize proprioception data to match training distribution.
652
-
653
- Args:
654
- proprio: Raw proprioception data
655
- norm_stats: Normalization statistics
656
-
657
- Returns:
658
- np.ndarray: Normalized proprioception data
659
- """
660
- if ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS:
661
- mask = norm_stats.get("mask", np.ones_like(norm_stats["min"], dtype=bool))
662
- proprio_high, proprio_low = np.array(norm_stats["max"]), np.array(norm_stats["min"])
663
- elif ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS_Q99:
664
- mask = norm_stats.get("mask", np.ones_like(norm_stats["q01"], dtype=bool))
665
- proprio_high, proprio_low = np.array(norm_stats["q99"]), np.array(norm_stats["q01"])
666
- else:
667
- raise ValueError("Unsupported action/proprio normalization type detected!")
668
-
669
- normalized_proprio = np.clip(
670
- np.where(
671
- mask,
672
- 2 * (proprio - proprio_low) / (proprio_high - proprio_low + 1e-8) - 1,
673
- proprio,
674
- ),
675
- a_min=-1.0,
676
- a_max=1.0,
677
- )
678
-
679
- return normalized_proprio
680
-
681
-
682
- def prepare_images_for_vla(images: List[np.ndarray], cfg: Any) -> List[Image.Image]:
683
- """
684
- Prepare images for VLA input by resizing and cropping as needed.
685
-
686
- Args:
687
- images: List of input images as numpy arrays
688
- cfg: Configuration object with parameters
689
-
690
- Returns:
691
- List[Image.Image]: Processed images ready for the model
692
- """
693
- processed_images = []
694
-
695
- for image in images:
696
- # Validate format
697
- check_image_format(image)
698
-
699
- # Resize if needed
700
- if image.shape != (OPENVLA_IMAGE_SIZE, OPENVLA_IMAGE_SIZE, 3):
701
- image = resize_image_for_policy(image, OPENVLA_IMAGE_SIZE)
702
-
703
- # Convert to PIL image
704
- pil_image = Image.fromarray(image).convert("RGB")
705
-
706
- # Apply center crop if configured
707
- if cfg.center_crop:
708
- pil_image = center_crop_image(pil_image)
709
-
710
- processed_images.append(pil_image)
711
-
712
- return processed_images
713
-
714
-
715
- def get_vla_action(
716
- cfg: Any,
717
- vla: torch.nn.Module,
718
- processor: Any,
719
- obs: Dict[str, Any],
720
- task_label: str,
721
- action_head: Optional[torch.nn.Module] = None,
722
- proprio_projector: Optional[torch.nn.Module] = None,
723
- noisy_action_projector: Optional[torch.nn.Module] = None,
724
- use_film: bool = False,
725
- ) -> List[np.ndarray]:
726
- """
727
- Generate action predictions with the VLA policy.
728
-
729
- Args:
730
- cfg: Configuration object with parameters
731
- vla: The VLA model
732
- processor: Model processor for inputs
733
- obs: Observation dictionary
734
- task_label: Text description of the task
735
- action_head: Optional action head for continuous actions
736
- proprio_projector: Optional proprioception projector
737
- noisy_action_projector: Optional noisy action projector for diffusion
738
- use_film: Whether to use FiLM
739
-
740
- Returns:
741
- List[np.ndarray]: Predicted actions
742
- """
743
- with torch.inference_mode():
744
-
745
- # Collect all input images
746
- all_images = [obs["full_image"]]
747
- if cfg.num_images_in_input > 1:
748
- all_images.extend([obs[k] for k in obs.keys() if "wrist" in k])
749
-
750
- # Process images
751
- all_images = prepare_images_for_vla(all_images, cfg)
752
-
753
- # Extract primary image and additional images
754
- primary_image = all_images.pop(0)
755
-
756
- # Build VLA prompt
757
- prompt = f"In: What action should the robot take to {task_label.lower()}?\nOut:"
758
-
759
- # Process primary image
760
- inputs = processor(prompt, primary_image).to(DEVICE, dtype=torch.bfloat16)
761
-
762
- # Process additional wrist images if any
763
- if all_images:
764
- all_wrist_inputs = [
765
- processor(prompt, image_wrist).to(DEVICE, dtype=torch.bfloat16) for image_wrist in all_images
766
- ]
767
- # Concatenate all images
768
- primary_pixel_values = inputs["pixel_values"]
769
- all_wrist_pixel_values = [wrist_inputs["pixel_values"] for wrist_inputs in all_wrist_inputs]
770
- inputs["pixel_values"] = torch.cat([primary_pixel_values] + all_wrist_pixel_values, dim=1)
771
-
772
- # Process proprioception data if used
773
- proprio = None
774
- if cfg.use_proprio:
775
- proprio = obs["state"]
776
- proprio_norm_stats = vla.norm_stats[cfg.unnorm_key]["proprio"]
777
- obs["state"] = normalize_proprio(proprio, proprio_norm_stats)
778
- proprio = obs["state"]
779
-
780
- # Generate action
781
- if action_head is None:
782
- # Standard VLA output (single-image inputs, discrete actions)
783
- action, _ = vla.predict_action(**inputs, unnorm_key=cfg.unnorm_key, do_sample=False)
784
- else:
785
- # Custom action head for continuous actions
786
- action, _ = vla.predict_action(
787
- **inputs,
788
- unnorm_key=cfg.unnorm_key,
789
- do_sample=False,
790
- proprio=proprio,
791
- proprio_projector=proprio_projector,
792
- noisy_action_projector=noisy_action_projector,
793
- action_head=action_head,
794
- use_film=use_film,
795
- )
796
-
797
- # Return action chunk as list of actions
798
- return [action[i] for i in range(len(action))]
799
-
800
-
801
- def get_action_from_server(
802
- observation: Dict[str, Any], server_endpoint: str = "http://0.0.0.0:8777/act"
803
- ) -> Dict[str, Any]:
804
- """
805
- Get VLA action from remote inference server.
806
-
807
- Args:
808
- observation: Observation data to send to server
809
- server_endpoint: URL of the inference server
810
-
811
- Returns:
812
- Dict[str, Any]: Action response from server
813
- """
814
- response = requests.post(
815
- server_endpoint,
816
- json=observation,
817
- )
818
- return response.json()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/experiments/robot/robot_utils.py DELETED
@@ -1,199 +0,0 @@
1
- """Utils for evaluating robot policies in various environments."""
2
-
3
- import os
4
- import random
5
- import time
6
- from typing import Any, Dict, List, Optional, Union
7
-
8
- import numpy as np
9
- import torch
10
-
11
- from experiments.robot.openvla_utils import (
12
- get_vla,
13
- get_vla_action,
14
- )
15
-
16
- # Initialize important constants
17
- ACTION_DIM = 7
18
- DATE = time.strftime("%Y_%m_%d")
19
- DATE_TIME = time.strftime("%Y_%m_%d-%H_%M_%S")
20
- DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
21
-
22
- # Configure NumPy print settings
23
- np.set_printoptions(formatter={"float": lambda x: "{0:0.3f}".format(x)})
24
-
25
- # Initialize system prompt for OpenVLA v0.1
26
- OPENVLA_V01_SYSTEM_PROMPT = (
27
- "A chat between a curious user and an artificial intelligence assistant. "
28
- "The assistant gives helpful, detailed, and polite answers to the user's questions."
29
- )
30
-
31
- # Model image size configuration
32
- MODEL_IMAGE_SIZES = {
33
- "openvla": 224,
34
- # Add other models as needed
35
- }
36
-
37
-
38
- def set_seed_everywhere(seed: int) -> None:
39
- """
40
- Set random seed for all random number generators for reproducibility.
41
-
42
- Args:
43
- seed: The random seed to use
44
- """
45
- torch.manual_seed(seed)
46
- torch.cuda.manual_seed_all(seed)
47
- np.random.seed(seed)
48
- random.seed(seed)
49
- torch.backends.cudnn.deterministic = True
50
- torch.backends.cudnn.benchmark = False
51
- os.environ["PYTHONHASHSEED"] = str(seed)
52
-
53
-
54
- def get_model(cfg: Any, wrap_diffusion_policy_for_droid: bool = False) -> torch.nn.Module:
55
- """
56
- Load and initialize model for evaluation based on configuration.
57
-
58
- Args:
59
- cfg: Configuration object with model parameters
60
- wrap_diffusion_policy_for_droid: Whether to wrap diffusion policy for DROID
61
-
62
- Returns:
63
- torch.nn.Module: The loaded model
64
-
65
- Raises:
66
- ValueError: If model family is not supported
67
- """
68
- if cfg.model_family == "openvla":
69
- model = get_vla(cfg)
70
- else:
71
- raise ValueError(f"Unsupported model family: {cfg.model_family}")
72
-
73
- print(f"Loaded model: {type(model)}")
74
- return model
75
-
76
-
77
- def get_image_resize_size(cfg: Any) -> Union[int, tuple]:
78
- """
79
- Get image resize dimensions for a specific model.
80
-
81
- If returned value is an int, the resized image will be a square.
82
- If returned value is a tuple, the resized image will be a rectangle.
83
-
84
- Args:
85
- cfg: Configuration object with model parameters
86
-
87
- Returns:
88
- Union[int, tuple]: Image resize dimensions
89
-
90
- Raises:
91
- ValueError: If model family is not supported
92
- """
93
- if cfg.model_family not in MODEL_IMAGE_SIZES:
94
- raise ValueError(f"Unsupported model family: {cfg.model_family}")
95
-
96
- return MODEL_IMAGE_SIZES[cfg.model_family]
97
-
98
-
99
- def get_action(
100
- cfg: Any,
101
- model: torch.nn.Module,
102
- obs: Dict[str, Any],
103
- task_label: str,
104
- processor: Optional[Any] = None,
105
- action_head: Optional[torch.nn.Module] = None,
106
- proprio_projector: Optional[torch.nn.Module] = None,
107
- noisy_action_projector: Optional[torch.nn.Module] = None,
108
- use_film: bool = False,
109
- ) -> Union[List[np.ndarray], np.ndarray]:
110
- """
111
- Query the model to get action predictions.
112
-
113
- Args:
114
- cfg: Configuration object with model parameters
115
- model: The loaded model
116
- obs: Observation dictionary
117
- task_label: Text description of the task
118
- processor: Model processor for inputs
119
- action_head: Optional action head for continuous actions
120
- proprio_projector: Optional proprioception projector
121
- noisy_action_projector: Optional noisy action projector for diffusion
122
- use_film: Whether to use FiLM
123
-
124
- Returns:
125
- Union[List[np.ndarray], np.ndarray]: Predicted actions
126
-
127
- Raises:
128
- ValueError: If model family is not supported
129
- """
130
- with torch.no_grad():
131
- if cfg.model_family == "openvla":
132
- action = get_vla_action(
133
- cfg=cfg,
134
- vla=model,
135
- processor=processor,
136
- obs=obs,
137
- task_label=task_label,
138
- action_head=action_head,
139
- proprio_projector=proprio_projector,
140
- noisy_action_projector=noisy_action_projector,
141
- use_film=use_film,
142
- )
143
- else:
144
- raise ValueError(f"Unsupported model family: {cfg.model_family}")
145
-
146
- return action
147
-
148
-
149
- def normalize_gripper_action(action: np.ndarray, binarize: bool = True) -> np.ndarray:
150
- """
151
- Normalize gripper action from [0,1] to [-1,+1] range.
152
-
153
- This is necessary for some environments because the dataset wrapper
154
- standardizes gripper actions to [0,1]. Note that unlike the other action
155
- dimensions, the gripper action is not normalized to [-1,+1] by default.
156
-
157
- Normalization formula: y = 2 * (x - orig_low) / (orig_high - orig_low) - 1
158
-
159
- Args:
160
- action: Action array with gripper action in the last dimension
161
- binarize: Whether to binarize gripper action to -1 or +1
162
-
163
- Returns:
164
- np.ndarray: Action array with normalized gripper action
165
- """
166
- # Create a copy to avoid modifying the original
167
- normalized_action = action.copy()
168
-
169
- # Normalize the last action dimension to [-1,+1]
170
- orig_low, orig_high = 0.0, 1.0
171
- normalized_action[..., -1] = 2 * (normalized_action[..., -1] - orig_low) / (orig_high - orig_low) - 1
172
-
173
- if binarize:
174
- # Binarize to -1 or +1
175
- normalized_action[..., -1] = np.sign(normalized_action[..., -1])
176
-
177
- return normalized_action
178
-
179
-
180
- def invert_gripper_action(action: np.ndarray) -> np.ndarray:
181
- """
182
- Flip the sign of the gripper action (last dimension of action vector).
183
-
184
- This is necessary for environments where -1 = open, +1 = close, since
185
- the RLDS dataloader aligns gripper actions such that 0 = close, 1 = open.
186
-
187
- Args:
188
- action: Action array with gripper action in the last dimension
189
-
190
- Returns:
191
- np.ndarray: Action array with inverted gripper action
192
- """
193
- # Create a copy to avoid modifying the original
194
- inverted_action = action.copy()
195
-
196
- # Invert the gripper action
197
- inverted_action[..., -1] *= -1.0
198
-
199
- return inverted_action
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/prismatic/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .models import available_model_names, available_models, get_model_description, load
 
 
capvector-oft/prismatic/conf/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from .datasets import DatasetConfig, DatasetRegistry
2
- from .models import ModelConfig, ModelRegistry
3
- from .vla import VLAConfig, VLARegistry
 
 
 
 
capvector-oft/prismatic/conf/datasets.py DELETED
@@ -1,133 +0,0 @@
1
- """
2
- datasets.py
3
-
4
- Draccus Dataclass Definition for a DatasetConfig object, with various registered subclasses for each dataset variant
5
- and processing scheme. A given dataset variant (e.g., `llava-lightning`) configures the following attributes:
6
- - Dataset Variant (Identifier) --> e.g., "llava-v15"
7
- - Align Stage Dataset Components (annotations, images)
8
- - Finetune Stage Dataset Components (annotations, images)
9
- - Dataset Root Directory (Path)
10
- """
11
-
12
- from dataclasses import dataclass
13
- from enum import Enum, unique
14
- from pathlib import Path
15
- from typing import Tuple
16
-
17
- from draccus import ChoiceRegistry
18
-
19
-
20
- @dataclass
21
- class DatasetConfig(ChoiceRegistry):
22
- # fmt: off
23
- dataset_id: str # Unique ID that fully specifies a dataset variant
24
-
25
- # Dataset Components for each Stage in < align | finetune >
26
- align_stage_components: Tuple[Path, Path] # Path to annotation file and images directory for `align` stage
27
- finetune_stage_components: Tuple[Path, Path] # Path to annotation file and images directory for `finetune` stage
28
-
29
- dataset_root_dir: Path # Path to dataset root directory; others paths are relative to root
30
- # fmt: on
31
-
32
-
33
- # [Reproduction] LLaVa-v15 (exact dataset used in all public LLaVa-v15 models)
34
- @dataclass
35
- class LLaVa_V15_Config(DatasetConfig):
36
- dataset_id: str = "llava-v15"
37
-
38
- align_stage_components: Tuple[Path, Path] = (
39
- Path("download/llava-laion-cc-sbu-558k/chat.json"),
40
- Path("download/llava-laion-cc-sbu-558k/"),
41
- )
42
- finetune_stage_components: Tuple[Path, Path] = (
43
- Path("download/llava-v1.5-instruct/llava_v1_5_mix665k.json"),
44
- Path("download/llava-v1.5-instruct/"),
45
- )
46
- dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms")
47
-
48
-
49
- # [Multimodal-Only] LLava-v15 WITHOUT the Language-Only ShareGPT Data (No Co-Training)
50
- @dataclass
51
- class LLaVa_Multimodal_Only_Config(DatasetConfig):
52
- dataset_id: str = "llava-multimodal"
53
-
54
- align_stage_components: Tuple[Path, Path] = (
55
- Path("download/llava-laion-cc-sbu-558k/chat.json"),
56
- Path("download/llava-laion-cc-sbu-558k/"),
57
- )
58
- finetune_stage_components: Tuple[Path, Path] = (
59
- Path("download/llava-v1.5-instruct/llava_v1_5_stripped625k.json"),
60
- Path("download/llava-v1.5-instruct/"),
61
- )
62
- dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms")
63
-
64
-
65
- # LLaVa-v15 + LVIS-Instruct-4V
66
- @dataclass
67
- class LLaVa_LVIS4V_Config(DatasetConfig):
68
- dataset_id: str = "llava-lvis4v"
69
-
70
- align_stage_components: Tuple[Path, Path] = (
71
- Path("download/llava-laion-cc-sbu-558k/chat.json"),
72
- Path("download/llava-laion-cc-sbu-558k/"),
73
- )
74
- finetune_stage_components: Tuple[Path, Path] = (
75
- Path("download/llava-v1.5-instruct/llava_v1_5_lvis4v_mix888k.json"),
76
- Path("download/llava-v1.5-instruct/"),
77
- )
78
- dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms")
79
-
80
-
81
- # LLaVa-v15 + LRV-Instruct
82
- @dataclass
83
- class LLaVa_LRV_Config(DatasetConfig):
84
- dataset_id: str = "llava-lrv"
85
-
86
- align_stage_components: Tuple[Path, Path] = (
87
- Path("download/llava-laion-cc-sbu-558k/chat.json"),
88
- Path("download/llava-laion-cc-sbu-558k/"),
89
- )
90
- finetune_stage_components: Tuple[Path, Path] = (
91
- Path("download/llava-v1.5-instruct/llava_v1_5_lrv_mix1008k.json"),
92
- Path("download/llava-v1.5-instruct/"),
93
- )
94
- dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms")
95
-
96
-
97
- # LLaVa-v15 + LVIS-Instruct-4V + LRV-Instruct
98
- @dataclass
99
- class LLaVa_LVIS4V_LRV_Config(DatasetConfig):
100
- dataset_id: str = "llava-lvis4v-lrv"
101
-
102
- align_stage_components: Tuple[Path, Path] = (
103
- Path("download/llava-laion-cc-sbu-558k/chat.json"),
104
- Path("download/llava-laion-cc-sbu-558k/"),
105
- )
106
- finetune_stage_components: Tuple[Path, Path] = (
107
- Path("download/llava-v1.5-instruct/llava_v1_5_lvis4v_lrv_mix1231k.json"),
108
- Path("download/llava-v1.5-instruct/"),
109
- )
110
- dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms")
111
-
112
-
113
- # === Define a Dataset Registry Enum for Reference & Validation =>> all *new* datasets must be added here! ===
114
- @unique
115
- class DatasetRegistry(Enum):
116
- # === LLaVa v1.5 ===
117
- LLAVA_V15 = LLaVa_V15_Config
118
-
119
- LLAVA_MULTIMODAL_ONLY = LLaVa_Multimodal_Only_Config
120
-
121
- LLAVA_LVIS4V = LLaVa_LVIS4V_Config
122
- LLAVA_LRV = LLaVa_LRV_Config
123
-
124
- LLAVA_LVIS4V_LRV = LLaVa_LVIS4V_LRV_Config
125
-
126
- @property
127
- def dataset_id(self) -> str:
128
- return self.value.dataset_id
129
-
130
-
131
- # Register Datasets in Choice Registry
132
- for dataset_variant in DatasetRegistry:
133
- DatasetConfig.register_subclass(dataset_variant.dataset_id, dataset_variant.value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/prismatic/conf/models.py DELETED
@@ -1,584 +0,0 @@
1
- """
2
- models.py
3
-
4
- Draccus Dataclass Definition for a ModelConfig object, with various registered subclasses for each model family and
5
- variant thereof. A given model variant configures the following attributes:
6
- - Pretrained Visual Representation (e.g., OpenAI CLIP ViT-L/14) + Pretrained LLM Backbone (e.g., LLaMa-2 7B)
7
- - VLM Configuration + Parameters (e.g., MLP Projector, Image Preprocessing, etc.)
8
- - [Optional] Stage 1 (`align`) Optimization Hyperparameters
9
- - Stage 2 (`finetune`) Optimization Hyperparameters
10
- """
11
-
12
- from dataclasses import dataclass
13
- from enum import Enum, unique
14
- from typing import Optional
15
-
16
- from draccus import ChoiceRegistry
17
-
18
-
19
- @dataclass
20
- class ModelConfig(ChoiceRegistry):
21
- # fmt: off
22
- model_id: str # Unique Model ID that fully specifies a given variant
23
- arch_specifier: str # Architecture specifier string (e.g., "gelu-mlp")
24
-
25
- # Pretrained Backbones
26
- vision_backbone_id: str # Pretrained Visual Featurizer (from TIMM) to load
27
- llm_backbone_id: str # Pretrained LLM (from HF Transformers) to load
28
-
29
- # Backbone Parameters
30
- image_resize_strategy: str # Resizing strategy in < crop | letterbox | corner-pad >
31
- llm_max_length: int # Maximum context length for LLM (can be < than max!)
32
-
33
- # === Multi-Stage Optimization Hyperparameters ===
34
- # By default, we assume an AdamW optimizer with FSDP (Gradient Sharding or Full Sharding depending on stage)
35
-
36
- # Align Stage Optimization Parameters
37
- align_epochs: int # Epochs to Run (in case `max_steps` is not specified)
38
- align_max_steps: Optional[int] # [Optional] Max Gradient Steps (overrides epochs)
39
- align_global_batch_size: int # Global Batch Size (divided across processes)
40
- align_per_device_batch_size: int # Per-Device Batch Size (per-process)
41
- # => # of accumulation steps is auto-computed
42
-
43
- align_learning_rate: float # Peak Learning Rate (lr_scheduler sets warmup/decay)
44
- align_weight_decay: float # Weight Decay for AdamW Optimizer
45
- align_max_grad_norm: float # Max Grad Norm (for global gradient clipping)
46
- align_lr_scheduler_type: str # LR Scheduler (default: "linear-warmup+cosine-decay")
47
- align_warmup_ratio: float # Fraction of total steps to warmup
48
-
49
- align_train_strategy: str # Align Train Strategy (default: "fsdp-shard-grad-op")
50
-
51
- # Finetune Stage Optimization Parameters
52
- finetune_epochs: int # Epochs to Run (in case `max_steps` is not specified)
53
- finetune_max_steps: Optional[int] # [Optional] Max Gradient Steps (overrides epochs)
54
- finetune_global_batch_size: int # Global Batch Size (divided across processes)
55
- finetune_per_device_batch_size: int # Per-Device Batch Size (per-process)
56
- # => # of accumulation steps is auto-computed
57
-
58
- finetune_learning_rate: float # Peak Learning Rate (lr_scheduler sets warmup/decay)
59
- finetune_weight_decay: float # Weight Decay for AdamW Optimizer
60
- finetune_max_grad_norm: float # Max Grad Norm (for global gradient clipping)
61
- finetune_lr_scheduler_type: str # LR Scheduler (default: "linear-warmup+cosine-decay")
62
- finetune_warmup_ratio: float # Fraction of total steps to warmup
63
-
64
- finetune_train_strategy: str # Finetune Train Strategy (default: "fsdp-full-shard")
65
-
66
- # Enable Gradient/Activation Checkpointing (for the LLM Backbone)
67
- enable_gradient_checkpointing: bool = True
68
-
69
- # Enable Traditional Mixed Precision Training via Torch Native AMP (`autocast`)
70
- enable_mixed_precision_training: bool = True # Whether to enable mixed precision training
71
- reduce_in_full_precision: bool = False # Whether to run gradient reduction in FP32
72
-
73
- # fmt: on
74
-
75
-
76
- # === LLaVa v1.5 Reproduction - Fully Specified Configurations ===
77
- @dataclass
78
- class LLaVa_v15_Reproduction_7B(ModelConfig):
79
- model_id: str = "reproduction-llava-v15+7b"
80
- arch_specifier: str = "gelu-mlp"
81
-
82
- vision_backbone_id: str = "clip-vit-l-336px"
83
- llm_backbone_id: str = "vicuna-v15-7b"
84
-
85
- image_resize_strategy: str = "letterbox"
86
- llm_max_length: int = 2048
87
-
88
- # Align Stage Optimization Parameters
89
- align_epochs: int = 1
90
- align_max_steps: Optional[int] = None
91
- align_global_batch_size: int = 256
92
- align_per_device_batch_size: int = 16
93
-
94
- align_learning_rate: float = 1e-3
95
- align_weight_decay: float = 0.0
96
- align_max_grad_norm: float = 1.0
97
- align_lr_scheduler_type: str = "linear-warmup+cosine-decay"
98
- align_warmup_ratio: float = 0.03
99
-
100
- align_train_strategy: str = "fsdp-shard-grad-op"
101
-
102
- # Finetune Stage Optimization Parameters
103
- finetune_epochs: int = 1
104
- finetune_max_steps: Optional[int] = None
105
- finetune_global_batch_size: int = 128
106
- finetune_per_device_batch_size: int = 16
107
-
108
- finetune_learning_rate: float = 2e-5
109
- finetune_weight_decay: float = 0.1
110
- finetune_max_grad_norm: float = 1.0
111
- finetune_lr_scheduler_type: str = "linear-warmup+cosine-decay"
112
- finetune_warmup_ratio: float = 0.03
113
-
114
- finetune_train_strategy: str = "fsdp-full-shard"
115
-
116
-
117
- @dataclass
118
- class LLaVa_v15_Reproduction_13B(LLaVa_v15_Reproduction_7B):
119
- model_id: str = "reproduction-llava-v15+13b"
120
- llm_backbone_id: str = "vicuna-v15-13b"
121
-
122
-
123
- # === Section 4.1 :: Optimization Procedure ===
124
-
125
-
126
- # Section 4.1A :: 🚀 --> Necessity of Multi-Stage Training
127
- @dataclass
128
- class Exp_7B_One_Stage(LLaVa_v15_Reproduction_7B):
129
- model_id: str = "one-stage+7b"
130
- arch_specifier: str = "no-align+gelu-mlp"
131
-
132
-
133
- @dataclass
134
- class Exp_13B_One_Stage(LLaVa_v15_Reproduction_13B):
135
- model_id: str = "one-stage+13b"
136
- arch_specifier: str = "no-align+gelu-mlp"
137
-
138
-
139
- # Section 4.1B :: 🛠️ --> Full Finetuning through Visual Backbones
140
- # =>> Note :: Run with `--stage full-finetune`
141
- @dataclass
142
- class Exp_7B_Full_Finetune_Multi_Stage(LLaVa_v15_Reproduction_7B):
143
- model_id: str = "full-ft-multi-stage+7b"
144
-
145
-
146
- @dataclass
147
- class Exp_7B_Full_Finetune_One_Stage(Exp_7B_One_Stage):
148
- model_id: str = "full-ft-one-stage+7b"
149
-
150
-
151
- # === Section 4.2 :: Image Processing and Visual Representations ===
152
-
153
-
154
- # Section 4.2A :: 📸 --> Choosing a Pretrained Representation
155
- @dataclass
156
- class Exp_7B_IN1K_ViT_L_p16_224px(Exp_7B_One_Stage):
157
- model_id: str = "in1k-224px+7b"
158
- vision_backbone_id: str = "in1k-vit-l"
159
-
160
-
161
- @dataclass
162
- class Exp_7B_DINOv2_ViT_L_p14_224px(Exp_7B_One_Stage):
163
- model_id: str = "dinov2-224px+7b"
164
- vision_backbone_id: str = "dinov2-vit-l"
165
-
166
-
167
- @dataclass
168
- class Exp_7B_CLIP_ViT_L_p14_224px(Exp_7B_One_Stage):
169
- model_id: str = "clip-224px+7b"
170
- vision_backbone_id: str = "clip-vit-l"
171
-
172
-
173
- @dataclass
174
- class Exp_7B_SigLIP_ViT_SO_p14_224px(Exp_7B_One_Stage):
175
- model_id: str = "siglip-224px+7b"
176
- vision_backbone_id: str = "siglip-vit-so400m"
177
-
178
-
179
- # Section 4.2B :: 📐 --> Choosing an Image Preprocessing Strategy
180
- @dataclass
181
- class Exp_7B_CLIP_ViT_L_p14_336px_Resize_Crop(Exp_7B_One_Stage):
182
- model_id: str = "clip-336px-resize-crop+7b"
183
- image_resize_strategy: str = "resize-crop"
184
-
185
-
186
- @dataclass
187
- class Exp_7B_CLIP_ViT_L_p14_336px_Resize_Naive(Exp_7B_One_Stage):
188
- model_id: str = "clip-336px-resize-naive+7b"
189
- image_resize_strategy: str = "resize-naive"
190
-
191
-
192
- @dataclass
193
- class Exp_7B_SigLIP_ViT_SO_p14_384px_Letterbox(Exp_7B_One_Stage):
194
- model_id: str = "siglip-384px-letterbox+7b"
195
- vision_backbone_id: str = "siglip-vit-so400m-384px"
196
- image_resize_strategy: str = "letterbox"
197
-
198
-
199
- @dataclass
200
- class Exp_7B_SigLIP_ViT_SO_p14_384px_Resize_Crop(Exp_7B_One_Stage):
201
- model_id: str = "siglip-384px-resize-crop+7b"
202
- vision_backbone_id: str = "siglip-vit-so400m-384px"
203
- image_resize_strategy: str = "resize-crop"
204
-
205
-
206
- @dataclass
207
- class Exp_7B_SigLIP_ViT_SO_p14_384px_Resize_Naive(Exp_7B_One_Stage):
208
- model_id: str = "siglip-384px-resize-naive+7b"
209
- vision_backbone_id: str = "siglip-vit-so400m-384px"
210
- image_resize_strategy: str = "resize-naive"
211
-
212
-
213
- # Section 4.2D :: 🥞 --> Stacking/Ensembling Visual Representations
214
- @dataclass
215
- class Exp_7B_DINOCLIP_ViT_L_p14_336px_Letterbox(Exp_7B_One_Stage):
216
- model_id: str = "dinoclip-336px-letterbox+7b"
217
- vision_backbone_id: str = "dinoclip-vit-l-336px"
218
- image_resize_strategy: str = "letterbox"
219
- arch_specifier: str = "no-align+fused-gelu-mlp"
220
-
221
-
222
- @dataclass
223
- class Exp_7B_DINOCLIP_ViT_L_p14_336px_Resize_Naive(Exp_7B_One_Stage):
224
- model_id: str = "dinoclip-336px-resize-naive+7b"
225
- vision_backbone_id: str = "dinoclip-vit-l-336px"
226
- image_resize_strategy: str = "resize-naive"
227
- arch_specifier: str = "no-align+fused-gelu-mlp"
228
-
229
-
230
- @dataclass
231
- class Exp_7B_DINOSigLIP_ViT_L_p14_384px_Letterbox(Exp_7B_One_Stage):
232
- model_id: str = "dinosiglip-384px-letterbox+7b"
233
- vision_backbone_id: str = "dinosiglip-vit-so-384px"
234
- image_resize_strategy: str = "letterbox"
235
- arch_specifier: str = "no-align+fused-gelu-mlp"
236
-
237
-
238
- @dataclass
239
- class Exp_7B_DINOSigLIP_ViT_L_p14_384px_Resize_Naive(Exp_7B_One_Stage):
240
- model_id: str = "dinosiglip-384px-resize-naive+7b"
241
- vision_backbone_id: str = "dinosiglip-vit-so-384px"
242
- image_resize_strategy: str = "resize-naive"
243
- arch_specifier: str = "no-align+fused-gelu-mlp"
244
-
245
-
246
- # === Section 4.3 :: Language Models ===
247
-
248
-
249
- # Section 4.3A :: 📝 --> Base vs. Instruct-Tuned (Chat) LLMs
250
- @dataclass
251
- class Exp_7B_Llama2(Exp_7B_One_Stage):
252
- model_id: str = "llama2+7b"
253
- llm_backbone_id: str = "llama2-7b-pure"
254
-
255
-
256
- @dataclass
257
- class Exp_13B_Llama2(Exp_13B_One_Stage):
258
- model_id: str = "llama2+13b"
259
- llm_backbone_id: str = "llama2-13b-pure"
260
-
261
-
262
- # ~ Additional LLM Backbones :: LLaMa-2 Chat, Mistral v0.1, Mistral v0.1 Instruct, Phi-2 ~
263
- @dataclass
264
- class Ext_Exp_7B_Llama2_Chat(Exp_7B_One_Stage):
265
- model_id: str = "llama2-chat+7b"
266
- llm_backbone_id: str = "llama2-7b-chat"
267
-
268
-
269
- @dataclass
270
- class Ext_Exp_13B_Llama2_Chat(Exp_13B_One_Stage):
271
- model_id: str = "llama2-chat+13b"
272
- llm_backbone_id: str = "llama2-13b-chat"
273
-
274
-
275
- @dataclass
276
- class Ext_Exp_7B_Mistral_V1(Exp_7B_One_Stage):
277
- model_id: str = "mistral-v0.1+7b"
278
- llm_backbone_id: str = "mistral-v0.1-7b-pure"
279
-
280
-
281
- @dataclass
282
- class Ext_Exp_7B_Mistral_Instruct_V1(Exp_7B_One_Stage):
283
- model_id: str = "mistral-instruct-v0.1+7b"
284
- llm_backbone_id: str = "mistral-v0.1-7b-instruct"
285
-
286
-
287
- @dataclass
288
- class Ext_Exp_3B_Phi_2(Exp_7B_One_Stage):
289
- model_id: str = "phi-2+3b"
290
- llm_backbone_id: str = "phi-2-3b"
291
-
292
-
293
- # Section 4.3B :: ✌️ --> Co-training on Language-only Data
294
- # =>> Note :: Run with `--dataset.type "llava-multimodal" (multimodal data only / no co-training)
295
- @dataclass
296
- class Exp_7B_Vicuna_No_Cotraining(Exp_7B_One_Stage):
297
- model_id: str = "vicuna-no-cotraining+7b"
298
-
299
-
300
- @dataclass
301
- class Exp_7B_Llama2_No_Cotraining(Exp_7B_One_Stage):
302
- model_id: str = "llama2-no-cotraining+7b"
303
- llm_backbone_id: str = "llama2-7b-pure"
304
-
305
-
306
- # === Section 4.4 :: Scaling Properties - Train Time & Data ===
307
-
308
-
309
- # Section 4.4A :: ⏰ --> Scaling Train Time
310
- @dataclass
311
- class Exp_7B_1p25_Epochs(Exp_7B_One_Stage):
312
- model_id: str = "train-1.25-epochs+7b"
313
- finetune_max_steps: int = 6500
314
-
315
-
316
- @dataclass
317
- class Exp_7B_1p5_Epochs(Exp_7B_One_Stage):
318
- model_id: str = "train-1.5-epochs+7b"
319
- finetune_max_steps: int = 7800
320
-
321
-
322
- @dataclass
323
- class Exp_7B_2_Epochs(Exp_7B_One_Stage):
324
- model_id: str = "train-2-epochs+7b"
325
- finetune_epochs: int = 2
326
-
327
-
328
- @dataclass
329
- class Exp_7B_3_Epochs(Exp_7B_One_Stage):
330
- model_id: str = "train-3-epochs+7b"
331
- finetune_epochs: int = 3
332
-
333
-
334
- # Section 4.4B :: 📚 --> Scaling Data
335
- # =>> Note :: Run with `--dataset.type "llava-lvis4v"`
336
- @dataclass
337
- class Exp_7B_LLaVa_LVIS4V(Exp_7B_One_Stage):
338
- model_id: str = "llava-lvis4v+7b"
339
-
340
-
341
- # =>> Note :: Run with `--dataset.type "llava-lrv"`
342
- @dataclass
343
- class Exp_7B_LLaVa_LRV(Exp_7B_One_Stage):
344
- model_id: str = "llava-lrv+7b"
345
-
346
-
347
- # =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"`
348
- @dataclass
349
- class Exp_7B_LLaVa_LVIS4V_LRV(Exp_7B_One_Stage):
350
- model_id: str = "llava-lvis4v-lrv+7b"
351
-
352
-
353
- # === Section 5 :: Prisms ===
354
-
355
-
356
- # Prism-CLIP
357
- @dataclass
358
- class Prism_7B_CLIP_Controlled(Exp_7B_One_Stage):
359
- model_id: str = "prism-clip-controlled+7b"
360
- vision_backbone_id: str = "clip-vit-l-336px"
361
- image_resize_strategy: str = "resize-naive"
362
- llm_backbone_id: str = "llama2-7b-pure"
363
-
364
-
365
- @dataclass
366
- class Prism_13B_CLIP_Controlled(Exp_13B_One_Stage):
367
- model_id: str = "prism-clip-controlled+13b"
368
- vision_backbone_id: str = "clip-vit-l-336px"
369
- image_resize_strategy: str = "resize-naive"
370
- llm_backbone_id: str = "llama2-13b-pure"
371
-
372
-
373
- # =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"`
374
- @dataclass
375
- class Prism_7B_CLIP(Exp_7B_One_Stage):
376
- model_id: str = "prism-clip+7b"
377
- vision_backbone_id: str = "clip-vit-l-336px"
378
- image_resize_strategy: str = "resize-naive"
379
- llm_backbone_id: str = "llama2-7b-pure"
380
- finetune_epochs: int = 2
381
-
382
-
383
- # =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"`
384
- @dataclass
385
- class Prism_13B_CLIP(Exp_13B_One_Stage):
386
- model_id: str = "prism-clip+13b"
387
- vision_backbone_id: str = "clip-vit-l-336px"
388
- image_resize_strategy: str = "resize-naive"
389
- llm_backbone_id: str = "llama2-13b-pure"
390
- finetune_epochs: int = 2
391
-
392
-
393
- # Prism-SigLIP
394
- @dataclass
395
- class Prism_7B_SigLIP_Controlled(Exp_7B_One_Stage):
396
- model_id: str = "prism-siglip-controlled+7b"
397
- vision_backbone_id: str = "siglip-vit-so400m-384px"
398
- image_resize_strategy: str = "resize-naive"
399
- llm_backbone_id: str = "llama2-7b-pure"
400
-
401
-
402
- @dataclass
403
- class Prism_13B_SigLIP_Controlled(Exp_13B_One_Stage):
404
- model_id: str = "prism-siglip-controlled+13b"
405
- vision_backbone_id: str = "siglip-vit-so400m-384px"
406
- image_resize_strategy: str = "resize-naive"
407
- llm_backbone_id: str = "llama2-13b-pure"
408
-
409
-
410
- # =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"`
411
- @dataclass
412
- class Prism_7B_SigLIP(Exp_7B_One_Stage):
413
- model_id: str = "prism-siglip+7b"
414
- vision_backbone_id: str = "siglip-vit-so400m-384px"
415
- image_resize_strategy: str = "resize-naive"
416
- llm_backbone_id: str = "llama2-7b-pure"
417
- finetune_epochs: int = 2
418
-
419
-
420
- # =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"`
421
- @dataclass
422
- class Prism_13B_SigLIP(Exp_13B_One_Stage):
423
- model_id: str = "prism-siglip+13b"
424
- vision_backbone_id: str = "clip-vit-l-336px"
425
- image_resize_strategy: str = "resize-naive"
426
- llm_backbone_id: str = "llama2-13b-pure"
427
- finetune_epochs: int = 2
428
-
429
-
430
- # Prism-DINOSigLIP
431
- @dataclass
432
- class Prism_7B_DINOSigLIP_Controlled(Exp_7B_One_Stage):
433
- model_id: str = "prism-dinosiglip-controlled+7b"
434
- vision_backbone_id: str = "dinosiglip-vit-so-384px"
435
- image_resize_strategy: str = "resize-naive"
436
- llm_backbone_id: str = "llama2-7b-pure"
437
- arch_specifier: str = "no-align+fused-gelu-mlp"
438
-
439
-
440
- @dataclass
441
- class Prism_13B_DINOSigLIP_Controlled(Exp_13B_One_Stage):
442
- model_id: str = "prism-dinosiglip-controlled+13b"
443
- vision_backbone_id: str = "dinosiglip-vit-so-384px"
444
- image_resize_strategy: str = "resize-naive"
445
- llm_backbone_id: str = "llama2-13b-pure"
446
- arch_specifier: str = "no-align+fused-gelu-mlp"
447
-
448
-
449
- # =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"`
450
- @dataclass
451
- class Prism_7B_DINOSigLIP(Exp_7B_One_Stage):
452
- model_id: str = "prism-dinosiglip+7b"
453
- vision_backbone_id: str = "dinosiglip-vit-so-384px"
454
- image_resize_strategy: str = "resize-naive"
455
- llm_backbone_id: str = "llama2-7b-pure"
456
- arch_specifier: str = "no-align+fused-gelu-mlp"
457
- finetune_epochs: int = 2
458
-
459
-
460
- # =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"`
461
- @dataclass
462
- class Prism_13B_DINOSigLIP(Exp_13B_One_Stage):
463
- model_id: str = "prism-dinosiglip+13b"
464
- vision_backbone_id: str = "dinosiglip-vit-so-384px"
465
- image_resize_strategy: str = "resize-naive"
466
- llm_backbone_id: str = "llama2-13b-pure"
467
- arch_specifier: str = "no-align+fused-gelu-mlp"
468
- finetune_epochs: int = 2
469
-
470
-
471
- # [Inference-Optimized] 224px Prisms
472
- @dataclass
473
- class Opt_7B_DINOSigLIP_ViT_SO_p14_224px_Resize_Naive(Exp_7B_One_Stage):
474
- model_id: str = "dinosiglip-224px-resize-naive+7b"
475
- vision_backbone_id: str = "dinosiglip-vit-so-224px"
476
- image_resize_strategy: str = "resize-naive"
477
- arch_specifier: str = "no-align+fused-gelu-mlp"
478
-
479
-
480
- @dataclass
481
- class Prism_7B_DINOSigLIP_224px_Controlled(Exp_7B_One_Stage):
482
- model_id: str = "prism-dinosiglip-224px-controlled+7b"
483
- vision_backbone_id: str = "dinosiglip-vit-so-224px"
484
- image_resize_strategy: str = "resize-naive"
485
- llm_backbone_id: str = "llama2-7b-pure"
486
- arch_specifier: str = "no-align+fused-gelu-mlp"
487
-
488
-
489
- # =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"`
490
- @dataclass
491
- class Prism_7B_DINOSigLIP_224px(Exp_7B_One_Stage):
492
- model_id: str = "prism-dinosiglip-224px+7b"
493
- vision_backbone_id: str = "dinosiglip-vit-so-224px"
494
- image_resize_strategy: str = "resize-naive"
495
- llm_backbone_id: str = "llama2-7b-pure"
496
- arch_specifier: str = "no-align+fused-gelu-mlp"
497
- finetune_epochs: int = 2
498
-
499
-
500
- # === Define a Model Registry Enum for Reference & Validation ===
501
- @unique
502
- class ModelRegistry(Enum):
503
- # === LLaVa v1.5 Base Reproductions ===
504
- REPRODUCTION_7B = LLaVa_v15_Reproduction_7B
505
- REPRODUCTION_13B = LLaVa_v15_Reproduction_13B
506
-
507
- # === Section 4.1 :: Optimization Procedure ===
508
- EXP_ONE_STAGE_7B = Exp_7B_One_Stage
509
- EXP_ONE_STAGE_13B = Exp_13B_One_Stage
510
-
511
- EXP_FULL_FT_MULTI_STAGE = Exp_7B_Full_Finetune_Multi_Stage
512
- EXP_FULL_FT_ONE_STAGE = Exp_7B_Full_Finetune_One_Stage
513
-
514
- # === Section 4.2 :: Image Processing and Visual Representations ===
515
- EXP_IN1K_224PX = Exp_7B_IN1K_ViT_L_p16_224px
516
- EXP_DINOV2_224PX = Exp_7B_DINOv2_ViT_L_p14_224px
517
- EXP_CLIP_224PX = Exp_7B_CLIP_ViT_L_p14_224px
518
- EXP_SIGLIP_224PX = Exp_7B_SigLIP_ViT_SO_p14_224px
519
-
520
- EXP_CLIP_336PX_RESIZE_CROP = Exp_7B_CLIP_ViT_L_p14_336px_Resize_Crop
521
- EXP_CLIP_336PX_RESIZE_NAIVE = Exp_7B_CLIP_ViT_L_p14_336px_Resize_Naive
522
- EXP_SIGLIP_384PX_LETTERBOX = Exp_7B_SigLIP_ViT_SO_p14_384px_Letterbox
523
- EXP_SIGLIP_384PX_RESIZE_CROP = Exp_7B_SigLIP_ViT_SO_p14_384px_Resize_Crop
524
- EXP_SIGLIP_384PX_RESIZE_NAIVE = Exp_7B_SigLIP_ViT_SO_p14_384px_Resize_Naive
525
-
526
- EXP_DINOCLIP_336PX_LETTERBOX = Exp_7B_DINOCLIP_ViT_L_p14_336px_Letterbox
527
- EXP_DINOCLIP_336PX_RESIZE_NAIVE = Exp_7B_DINOCLIP_ViT_L_p14_336px_Resize_Naive
528
- EXP_DINOSIGLIP_384PX_LETTERBOX = Exp_7B_DINOSigLIP_ViT_L_p14_384px_Letterbox
529
- EXP_DINOSIGLIP_384PX_RESIZE_NAIVE = Exp_7B_DINOSigLIP_ViT_L_p14_384px_Resize_Naive
530
-
531
- # === Section 4.3 :: Language Models ===
532
- EXP_LLAMA2_7B = Exp_7B_Llama2
533
- EXP_LLAMA2_13B = Exp_13B_Llama2
534
-
535
- # ~ Additional LLM Backbone Experiments :: LLaMa-2 Chat, Mistral v0.1, Mistral v0.1 Instruct ~
536
- EXT_EXP_LLAMA2_CHAT_7B = Ext_Exp_7B_Llama2_Chat
537
- EXT_EXP_LLAMA2_CHAT_13B = Ext_Exp_13B_Llama2_Chat
538
- EXT_EXP_MISTRAL_V1_7B = Ext_Exp_7B_Mistral_V1
539
- EXT_EXP_MISTRAL_INSTRUCT_V1_7B = Ext_Exp_7B_Mistral_Instruct_V1
540
- EXT_EXP_PHI_2_3B = Ext_Exp_3B_Phi_2
541
-
542
- # Cotraining w/ Unimodal Data
543
- EXP_VICUNA_NO_COTRAINING_7B = Exp_7B_Vicuna_No_Cotraining
544
- EXP_LLAMA2_NO_COTRAINING_7B = Exp_7B_Llama2_No_Cotraining
545
-
546
- # === Section 4.4 :: Scaling Properties - Train Time & Data ===
547
- EXP_1P25_EPOCHS = Exp_7B_1p25_Epochs
548
- EXP_1P5_EPOCHS = Exp_7B_1p5_Epochs
549
- EXP_2_EPOCHS = Exp_7B_2_Epochs
550
- EXP_3_EPOCHS = Exp_7B_3_Epochs
551
-
552
- EXP_LLAVA_LVIS4V = Exp_7B_LLaVa_LVIS4V
553
- EXP_LLAVA_LRV = Exp_7B_LLaVa_LRV
554
- EXP_LLAVA_LVIS4V_LRV = Exp_7B_LLaVa_LVIS4V_LRV
555
-
556
- # === Section 5 :: Prisms ===
557
- PRISM_CLIP_CONTROLLED_7B = Prism_7B_CLIP_Controlled
558
- PRISM_CLIP_CONTROLLED_13B = Prism_13B_CLIP_Controlled
559
- PRISM_CLIP_7B = Prism_7B_CLIP
560
- PRISM_CLIP_13B = Prism_13B_CLIP
561
-
562
- PRISM_SIGLIP_CONTROLLED_7B = Prism_7B_SigLIP_Controlled
563
- PRISM_SIGLIP_CONTROLLED_13B = Prism_13B_SigLIP_Controlled
564
- PRISM_SIGLIP_7B = Prism_7B_SigLIP
565
- PRISM_SIGLIP_13B = Prism_13B_SigLIP
566
-
567
- PRISM_DINOSIGLIP_CONTROLLED_7B = Prism_7B_DINOSigLIP_Controlled
568
- PRISM_DINOSIGLIP_CONTROLLED_13B = Prism_13B_DINOSigLIP_Controlled
569
- PRISM_DINOSIGLIP_7B = Prism_7B_DINOSigLIP
570
- PRISM_DINOSIGLIP_13B = Prism_13B_DINOSigLIP
571
-
572
- # === Inference Optimized :: 224px Prisms ===
573
- OPT_DINOSIGLIP_224PX_RESIZE_NAIVE = Opt_7B_DINOSigLIP_ViT_SO_p14_224px_Resize_Naive
574
- PRISM_DINOSIGLIP_224PX_CONTROLLED_7B = Prism_7B_DINOSigLIP_224px_Controlled
575
- PRISM_DINOSIGLIP_224PX_7B = Prism_7B_DINOSigLIP_224px
576
-
577
- @property
578
- def model_id(self) -> str:
579
- return self.value.model_id
580
-
581
-
582
- # Register Models in Choice Registry
583
- for model_variant in ModelRegistry:
584
- ModelConfig.register_subclass(model_variant.model_id, model_variant.value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/prismatic/conf/vla.py DELETED
@@ -1,235 +0,0 @@
1
- """
2
- vla.py
3
-
4
- Draccus Dataclass Definition for a VLAConfig object, with various registered subclasses for each VLA experiment and
5
- model configuration thereof. A given VLA model (`policy`) configures the following attributes:
6
- - Data Mixture (e.g., Bridge, OXE_MAGIC_SOUP, etc.)
7
- - Base VLM from Prismatic Registry (e.g., `prism-dinosiglip+7b`)
8
- - VLA Model Architecture / Parameters (e.g., freeze vision encoder, last layer finetuning)
9
- - Training / Optimization Hyperparameters
10
- """
11
-
12
- from dataclasses import dataclass
13
- from enum import Enum, unique
14
- from pathlib import Path
15
- from typing import Optional, Union
16
-
17
- from draccus import ChoiceRegistry
18
-
19
-
20
- @dataclass
21
- class VLAConfig(ChoiceRegistry):
22
- # fmt: off
23
- vla_id: str # Unique VLA Policy ID that fully specifies a configuration variant
24
- base_vlm: Union[str, Path] # Base VLM as ID/Path to Run Directory (e.g., `prism-dinosiglip+7b`)
25
- freeze_vision_backbone: bool # Freeze Vision Backbone Parameters (akin to pretraining)
26
- freeze_llm_backbone: bool # Freeze LLM Backbone parameters
27
- unfreeze_last_llm_layer: bool # Unfreeze final layer of LLM (only takes effect if LLM is frozen)
28
-
29
- # Data Mixture Parameters
30
- data_mix: str # Open-X Embodiment Dataset =>> Unique Mixture ID (e.g., `bridge`)
31
- shuffle_buffer_size: int # Size of Shuffle Buffer (100K for Bridge, 1M for OXE)
32
-
33
- # Optimization Parameters
34
- epochs: int # Epochs to Run (in case `max_steps` is not specified)
35
- max_steps: Optional[int] # [Optional] Max Gradient Steps to Run (overrides `epochs`)
36
-
37
- expected_world_size: int # Expected # of GPUs =>> allows us to gate training on hardware
38
- global_batch_size: int # Global Batch Size (divided across processes / world size)
39
- per_device_batch_size: int # Per-Device Batch Size (per-process / individual GPU)
40
- # =>> # of accumulation steps is auto-computed
41
-
42
- learning_rate: float # Peak Learning Rate (`lr_scheduler_type` sets warmup/decay)
43
- weight_decay: float # Weight Decay for AdamW Optimizer
44
- max_grad_norm: float # Max Grad Norm (for global gradient clipping)
45
- lr_scheduler_type: str # LR Scheduler (usually: "constant" | "linear-warmup+cosine-decay")
46
- warmup_ratio: float # Fraction of Steps to Warmup (for warmup LR schedulers)
47
-
48
- train_strategy: str # Train Strategy (default "fsdp-full-shard")
49
-
50
- # Enable Gradient/Activation Checkpointing (for the LLM Backbone)
51
- enable_gradient_checkpointing: bool = True # Enable Gradient/Activation Checkpointing during Training
52
-
53
- # Mixed Precision Training via Torch Native AMP (`autocast`)
54
- enable_mixed_precision_training: bool = True # Enable Traditional BF16 Mixed Precision
55
- reduce_in_full_precision: bool = True # Accumulate/Reduce All-Gather Gradients in FP32 Full Precision
56
-
57
- # fmt: on
58
-
59
-
60
- # === OpenVLA Training Configurations ===
61
-
62
-
63
- # = [8 GPU] Fast Iteration =>> SigLIP 224px + Bridge =
64
- @dataclass
65
- class Exp_SigLIP_224px_Bridge(VLAConfig):
66
- vla_id: str = "siglip-224px+mx-bridge"
67
- base_vlm: Union[str, Path] = "siglip-224px+7b"
68
-
69
- freeze_vision_backbone: bool = False
70
- freeze_llm_backbone: bool = False
71
- unfreeze_last_llm_layer: bool = False
72
-
73
- # Data Mixture Parameters
74
- data_mix: str = "bridge"
75
- shuffle_buffer_size: int = 256_000
76
-
77
- # Optimization Parameters
78
- epochs: int = 1000
79
- max_steps: Optional[int] = None
80
-
81
- expected_world_size: int = 8
82
- global_batch_size: int = 256
83
- per_device_batch_size: int = 32
84
-
85
- learning_rate: float = 2e-5
86
- weight_decay: float = 0.0
87
- max_grad_norm: float = 1.0
88
- lr_scheduler_type: str = "constant"
89
- warmup_ratio: float = 0.0
90
-
91
- train_strategy: str = "fsdp-full-shard"
92
-
93
-
94
- # = [8 GPU] SigLIP 224px Frozen Vision Backbone + Bridge =
95
- @dataclass
96
- class Exp_FreezeVIT_SigLIP_224px_Bridge(Exp_SigLIP_224px_Bridge):
97
- vla_id: str = "siglip-224px-icy+mx-bridge"
98
- base_vlm: Union[str, Path] = "siglip-224px+7b"
99
- freeze_vision_backbone: bool = True
100
-
101
-
102
- # = [8 GPU] Fast Iteration =>> DINO-SigLIP 224px + Bridge =
103
- @dataclass
104
- class Exp_DinoSigLIP_224px_Bridge(Exp_SigLIP_224px_Bridge):
105
- vla_id: str = "prism-dinosiglip-224px+mx-bridge"
106
- base_vlm: Union[str, Path] = "prism-dinosiglip-224px+7b"
107
-
108
- data_mix: str = "bridge"
109
-
110
-
111
- # = [64 GPU] SigLIP 224px + OXE Magic Soup =
112
- @dataclass
113
- class Exp_SigLIP_224px_OXE_Magic_Soup(Exp_SigLIP_224px_Bridge):
114
- vla_id: str = "siglip-224px+mx-oxe-magic-soup"
115
- base_vlm: Union[str, Path] = "siglip-224px+7b"
116
-
117
- data_mix: str = "oxe_magic_soup"
118
-
119
- expected_world_size: int = 64
120
- global_batch_size: int = 2048
121
- per_device_batch_size: int = 32
122
-
123
-
124
- # = [64 GPU] DINO-SigLIP 224px + OXE Magic Soup++ =
125
- @dataclass
126
- class Exp_DinoSigLIP_224px_OXE_Magic_Soup_Plus(Exp_SigLIP_224px_Bridge):
127
- vla_id: str = "prism-dinosiglip-224px+mx-oxe-magic-soup-plus"
128
- base_vlm: Union[str, Path] = "prism-dinosiglip-224px+7b"
129
-
130
- # Note =>> We adopt two stages, training on a mixture including DROID for 70% of training, before resampling!
131
- # data_mix: str = "oxe_magic_soup_plus"
132
- data_mix: str = "oxe_magic_soup_plus_minus"
133
-
134
- expected_world_size: int = 64
135
- global_batch_size: int = 2048
136
- per_device_batch_size: int = 32
137
-
138
-
139
- # === OpenVLA Fine-tuning Configurations ===
140
-
141
-
142
- # = [8 GPU] SigLIP 224px + T-DROID =
143
- @dataclass
144
- class Exp_SigLIP_224px_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge):
145
- vla_id: str = "siglip-224px+mx-tdroid_carrot_in_bowl"
146
- base_vlm: Union[str, Path] = "siglip-224px+7b"
147
-
148
- data_mix: str = "tdroid_carrot_in_bowl"
149
-
150
-
151
- @dataclass
152
- class Exp_SigLIP_224px_TDROID_PourCornInPot(Exp_SigLIP_224px_Bridge):
153
- vla_id: str = "siglip-224px+mx-tdroid_pour_corn_in_pot"
154
- base_vlm: Union[str, Path] = "siglip-224px+7b"
155
-
156
- data_mix: str = "tdroid_pour_corn_in_pot"
157
-
158
-
159
- # = [8 GPU] SigLIP 224px + T-DROID -- Partial Finetuning =
160
- @dataclass
161
- class Exp_SigLIP_224px_Icy_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge):
162
- vla_id: str = "siglip-224px-icy+mx-tdroid_carrot_in_bowl"
163
- base_vlm: Union[str, Path] = "siglip-224px+7b"
164
- freeze_vision_backbone: bool = True
165
- freeze_llm_backbone: bool = False
166
-
167
- data_mix: str = "tdroid_carrot_in_bowl"
168
-
169
-
170
- @dataclass
171
- class Exp_SigLIP_224px_LastLayer_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge):
172
- vla_id: str = "siglip-224px-last_layer+mx-tdroid_carrot_in_bowl"
173
- base_vlm: Union[str, Path] = "siglip-224px+7b"
174
- freeze_vision_backbone: bool = True
175
- freeze_llm_backbone: bool = True
176
- unfreeze_last_llm_layer: bool = True
177
-
178
- data_mix: str = "tdroid_carrot_in_bowl"
179
-
180
-
181
- @dataclass
182
- class Exp_SigLIP_224px_Sandwich_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge):
183
- vla_id: str = "siglip-224px-sandwich+mx-tdroid_carrot_in_bowl"
184
- base_vlm: Union[str, Path] = "siglip-224px+7b"
185
- freeze_vision_backbone: bool = False
186
- freeze_llm_backbone: bool = True
187
- unfreeze_last_llm_layer: bool = True
188
-
189
- data_mix: str = "tdroid_carrot_in_bowl"
190
-
191
-
192
- # === [8 GPU] SigLIP 224px + FrankaWipe ===
193
- @dataclass
194
- class Exp_SigLIP_224px_Droid_Wipe(Exp_SigLIP_224px_Bridge):
195
- vla_id: str = "siglip-224px+mx-droid_wipe"
196
- base_vlm: Union[str, Path] = "siglip-224px+7b"
197
-
198
- data_mix: str = "droid_wipe"
199
-
200
-
201
- # === Define a VLA Registry Enum for Reference & Validation ===
202
- @unique
203
- class VLARegistry(Enum):
204
- # Sanity Check Configurations =>> BridgeV2
205
- SIGLIP_224PX_MX_BRIDGE = Exp_SigLIP_224px_Bridge
206
- DINOSIGLIP_224PX_MX_BRIDGE = Exp_DinoSigLIP_224px_Bridge
207
-
208
- # SigLIP Frozen Backbone Experiment
209
- FREEZE_SIGLIP_224PX_MX_BRIDGE = Exp_FreezeVIT_SigLIP_224px_Bridge
210
-
211
- # [OpenVLA v0.1 7B] SigLIP 224px + OXE Magic Soup
212
- SIGLIP_224PX_MX_OXE_MAGIC_SOUP = Exp_SigLIP_224px_OXE_Magic_Soup
213
-
214
- # [OpenVLA 7B] DINO + SigLIP 224px + OXE Magic Soup++
215
- DINOSIGLIP_224PX_MX_OXE_MAGIC_SOUP_PLUS = Exp_DinoSigLIP_224px_OXE_Magic_Soup_Plus
216
-
217
- # === TDROID Fine-tuning Configs ===
218
- SIGLIP_224PX_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_TDROID_CarrotInBowl
219
- SIGLIP_224PX_MX_TDROID_POUR_CORN_IN_POT = Exp_SigLIP_224px_TDROID_PourCornInPot
220
-
221
- SIGLIP_224PX_ICY_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_Icy_TDROID_CarrotInBowl
222
- SIGLIP_224PX_LASTLAYER_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_LastLayer_TDROID_CarrotInBowl
223
- SIGLIP_224PX_SANDWICH_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_Sandwich_TDROID_CarrotInBowl
224
-
225
- # === DROID Fine-tuning Configs ===
226
- SIGLIP_224PX_MX_DROID_WIPE = Exp_SigLIP_224px_Droid_Wipe
227
-
228
- @property
229
- def vla_id(self) -> str:
230
- return self.value.vla_id
231
-
232
-
233
- # Register VLAs in Choice Registry
234
- for vla_variant in VLARegistry:
235
- VLAConfig.register_subclass(vla_variant.vla_id, vla_variant.value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/prismatic/extern/__init__.py DELETED
File without changes
capvector-oft/prismatic/extern/hf/__init__.py DELETED
File without changes
capvector-oft/prismatic/extern/hf/configuration_prismatic.py DELETED
@@ -1,140 +0,0 @@
1
- """
2
- configuration_prismatic.py
3
-
4
- HuggingFace-style configuration definition for Prismatic VLMs, inheriting from `transformers.PretrainedConfig`.
5
- Default configuration specifies `siglip-224px+7b`.
6
- """
7
-
8
- from typing import Any, Dict, List, Optional
9
-
10
- from transformers import PretrainedConfig
11
- from transformers.models.auto import CONFIG_MAPPING
12
-
13
- # === Utilities for Mapping Prismatic names to HF names ===
14
- # fmt: off
15
- VISION_BACKBONE_TO_RESOLUTION: Dict[str, List[int]] = {
16
- "clip-vit-l": [224], "siglip-vit-so400m": [224], "dinov2-vit-l": [224], "in1k-vit-l": [224],
17
-
18
- "clip-vit-l-336px": [336],
19
- "siglip-vit-so400m-384px": [384],
20
-
21
- "dinoclip-vit-l-336px": [336, 336],
22
- "dinosiglip-vit-so-224px": [224, 224],
23
- "dinosiglip-vit-so-384px": [384, 384],
24
- }
25
- VISION_BACKBONE_TO_TIMM_ID: Dict[str, List[str]] = {
26
- "clip-vit-l": ["vit_large_patch14_clip_224.openai"],
27
- "clip-vit-l-336px": ["vit_large_patch14_clip_336.openai"],
28
-
29
- "dinov2-vit-l": ["vit_large_patch14_reg4_dinov2.lvd142m"],
30
- "in1k-vit-l": ["vit_large_patch16_224.augreg_in21k_ft_in1k"],
31
-
32
- "siglip-vit-so400m": ["vit_so400m_patch14_siglip_224"],
33
- "siglip-vit-so400m-384px": ["vit_so400m_patch14_siglip_384"],
34
-
35
- "dinoclip-vit-l-336px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_large_patch14_clip_336.openai"],
36
- "dinosiglip-vit-so-224px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_224"],
37
- "dinosiglip-vit-so-384px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_384"],
38
- }
39
- TIMM_OVERRIDE_ACT_LAYER: Dict[str, List[Optional[str]]] = {
40
- "clip-vit-l": ["quick_gelu"], "clip-vit-l-336px": ["quick_gelu"],
41
- "dinov2-vit-l": [None], "in1k-vit-l": [None],
42
- "siglip-vit-so400m": [None], "siglip-vit-so400m-384px": [None],
43
- "dinoclip-vit-l-336px": [None, "quick_gelu"],
44
- "dinosiglip-vit-so-224px": [None, None], "dinosiglip-vit-so-384px": [None, None]
45
- }
46
-
47
- LLM_BACKBONE_TO_HF_PATH = {
48
- "llama2-7b-pure": "meta-llama/Llama-2-7b-hf", "llama2-13b-pure": "meta-llama/Llama-2-13b-hf",
49
- "llama2-7b-chat": "meta-llama/Llama-2-7b-chat-hf", "llama2-13b-chat": "meta-llama/Llama-2-13b-chat-hf",
50
-
51
- "vicuna-v15-7b": "lmsys/vicuna-7b-v1.5", "vicuna-v15-13b": "lmsys/vicuna-13b-v1.5",
52
-
53
- "mistral-v0.1-7b-pure": "mistralai/Mistral-7B-v0.1",
54
- "mistral-v0.1-7b-instruct": "mistralai/Mistral-7B-Instruct-v0.1",
55
-
56
- "phi-2-3b": "microsoft/phi-2",
57
- }
58
- LLM_BACKBONE_TO_HF_METACLASS = {
59
- "llama2-7b-pure": "llama", "llama2-13b-pure": "llama", "llama2-7b-chat": "llama", "llama2-13b-chat": "llama",
60
- "vicuna-v15-7b": "llama", "vicuna-v15-13b": "llama",
61
-
62
- "mistral-v0.1-7b-pure": "mistral", "mistral-v0.1-7b-instruct": "mistral",
63
-
64
- "phi-2-3b": "phi",
65
- }
66
-
67
- VALID_VISION_BACKBONES = set(VISION_BACKBONE_TO_RESOLUTION.keys())
68
- VALID_LLM_BACKBONES = set(LLM_BACKBONE_TO_HF_PATH)
69
- # fmt: on
70
-
71
-
72
- class PrismaticConfig(PretrainedConfig):
73
- model_type: str = "prismatic"
74
- is_composition: bool = False
75
-
76
- def __init__(
77
- self,
78
- vision_backbone_id: str = "siglip-vit-so400m",
79
- llm_backbone_id: str = "vicuna-v15-7b",
80
- arch_specifier: str = "no-align+gelu-mlp",
81
- use_fused_vision_backbone: Optional[bool] = None,
82
- image_resize_strategy: str = "letterbox",
83
- text_config: Optional[Dict[str, Any]] = None,
84
- llm_max_length: int = 2048,
85
- pad_token_id: int = 32000,
86
- pad_to_multiple_of: int = 64,
87
- output_projector_states: bool = False,
88
- **kwargs: str,
89
- ) -> None:
90
- if vision_backbone_id not in VALID_VISION_BACKBONES:
91
- raise ValueError(f"Vision backbone `{vision_backbone_id}` not in {VALID_VISION_BACKBONES = }")
92
-
93
- if llm_backbone_id not in VALID_LLM_BACKBONES:
94
- raise ValueError(f"LLM backbone `{llm_backbone_id}` not in {VALID_LLM_BACKBONES = }")
95
-
96
- # Set Prismatic Configuration Fields
97
- self.vision_backbone_id = vision_backbone_id
98
- self.llm_backbone_id = llm_backbone_id
99
- self.arch_specifier = arch_specifier
100
- self.output_projector_states = output_projector_states
101
-
102
- # [Contract] All vision backbone parameters are lists =>> supports fused backbones with different preprocessing
103
- self.use_fused_vision_backbone = (
104
- use_fused_vision_backbone
105
- if use_fused_vision_backbone is not None
106
- else any(self.vision_backbone_id.startswith(v) for v in ["dinoclip", "dinosiglip"])
107
- )
108
-
109
- self.timm_model_ids = VISION_BACKBONE_TO_TIMM_ID[self.vision_backbone_id]
110
- self.timm_override_act_layers = TIMM_OVERRIDE_ACT_LAYER[self.vision_backbone_id]
111
- self.image_sizes = VISION_BACKBONE_TO_RESOLUTION[self.vision_backbone_id]
112
- self.image_resize_strategy = image_resize_strategy
113
-
114
- self.hf_llm_id = LLM_BACKBONE_TO_HF_PATH[self.llm_backbone_id]
115
- self.llm_max_length = llm_max_length
116
- self.pad_token_id, self.pad_to_multiple_of = pad_token_id, pad_to_multiple_of
117
-
118
- # [IMPORTANT] HF Utilities actually look for a `text_config` field... we need to use that specific naming!
119
- self.text_config = (
120
- CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]](**text_config)
121
- if text_config is not None
122
- else CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]]()
123
- )
124
-
125
- # Dispatch **kwargs to super() =>> note that `pad_token_id` collides, so we pass it in here as well...
126
- super().__init__(pad_token_id=pad_token_id, **kwargs)
127
-
128
-
129
- class OpenVLAConfig(PrismaticConfig):
130
- model_type: str = "openvla"
131
-
132
- def __init__(
133
- self,
134
- norm_stats: Optional[Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]]] = None,
135
- n_action_bins: int = 256,
136
- **kwargs: str,
137
- ) -> None:
138
- self.norm_stats, self.n_action_bins = norm_stats, n_action_bins
139
-
140
- super().__init__(**kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/prismatic/extern/hf/modeling_prismatic.py DELETED
@@ -1,1085 +0,0 @@
1
- """
2
- modeling_prismatic.py
3
-
4
- Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions.
5
- Inherits from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained,
6
- but exactly replicate the logic in `prismatic.models.vlms.prismatic.py`.
7
- """
8
-
9
- import logging
10
- from dataclasses import dataclass
11
- from functools import partial
12
- from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union
13
-
14
- import numpy as np
15
- import timm
16
- import tokenizers
17
- import torch
18
- import torch.nn as nn
19
- import transformers
20
- from timm.models.vision_transformer import LayerScale
21
- from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
22
- from transformers.modeling_outputs import ModelOutput
23
-
24
- from prismatic.training.train_utils import (
25
- get_current_action_mask,
26
- get_next_actions_mask,
27
- )
28
- from prismatic.vla.constants import (
29
- ACTION_DIM,
30
- ACTION_PROPRIO_NORMALIZATION_TYPE,
31
- ACTION_TOKEN_BEGIN_IDX,
32
- IGNORE_INDEX,
33
- NUM_ACTIONS_CHUNK,
34
- STOP_INDEX,
35
- NormalizationType,
36
- )
37
-
38
- from .configuration_prismatic import OpenVLAConfig, PrismaticConfig
39
-
40
- # Set up logger
41
- logger = logging.getLogger(__name__)
42
-
43
-
44
- # === Utility Functions for Monkey-Patching ===
45
- def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
46
- def wrapper(*args: Any, **kwargs: Any) -> Any:
47
- result = fn(*args, **kwargs)
48
- return result[0] if isinstance(result, tuple) else result
49
-
50
- return wrapper
51
-
52
-
53
- # HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale.
54
- # =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109
55
- # =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960
56
- def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor:
57
- return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor
58
-
59
-
60
- def ls_apply_patch(ls_module: LayerScale):
61
- ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone())
62
- ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale)
63
- del ls_module.gamma
64
-
65
-
66
- # === Prismatic Vision Backbone (nn.Module) Definitions (w/ Fused Backbone Support) ===
67
- class PrismaticVisionBackbone(nn.Module):
68
- """
69
- Vision backbone for Prismatic models that handles image feature extraction.
70
-
71
- Supports both single backbone (e.g., SigLIP) and fused backbone (e.g., SigLIP + DINOv2) configurations.
72
- For fused backbones, features from both models are concatenated along the feature dimension.
73
- """
74
-
75
- def __init__(
76
- self,
77
- use_fused_vision_backbone: bool,
78
- image_sizes: List[int],
79
- timm_model_ids: List[str],
80
- timm_override_act_layers: List[Optional[str]],
81
- ) -> None:
82
- """
83
- Initialize the vision backbone.
84
-
85
- Args:
86
- use_fused_vision_backbone: Whether to use two backbones and fuse their features
87
- image_sizes: List of image sizes for each backbone
88
- timm_model_ids: List of TIMM model IDs to use for each backbone
89
- timm_override_act_layers: List of activation layer overrides for each backbone
90
- """
91
- super().__init__()
92
- self.use_fused_vision_backbone = use_fused_vision_backbone
93
- self.num_images_in_input = 1 # Default value, can be overridden later
94
-
95
- # Validate number of (fused) vision backbones
96
- if len(timm_model_ids) > 2:
97
- raise ValueError("Prismatic models only support up to 2 (fused) vision backbones!")
98
-
99
- # Create primary featurizer
100
- self.featurizer = self._create_featurizer(
101
- model_id=timm_model_ids[0], img_size=image_sizes[0], act_layer=timm_override_act_layers[0]
102
- )
103
- self.embed_dim = self.featurizer.embed_dim
104
-
105
- # Create secondary featurizer if using fused backbone
106
- if self.use_fused_vision_backbone:
107
- self.fused_featurizer = self._create_featurizer(
108
- model_id=timm_model_ids[1], img_size=image_sizes[1], act_layer=timm_override_act_layers[1]
109
- )
110
- self.embed_dim += self.fused_featurizer.embed_dim
111
-
112
- # Patch LayerScale modules for HF compatibility
113
- self._patch_layer_scales()
114
-
115
- def _create_featurizer(self, model_id: str, img_size: int, act_layer: Optional[str]) -> nn.Module:
116
- """
117
- Create a TIMM-based featurizer model with appropriate configurations.
118
-
119
- Args:
120
- model_id: The TIMM model ID to load
121
- img_size: Input image size for the model
122
- act_layer: Override for the activation layer type
123
-
124
- Returns:
125
- A configured featurizer model
126
- """
127
- featurizer = timm.create_model(
128
- model_id,
129
- pretrained=False,
130
- num_classes=0,
131
- img_size=img_size,
132
- act_layer=act_layer,
133
- )
134
-
135
- # Monkey-patch the forward function to extract the second-to-last layer features
136
- num_blocks = len(featurizer.blocks)
137
- featurizer.forward = unpack_tuple(partial(featurizer.get_intermediate_layers, n={num_blocks - 2}))
138
-
139
- return featurizer
140
-
141
- def _patch_layer_scales(self) -> None:
142
- """
143
- Patch all LayerScale modules to be compatible with HF's parameter naming.
144
-
145
- HF Transformers overwrites parameters with names containing 'gamma',
146
- so we need to rename and modify the forward method.
147
- """
148
- # Patch primary featurizer
149
- for module in self.featurizer.modules():
150
- if isinstance(module, LayerScale):
151
- ls_apply_patch(module)
152
-
153
- # Patch secondary featurizer if it exists
154
- if self.use_fused_vision_backbone:
155
- for module in self.fused_featurizer.modules():
156
- if isinstance(module, LayerScale):
157
- ls_apply_patch(module)
158
-
159
- def get_num_patches(self) -> int:
160
- """
161
- Returns the number of vision patches output by the vision backbone.
162
-
163
- Returns:
164
- Number of patches per image
165
- """
166
- return self.featurizer.patch_embed.num_patches
167
-
168
- def get_num_images_in_input(self) -> int:
169
- """
170
- Returns the number of input images for the vision backbone.
171
-
172
- Returns:
173
- Number of images expected in the input
174
- """
175
- return self.num_images_in_input
176
-
177
- def set_num_images_in_input(self, num_images_in_input: int) -> None:
178
- """
179
- Sets the number of input images for the vision backbone.
180
-
181
- Args:
182
- num_images_in_input: Number of images to expect in the input
183
- """
184
- self.num_images_in_input = num_images_in_input
185
-
186
- def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
187
- """
188
- Implements the forward pass for the vision backbone.
189
-
190
- If `self.use_fused_vision_backbone == True`, uses both SigLIP and DINOv2 transformers to extract visual features
191
- (otherwise uses SigLIP only). Allows multi-image inputs (but only for fused vision backbone).
192
-
193
- Args:
194
- pixel_values (torch.Tensor): Pixels for input image(s), (B, C, H, W).
195
- """
196
- if self.num_images_in_input == 1:
197
- if not self.use_fused_vision_backbone:
198
- return self.featurizer(pixel_values)
199
-
200
- # Split `pixel_values :: [bsz, 2 * 3, resolution, resolution]` =>> featurize =>> channel stack
201
- img, img_fused = torch.split(pixel_values, [3, 3], dim=1)
202
- patches, patches_fused = self.featurizer(img), self.fused_featurizer(img_fused)
203
-
204
- return torch.cat([patches, patches_fused], dim=2)
205
-
206
- else:
207
- assert self.use_fused_vision_backbone, "Multi-image inputs require using fused backbone!"
208
-
209
- # Split `pixel_values` into individual images (each with 6 channels: 3 for SigLIP + 3 for DINOv2)
210
- images = torch.split(pixel_values, [6] * self.num_images_in_input, dim=1)
211
-
212
- # Process each image and collect patches
213
- all_patches = []
214
- for img in images:
215
- # Split each image further into two stacks of channels (each with 3 channels)
216
- img_regular, img_fused = torch.split(img, [3, 3], dim=1)
217
-
218
- # Get patches from both SigLIP and DINOv2 vision transformers
219
- patches = self.featurizer(img_regular)
220
- patches_fused = self.fused_featurizer(img_fused)
221
-
222
- # Concatenate SigLIP and DINOv2 patches along the hidden dimension
223
- combined_patches = torch.cat([patches, patches_fused], dim=2)
224
- all_patches.append(combined_patches)
225
-
226
- # Concatenate all patches along the patch dimension
227
- return torch.cat(all_patches, dim=1)
228
-
229
-
230
- # === Prismatic Projector (nn.Module) Definitions ===
231
- class PrismaticProjector(nn.Module):
232
- def __init__(self, use_fused_vision_backbone: bool, vision_dim: int, llm_dim: int) -> None:
233
- super().__init__()
234
- self.use_fused_vision_backbone = use_fused_vision_backbone
235
- self.vision_dim, self.llm_dim = vision_dim, llm_dim
236
-
237
- # Switch on `use_fused_vision_backbone` =>> use slightly different MLPs and projection factors!
238
- if not self.use_fused_vision_backbone:
239
- self.fc1 = nn.Linear(self.vision_dim, self.llm_dim, bias=True)
240
- self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
241
- self.act_fn1 = nn.GELU()
242
- else:
243
- initial_projection_dim = 4 * vision_dim
244
- self.fc1 = nn.Linear(self.vision_dim, initial_projection_dim, bias=True)
245
- self.fc2 = nn.Linear(initial_projection_dim, self.llm_dim, bias=True)
246
- self.fc3 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
247
- self.act_fn1 = nn.GELU()
248
- self.act_fn2 = nn.GELU()
249
-
250
- def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
251
- if not self.use_fused_vision_backbone:
252
- projected_features = self.fc1(img_patches)
253
- projected_features = self.act_fn1(projected_features)
254
- projected_features = self.fc2(projected_features)
255
- else:
256
- projected_features = self.fc1(img_patches)
257
- projected_features = self.act_fn1(projected_features)
258
- projected_features = self.fc2(projected_features)
259
- projected_features = self.act_fn2(projected_features)
260
- projected_features = self.fc3(projected_features)
261
-
262
- return projected_features
263
-
264
-
265
- # === Main HF Class Definitions ===
266
- @dataclass
267
- class PrismaticCausalLMOutputWithPast(ModelOutput):
268
- """Base class for Prismatic casual (visually-conditioned) language model outputs; also exposes visual features."""
269
-
270
- loss: Optional[torch.FloatTensor] = None
271
- logits: torch.FloatTensor = None
272
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
273
- hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
274
- attentions: Optional[Tuple[torch.FloatTensor]] = None
275
-
276
- # Additions for VLMs
277
- projector_features: Optional[torch.FloatTensor] = None
278
-
279
-
280
- class PrismaticPreTrainedModel(PreTrainedModel):
281
- config_class: PretrainedConfig = PrismaticConfig
282
- base_model_prefix: str = "model"
283
- supports_gradient_checkpointing: bool = True
284
-
285
- _no_split_modules: ClassVar[List[str]] = ["PrismaticProjector"]
286
- _skip_keys_device_placement: str = "past_key_values"
287
- _supports_flash_attn_2: bool = True
288
-
289
- def _init_weights(self, module: nn.Module) -> None:
290
- # Important :: this HF ported version is *not* meant for training from scratch; only inference and fine-tuning!
291
- # => As such, this init_weights code is not correct; if training VLMs from scratch, use the main codebase at
292
- # https://github.com/TRI-ML/prismatic-vlms
293
- std = (
294
- self.config.initializer_range
295
- if hasattr(self.config, "initializer_range")
296
- else self.config.text_config.initializer_range
297
- )
298
-
299
- if hasattr(module, "class_embedding"):
300
- module.class_embedding.data.normal_(mean=0.0, std=std)
301
-
302
- if isinstance(module, (nn.Linear, nn.Conv2d)):
303
- module.weight.data.normal_(mean=0.0, std=std)
304
- if module.bias is not None:
305
- module.bias.data.zero_()
306
- elif isinstance(module, nn.Embedding):
307
- module.weight.data.normal_(mean=0.0, std=std)
308
- if module.padding_idx is not None:
309
- module.weight.data[module.padding_idx].zero_()
310
-
311
- @property
312
- def _supports_sdpa(self) -> bool:
313
- """Check LLM supports SDPA Attention"""
314
- return self.language_model._supports_sdpa
315
-
316
-
317
- class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
318
- def __init__(self, config: PrismaticConfig) -> None:
319
- super().__init__(config)
320
-
321
- # [Validation] Lightweight Validate on `config` Fields + Dependency Versions
322
- if config.use_fused_vision_backbone is None:
323
- raise ValueError("Missing config field `use_fused_vision_backbone`")
324
-
325
- if timm.__version__ not in {"0.9.10", "0.9.11", "0.9.12", "0.9.16"}:
326
- raise NotImplementedError(
327
- "TIMM Version must be >= 0.9.10 and < 1.0.0 (breaking); please raise a GitHub Issue "
328
- "if you urgently need support for latest TIMM versions."
329
- )
330
-
331
- if (transformers.__version__ != "4.40.1") or (tokenizers.__version__ != "0.19.1"):
332
- logger.warning(
333
- f"Expected `transformers==4.40.1` and `tokenizers==0.19.1` but got "
334
- f"`transformers=={transformers.__version__}` and `tokenizers=={tokenizers.__version__}`; "
335
- f"there might be inference-time regressions due to dependency changes. If in doubt, please"
336
- f"use the above versions."
337
- )
338
-
339
- # Instantiate PrismaticVisionBackbone (w/ Potential Fused Backbone)
340
- self.vision_backbone = PrismaticVisionBackbone(
341
- config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers
342
- )
343
-
344
- # Create Multimodal Projector
345
- self.projector = PrismaticProjector(
346
- config.use_fused_vision_backbone,
347
- vision_dim=self.vision_backbone.embed_dim,
348
- llm_dim=config.text_config.hidden_size,
349
- )
350
-
351
- # Instantiate LLM Backbone
352
- self.language_model = AutoModelForCausalLM.from_config(
353
- config.text_config, attn_implementation=config._attn_implementation
354
- )
355
- self.vocab_size = config.text_config.vocab_size
356
- self.pad_token_id = config.pad_token_id
357
- self.llm_dim = config.text_config.hidden_size
358
-
359
- # HF Boilerplate =>> initializes weights via `_init_weights()` and sets gradient checkpointing
360
- self.post_init()
361
-
362
- # === `PreTrainedModel` Boilerplate ===
363
- def get_input_embeddings(self) -> nn.Module:
364
- return self.language_model.get_input_embeddings()
365
-
366
- def set_input_embeddings(self, value: nn.Module) -> None:
367
- self.language_model.set_input_embeddings(value)
368
-
369
- def get_output_embeddings(self) -> nn.Module:
370
- return self.language_model.get_output_embeddings()
371
-
372
- def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
373
- self.language_model.set_output_embeddings(new_embeddings)
374
-
375
- def get_decoder(self) -> nn.Module:
376
- return self.language_model.get_decoder()
377
-
378
- def set_decoder(self, decoder: nn.Module) -> None:
379
- self.language_model.set_decoder(decoder)
380
-
381
- def tie_weights(self) -> None:
382
- self.language_model.tie_weights() # Note: `Llama-2` and `Mistral` don't tie weights (no-op)
383
-
384
- def resize_token_embeddings(
385
- self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
386
- ) -> nn.Embedding:
387
- updated_embeddings = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
388
-
389
- # Update config/instance variables
390
- self.config.text_config.vocab_size = updated_embeddings.num_embeddings
391
- self.vocab_size = updated_embeddings.num_embeddings
392
-
393
- return updated_embeddings
394
-
395
- def _replace_input_embeddings(self, input_embeddings, all_actions_mask, noisy_action_features):
396
- """
397
- Replace embeddings in input_embeddings at positions where all_actions_mask is True
398
- with embeddings from noisy_action_features, using vectorized operations.
399
-
400
- Args:
401
- input_embeddings: Tensor of shape (B, S, D)
402
- all_actions_mask: Boolean tensor of shape (B, S)
403
- noisy_action_features: Tensor of shape (B, K, D) where K is the number of True values in mask per sample
404
-
405
- Returns:
406
- Modified input_embeddings tensor
407
- """
408
- # Clone input to avoid modifying the original tensor
409
- new_input_embeddings = input_embeddings.clone()
410
-
411
- # Create a tensor with the same shape of input_embeddings to hold the noisy action features
412
- repositioned_noisy_action_features = torch.zeros_like(input_embeddings)
413
-
414
- # Create batch indices for splicing
415
- batch_indices = torch.arange(input_embeddings.shape[0], device=input_embeddings.device)
416
- batch_indices = batch_indices.unsqueeze(1).expand(-1, noisy_action_features.shape[1])
417
-
418
- # Get indices where mask is True for each sample
419
- masked_indices = torch.stack([torch.where(mask)[0] for mask in all_actions_mask])
420
-
421
- # Move the noisy action features into their correct positions
422
- repositioned_noisy_action_features[batch_indices, masked_indices] = noisy_action_features
423
-
424
- # Combine original input embeddings and noisy action embeddings using the mask
425
- new_input_embeddings = torch.where(
426
- all_actions_mask.unsqueeze(-1), repositioned_noisy_action_features, new_input_embeddings
427
- )
428
-
429
- return new_input_embeddings
430
-
431
- def _process_action_masks(self, labels):
432
- """Helper to get action masks from labels"""
433
- current_action_mask = get_current_action_mask(labels)
434
- next_actions_mask = get_next_actions_mask(labels)
435
- all_actions_mask = current_action_mask | next_actions_mask # (B, seq_len)
436
- return all_actions_mask
437
-
438
- def _process_vision_features(self, pixel_values, language_embeddings=None, use_film=False):
439
- """Process vision features with optional FiLM conditioning"""
440
- if use_film:
441
- # FiLM: Infuse language inputs into visual features
442
- patch_features = self.vision_backbone(pixel_values, language_embeddings) # (bsz, 256 * num_images, D)
443
- else:
444
- patch_features = self.vision_backbone(pixel_values) # (bsz, 256 * num_images, D)
445
-
446
- # Project patch embeddings into language embedding space
447
- return self.projector(patch_features)
448
-
449
- def _process_proprio_features(self, projected_patch_embeddings, proprio, proprio_projector):
450
- """Process proprioceptive features and append to vision features"""
451
- if proprio_projector is not None and proprio is not None:
452
- # projected_patch_embeddings: (bsz, num_patches * num_images, llm_dim)
453
- # proprio: (bsz, proprio_dim) or (propro_dim,)
454
- proprio = proprio.reshape(projected_patch_embeddings.shape[0], -1) # (bsz, proprio_dim)
455
- proprio_features = proprio_projector(proprio) # (bsz, llm_dim)
456
- proprio_features = proprio_features.unsqueeze(dim=1) # (bsz, 1, llm_dim)
457
- # For simplicity, just append proprio token to the end of projected vision patch tokens
458
- return torch.cat((projected_patch_embeddings, proprio_features), dim=1)
459
- return projected_patch_embeddings
460
-
461
- def _build_multimodal_attention(self, input_embeddings, projected_patch_embeddings, attention_mask):
462
- """Build multimodal embeddings and attention mask"""
463
- # Update attention mask
464
- projected_patch_attention_mask = None
465
- if attention_mask is not None:
466
- projected_patch_attention_mask = torch.full(
467
- (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
468
- fill_value=True,
469
- dtype=attention_mask.dtype,
470
- device=attention_mask.device,
471
- )
472
-
473
- # Build multimodal embeddings & attention mask; insert embeddings after <BOS> token (1:)
474
- multimodal_embeddings = torch.cat(
475
- [input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1
476
- )
477
-
478
- multimodal_attention_mask = None
479
- if attention_mask is not None:
480
- multimodal_attention_mask = torch.cat(
481
- [attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1
482
- )
483
-
484
- return multimodal_embeddings, multimodal_attention_mask
485
-
486
- def _build_multimodal_labels(self, labels, projected_patch_embeddings):
487
- """Build multimodal labels with IGNORE_INDEX for patch embeddings"""
488
- if labels is not None:
489
- projected_patch_labels = torch.full(
490
- (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
491
- fill_value=IGNORE_INDEX,
492
- dtype=labels.dtype,
493
- device=labels.device,
494
- )
495
- return torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1)
496
- return None
497
-
498
- # === Core Prismatic VLM `forward()` Logic ===
499
- def forward(
500
- self,
501
- input_ids: Optional[torch.LongTensor] = None,
502
- attention_mask: Optional[torch.Tensor] = None,
503
- pixel_values: Optional[torch.FloatTensor] = None,
504
- labels: Optional[torch.LongTensor] = None,
505
- inputs_embeds: Optional[torch.FloatTensor] = None,
506
- past_key_values: Optional[List[torch.FloatTensor]] = None,
507
- use_cache: Optional[bool] = None,
508
- output_attentions: Optional[bool] = None,
509
- output_hidden_states: Optional[bool] = None,
510
- output_projector_features: Optional[bool] = None,
511
- return_dict: Optional[bool] = None,
512
- proprio=None,
513
- proprio_projector=None,
514
- noisy_actions=None,
515
- noisy_action_projector=None,
516
- diffusion_timestep_embeddings=None,
517
- use_film: bool = False,
518
- ) -> Union[Tuple, PrismaticCausalLMOutputWithPast]:
519
- """Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance."""
520
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
521
- output_hidden_states = (
522
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
523
- )
524
- output_projector_features = output_projector_features if output_projector_features is not None else False
525
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
526
-
527
- # Respect `use_cache` only if not training (even if `gradient_checkpointing` is off)
528
- use_cache = use_cache and not self.training
529
-
530
- # Instantiate Placeholder for Projector Features
531
- projected_patch_embeddings = None
532
-
533
- # === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
534
- if input_ids.shape[1] == 1:
535
- assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
536
- assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
537
- assert labels is None, "Unexpected key `labels` provided during cached generation!"
538
-
539
- language_model_output = self.language_model(
540
- input_ids=input_ids,
541
- attention_mask=None,
542
- position_ids=None,
543
- past_key_values=past_key_values,
544
- inputs_embeds=None,
545
- labels=None,
546
- use_cache=use_cache,
547
- output_attentions=output_attentions,
548
- output_hidden_states=output_hidden_states,
549
- return_dict=return_dict,
550
- )
551
-
552
- # === Handle Unimodal Forward ===
553
- elif pixel_values is None:
554
- assert (input_ids is not None) and (inputs_embeds is None), "Missing `input_ids` in language-only forward!"
555
- assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
556
-
557
- language_model_output = self.language_model(
558
- input_ids=input_ids,
559
- attention_mask=attention_mask,
560
- position_ids=None,
561
- past_key_values=None,
562
- inputs_embeds=None,
563
- labels=labels,
564
- use_cache=use_cache,
565
- output_attentions=output_attentions,
566
- output_hidden_states=output_hidden_states,
567
- return_dict=return_dict,
568
- )
569
-
570
- # === Handle Multimodal Forward ===
571
- elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]):
572
- assert past_key_values is None, "Unexpected key `past_key_values` provided during multimodal forward!"
573
-
574
- # Get input embeddings (from language model embeddings)
575
- input_embeddings = self.get_input_embeddings()(input_ids) # (B, seq_len, D)
576
-
577
- # Extract action masks
578
- all_actions_mask = self._process_action_masks(labels)
579
-
580
- # Extract the language portion of the input embeddings (i.e. remove the action tokens portion)
581
- language_embeddings = input_embeddings[~all_actions_mask].reshape(
582
- input_embeddings.shape[0], -1, input_embeddings.shape[2]
583
- ) # (B, lang_seq_len, llm_dim)
584
-
585
- # Get visual features
586
- projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)
587
-
588
- # Add proprioceptive state if provided
589
- projected_patch_embeddings = self._process_proprio_features(
590
- projected_patch_embeddings, proprio, proprio_projector
591
- )
592
-
593
- # [Diffusion] Add diffusion timestep embedding if provided
594
- if diffusion_timestep_embeddings is not None:
595
- # For simplicity, just append diffusion timestep embedding to the end of projected vision patch tokens
596
- projected_patch_embeddings = torch.cat(
597
- (projected_patch_embeddings, diffusion_timestep_embeddings), dim=1
598
- )
599
-
600
- # Process action embeddings
601
- if noisy_actions is not None:
602
- # Get mask corresponding to all action tokens
603
- all_actions_mask = self._process_action_masks(labels)
604
-
605
- # Reshape noisy actions into individual action tokens
606
- # noisy_actions: (B, chunk_len, action_dim) -> (B, chunk_len * action_dim, 1)
607
- B = noisy_actions.shape[0]
608
- noisy_actions = noisy_actions.reshape(B, -1).unsqueeze(-1)
609
-
610
- # Project noisy action tokens into language model embedding space
611
- noisy_action_features = noisy_action_projector(noisy_actions) # (B, chunk_len * action_dim, llm_dim)
612
-
613
- # Replace embeddings of the action tokens with noisy action embeddings
614
- input_embeddings = self._replace_input_embeddings(
615
- input_embeddings, all_actions_mask, noisy_action_features
616
- )
617
- else:
618
- # Replace the embeddings of the action tokens with zeros
619
- # (Later on, the positional embeddings will be added to them)
620
- all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
621
- input_embeddings = input_embeddings * ~all_actions_mask
622
-
623
- # Build multimodal embeddings & attention mask
624
- multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
625
- input_embeddings, projected_patch_embeddings, attention_mask
626
- )
627
-
628
- # Build labels for multimodal sequence if needed
629
- multimodal_labels = self._build_multimodal_labels(labels, projected_patch_embeddings)
630
-
631
- # Dispatch to language model
632
- language_model_output = self.language_model(
633
- input_ids=None,
634
- attention_mask=multimodal_attention_mask,
635
- position_ids=None,
636
- past_key_values=None,
637
- inputs_embeds=multimodal_embeddings,
638
- labels=multimodal_labels,
639
- use_cache=use_cache,
640
- output_attentions=output_attentions,
641
- output_hidden_states=output_hidden_states,
642
- return_dict=return_dict,
643
- )
644
-
645
- # === Otherwise =>> Assume Invalid! ===
646
- elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]):
647
- raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!")
648
-
649
- else:
650
- raise ValueError(
651
- "Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n"
652
- f"=> `input_ids` = {input_ids is not None}\n"
653
- f"=> `attention_mask` = {attention_mask is not None}\n"
654
- f"=> `pixel_values` = {pixel_values is not None}\n"
655
- f"=> `labels` = {labels is not None}\n"
656
- f"=> `input_embeds` = {inputs_embeds is not None}\n"
657
- f"=> `past_key_values` = {past_key_values is not None}\n"
658
- f"=> `use_cache` = {use_cache}"
659
- )
660
-
661
- # Unpack `language_model_output` and return PrismaticCausalLMOutputWithPast (or tuple if not `return_dict`)
662
- if not return_dict:
663
- if output_projector_features and (projected_patch_embeddings is not None):
664
- return *language_model_output, projected_patch_embeddings
665
-
666
- return language_model_output
667
-
668
- return PrismaticCausalLMOutputWithPast(
669
- loss=language_model_output.loss,
670
- logits=language_model_output.logits,
671
- past_key_values=language_model_output.past_key_values,
672
- hidden_states=language_model_output.hidden_states,
673
- attentions=language_model_output.attentions,
674
- projector_features=projected_patch_embeddings,
675
- )
676
-
677
- # === GenerationMixin Methods ===
678
- def prepare_inputs_for_generation(
679
- self,
680
- input_ids: Optional[torch.Tensor] = None,
681
- past_key_values: Optional[List[torch.FloatTensor]] = None,
682
- inputs_embeds: Optional[torch.FloatTensor] = None,
683
- pixel_values: Optional[torch.FloatTensor] = None,
684
- attention_mask: Optional[torch.Tensor] = None,
685
- **kwargs: str,
686
- ) -> Dict[str, torch.Tensor]:
687
- """Borrowed from `LlamaForCausalLM` and simplified for batch size = 1; mirrors original PrismaticVLM logic."""
688
- if ((input_ids is not None) and (input_ids.shape[0] > 1)) or (
689
- (inputs_embeds is not None) and (inputs_embeds.shape[0] > 1)
690
- ):
691
- raise ValueError("Generation with batch size > 1 is not currently supported!")
692
-
693
- # Handle `past_key_values` (cache) =>> assume `input_ids` just has unprocessed tokens
694
- if past_key_values is not None:
695
- input_ids = input_ids[:, -1:]
696
-
697
- # If `input_embeds` are passed, we only want to use them in the 1st generation step
698
- if inputs_embeds is not None and past_key_values is None:
699
- model_inputs = {"input_embeds": inputs_embeds}
700
- else:
701
- model_inputs = {"input_ids": input_ids}
702
-
703
- # Make sure `pixel_values` are preserved in `model_inputs`
704
- model_inputs.update(
705
- {
706
- "attention_mask": attention_mask,
707
- "pixel_values": pixel_values,
708
- "past_key_values": past_key_values,
709
- "use_cache": kwargs.get("use_cache"),
710
- }
711
- )
712
-
713
- return model_inputs
714
-
715
- # Defer to Language Model (all handle this differently, with different return types)
716
- def _reorder_cache(self, *args, **kwargs) -> Any:
717
- return self.language_model._reorder_cache(*args, **kwargs)
718
-
719
-
720
- class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
721
- config_class: PretrainedConfig = OpenVLAConfig
722
-
723
- def __init__(self, config: OpenVLAConfig) -> None:
724
- super().__init__(config)
725
- self.norm_stats = config.norm_stats
726
-
727
- # Compute action bins
728
- self.bins = np.linspace(-1, 1, config.n_action_bins)
729
- self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0
730
-
731
- # Compute vocab size for de-tokenization -- revert added "multiple of"
732
- self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of
733
-
734
- def _prepare_input_for_action_prediction(self, input_ids, attention_mask):
735
- """Prepares input for action prediction by adding necessary tokens"""
736
- # Add (ACTION_DIM * NUM_ACTIONS_CHUNK) placeholder tokens to input_ids to simulate action tokens
737
- placeholder_action_token_ids = (
738
- torch.ones((input_ids.shape[0], ACTION_DIM * NUM_ACTIONS_CHUNK)).to(input_ids.device).to(input_ids.dtype)
739
- )
740
- input_ids = torch.cat([input_ids, placeholder_action_token_ids], dim=-1)
741
-
742
- # Add stop token to sequence (needed in non-causal bi-directional self-attention, as it appears at train time)
743
- stop_token_id = torch.ones((input_ids.shape[0], 1)).to(input_ids.device).to(input_ids.dtype) * STOP_INDEX
744
- input_ids = torch.cat([input_ids, stop_token_id], dim=-1)
745
-
746
- # Extend the attention mask to fit the new shape of input
747
- # Note: Only batch size == 1 supported right now
748
- mask_extension = (
749
- torch.ones((attention_mask.shape[0], input_ids.shape[-1] - attention_mask.shape[-1]))
750
- .to(attention_mask.device)
751
- .to(attention_mask.dtype)
752
- )
753
- attention_mask = torch.cat([attention_mask, mask_extension], dim=-1)
754
-
755
- return input_ids, attention_mask
756
-
757
- def _prepare_labels_for_action_prediction(self, labels, input_ids):
758
- """Creates labels tensor for action prediction if not provided"""
759
- # Extend labels tensor with fake action labels
760
- ARBITRARY_ACTION_TOKEN_IDX = ACTION_TOKEN_BEGIN_IDX + 1
761
- labels_extension = (
762
- torch.ones((labels.shape[0], input_ids.shape[-1] - labels.shape[-1])).to(labels.device).to(labels.dtype)
763
- * ARBITRARY_ACTION_TOKEN_IDX
764
- )
765
- labels = torch.cat([labels, labels_extension], dim=-1)
766
-
767
- # Replace last label token with stop token
768
- labels[:, -1] = STOP_INDEX
769
-
770
- return labels
771
-
772
- def _unnormalize_actions(self, normalized_actions, unnorm_key=None):
773
- """Unnormalize actions using dataset statistics"""
774
- action_norm_stats = self.get_action_stats(unnorm_key)
775
-
776
- if ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS:
777
- mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["min"], dtype=bool))
778
- action_high, action_low = np.array(action_norm_stats["max"]), np.array(action_norm_stats["min"])
779
- elif ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS_Q99:
780
- mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool))
781
- action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"])
782
- else:
783
- raise ValueError("Unsupported action/proprio normalization type detected!")
784
-
785
- actions = np.where(
786
- mask,
787
- 0.5 * (normalized_actions + 1) * (action_high - action_low + 1e-8) + action_low,
788
- normalized_actions,
789
- )
790
-
791
- return actions
792
-
793
- def _run_diffusion_prediction(
794
- self,
795
- input_embeddings,
796
- all_actions_mask,
797
- noise,
798
- action_head,
799
- projected_patch_embeddings,
800
- labels,
801
- attention_mask,
802
- NUM_PATCHES,
803
- NUM_PROMPT_TOKENS,
804
- noisy_action_projector,
805
- ):
806
- """Run diffusion-based action prediction"""
807
- # Clone embedding for reuse in each timestep
808
- orig_projected_patch_embeddings = projected_patch_embeddings.clone()
809
- curr_noisy_actions = noise
810
-
811
- # Reverse diffusion: Iteratively denoise to generate action prediction
812
- for t in action_head.noise_scheduler.timesteps:
813
- # Get diffusion model's noise prediction (conditioned on VLA latent embedding, current noisy action
814
- # embedding, and diffusion timestep embedding)
815
- timesteps = torch.Tensor([t]).to(labels.device)
816
- diffusion_timestep_embeddings = (
817
- action_head.time_encoder(timesteps).to(curr_noisy_actions.dtype).to(curr_noisy_actions.device)
818
- ) # (B, llm_dim)
819
- diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) # (B, 1, llm_dim)
820
-
821
- # [Diffusion] Replace the embeddings of the action tokens with noisy actions
822
- # (Later on, the positional embeddings will be added to them)
823
-
824
- # For simplicity, append diffusion timestep embedding to the end of projected vision tokens
825
- projected_patch_embeddings = torch.cat(
826
- (orig_projected_patch_embeddings, diffusion_timestep_embeddings), dim=1
827
- )
828
-
829
- # Reshape and project noisy actions into language embedding space
830
- B = curr_noisy_actions.shape[0]
831
- orig_curr_noisy_actions_shape = curr_noisy_actions.shape
832
- curr_noisy_actions = curr_noisy_actions.reshape(B, -1).unsqueeze(-1)
833
- noisy_action_features = noisy_action_projector(curr_noisy_actions)
834
- curr_noisy_actions = curr_noisy_actions.reshape(orig_curr_noisy_actions_shape)
835
-
836
- # Replace action token embeddings with noisy action embeddings
837
- input_embeddings = self._replace_input_embeddings(
838
- input_embeddings.clone(), all_actions_mask, noisy_action_features
839
- )
840
-
841
- # Build multimodal embeddings and attention mask
842
- multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
843
- input_embeddings, projected_patch_embeddings, attention_mask
844
- )
845
-
846
- # Forward pass through language model
847
- language_model_output = self.language_model(
848
- input_ids=None,
849
- attention_mask=multimodal_attention_mask,
850
- position_ids=None,
851
- past_key_values=None,
852
- inputs_embeds=multimodal_embeddings,
853
- labels=None,
854
- use_cache=None,
855
- output_attentions=False,
856
- output_hidden_states=True,
857
- return_dict=True,
858
- )
859
-
860
- # Extract hidden states for action portion of response
861
- last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
862
- actions_hidden_states = last_hidden_states[
863
- :,
864
- NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
865
- :,
866
- ] # (B, act_chunk_len, D)
867
-
868
- # Predict noise and update noisy actions: x_t -> x_{t-1}
869
- noise_pred = action_head.predict_noise(actions_hidden_states)
870
- curr_noisy_actions = action_head.noise_scheduler.step(noise_pred, t, curr_noisy_actions).prev_sample
871
-
872
- curr_noisy_actions = curr_noisy_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
873
-
874
- # Return final actions
875
- return curr_noisy_actions.float().cpu().detach().numpy(), actions_hidden_states
876
-
877
- def _regression_or_discrete_prediction(
878
- self,
879
- input_embeddings,
880
- all_actions_mask,
881
- projected_patch_embeddings,
882
- attention_mask,
883
- labels,
884
- NUM_PATCHES,
885
- NUM_PROMPT_TOKENS,
886
- action_head=None,
887
- ):
888
- """Run L1 regression-based continuous action prediction or discrete action tokens prediction."""
889
- # Zero out action token embeddings
890
- all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
891
- input_embeddings = input_embeddings * ~all_actions_mask
892
-
893
- # Build multimodal embeddings and attention mask
894
- multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
895
- input_embeddings, projected_patch_embeddings, attention_mask
896
- )
897
-
898
- # Forward pass through language model
899
- language_model_output = self.language_model(
900
- input_ids=None,
901
- attention_mask=multimodal_attention_mask,
902
- position_ids=None,
903
- past_key_values=None,
904
- inputs_embeds=multimodal_embeddings,
905
- labels=None,
906
- use_cache=None,
907
- output_attentions=False,
908
- output_hidden_states=True,
909
- return_dict=True,
910
- )
911
-
912
- # Extract hidden states for action tokens
913
- last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
914
- actions_hidden_states = last_hidden_states[
915
- :,
916
- NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
917
- :,
918
- ] # (B, act_chunk_len, D)
919
-
920
- # Handle different prediction methods
921
- if action_head is not None:
922
- # L1 regression prediction
923
- normalized_actions = action_head.predict_action(actions_hidden_states)
924
- normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
925
- normalized_actions = normalized_actions.float().cpu().detach().numpy()
926
- else:
927
- # Discrete token-based prediction
928
- predicted_action_token_ids = (
929
- language_model_output.logits[
930
- :,
931
- NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
932
- ]
933
- .argmax(dim=2)
934
- .cpu()
935
- .numpy()
936
- )
937
- discretized_actions = self.vocab_size - predicted_action_token_ids
938
- discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)
939
- normalized_actions = self.bin_centers[discretized_actions]
940
- normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
941
-
942
- return normalized_actions, actions_hidden_states
943
-
944
- def predict_action(
945
- self,
946
- input_ids: Optional[torch.LongTensor] = None,
947
- unnorm_key: Optional[str] = None,
948
- proprio=None,
949
- proprio_projector=None,
950
- action_head=None,
951
- noisy_action_projector=None,
952
- use_film: bool = False,
953
- **kwargs: str,
954
- ) -> np.ndarray:
955
- """Predict actions from input sequence, with options for different prediction methods.
956
-
957
- Args:
958
- input_ids: Input token ids
959
- unnorm_key: Key for unnormalization statistics
960
- proprio: Proprioceptive features
961
- proprio_projector: Projector for proprioceptive features
962
- action_head: Optional head for L1 regression or diffusion-based prediction
963
- noisy_action_projector: Projector for noisy actions in diffusion-based prediction
964
- use_film: Whether to use FiLM conditioning
965
- **kwargs: Additional arguments including pixel_values and attention_mask
966
-
967
- Returns:
968
- Tuple of (unnormalized_actions, action_hidden_states)
969
- """
970
- # If the special empty token ('') does not already appear after the colon (':') token in the prompt
971
- # (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time
972
- if not torch.all(input_ids[:, -1] == 29871):
973
- input_ids = torch.cat(
974
- (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
975
- )
976
-
977
- pixel_values = kwargs["pixel_values"]
978
- attention_mask = kwargs["attention_mask"]
979
-
980
- # Create fake labels tensor (needed for action mask)
981
- labels = input_ids.clone()
982
- labels[:] = IGNORE_INDEX
983
-
984
- # Get number of tokens in prompt (excluding the start token)
985
- NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1 # Subtract action tokens and stop token
986
-
987
- # Prepare inputs by adding necessary tokens
988
- input_ids, attention_mask = self._prepare_input_for_action_prediction(input_ids, attention_mask)
989
-
990
- # Update labels tensor for action mask computation later
991
- labels = self._prepare_labels_for_action_prediction(labels, input_ids)
992
-
993
- # Get input embeddings and action masks
994
- input_embeddings = self.get_input_embeddings()(input_ids)
995
- all_actions_mask = self._process_action_masks(labels)
996
-
997
- # Extract language embeddings
998
- language_embeddings = input_embeddings[~all_actions_mask].reshape(
999
- input_embeddings.shape[0], -1, input_embeddings.shape[2]
1000
- )
1001
-
1002
- # Process vision features
1003
- projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)
1004
-
1005
- # Add proprioceptive features if provided
1006
- use_proprio = proprio_projector is not None and proprio is not None
1007
- if use_proprio:
1008
- proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype)
1009
- projected_patch_embeddings = self._process_proprio_features(
1010
- projected_patch_embeddings, proprio, proprio_projector
1011
- )
1012
-
1013
- # Use diffusion if provided, otherwise use regression or discrete prediction
1014
- use_diffusion = noisy_action_projector is not None and hasattr(action_head, "noise_scheduler")
1015
-
1016
- # Calculate number of patches (including proprio token and/or diffusion timestep embedding if present)
1017
- NUM_PATCHES = self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input()
1018
- if use_proprio:
1019
- NUM_PATCHES += 1
1020
- if use_diffusion:
1021
- NUM_PATCHES += 1
1022
-
1023
- if use_diffusion:
1024
- # Sample random noise with shape equal to output action, used as the starting state for reverse diffusion
1025
- noise = torch.randn(
1026
- size=(1, NUM_ACTIONS_CHUNK, ACTION_DIM), device=input_embeddings.device, dtype=input_embeddings.dtype
1027
- )
1028
-
1029
- # Run diffusion-based prediction
1030
- normalized_actions, actions_hidden_states = self._run_diffusion_prediction(
1031
- input_embeddings,
1032
- all_actions_mask,
1033
- noise,
1034
- action_head,
1035
- projected_patch_embeddings,
1036
- labels,
1037
- attention_mask,
1038
- NUM_PATCHES,
1039
- NUM_PROMPT_TOKENS,
1040
- noisy_action_projector,
1041
- )
1042
- else:
1043
- # Run regression or discrete token-based prediction
1044
- normalized_actions, actions_hidden_states = self._regression_or_discrete_prediction(
1045
- input_embeddings,
1046
- all_actions_mask,
1047
- projected_patch_embeddings,
1048
- attention_mask,
1049
- labels,
1050
- NUM_PATCHES,
1051
- NUM_PROMPT_TOKENS,
1052
- action_head,
1053
- )
1054
-
1055
- # Unnormalize predicted actions
1056
- actions = self._unnormalize_actions(normalized_actions, unnorm_key)
1057
-
1058
- return actions, actions_hidden_states
1059
-
1060
- @staticmethod
1061
- def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:
1062
- """Validate and resolve the unnormalization key for action statistics"""
1063
- if unnorm_key is None:
1064
- assert len(norm_stats) == 1, (
1065
- f"Your model was trained on more than one dataset, "
1066
- f"please pass a `unnorm_key` from the following options to choose the statistics "
1067
- f"used for un-normalizing actions: {norm_stats.keys()}"
1068
- )
1069
- unnorm_key = next(iter(norm_stats.keys()))
1070
-
1071
- assert unnorm_key in norm_stats, (
1072
- f"The `unnorm_key` you chose is not in the set of available dataset statistics, "
1073
- f"please choose from: {norm_stats.keys()}"
1074
- )
1075
- return unnorm_key
1076
-
1077
- def get_action_dim(self, unnorm_key: Optional[str] = None) -> int:
1078
- """Get the dimensionality of the policy's action space."""
1079
- unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
1080
- return len(self.norm_stats[unnorm_key]["action"]["min"])
1081
-
1082
- def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]:
1083
- """Get all the logged statistics for the given dataset."""
1084
- unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
1085
- return self.norm_stats[unnorm_key]["action"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/prismatic/extern/hf/processing_prismatic.py DELETED
@@ -1,252 +0,0 @@
1
- """
2
- processing_prismatic.py
3
-
4
- HuggingFace-style preprocessor definitions for Prismatic VLMs, inheriting from `ProcessorMixin`. Default configuration
5
- specifies `siglip-224px+7b`.
6
- """
7
-
8
- from typing import Any, ClassVar, List, Optional, Tuple, Union
9
-
10
- import timm.data
11
- import torch
12
- import torchvision.transforms.functional as TVF
13
- from PIL import Image
14
- from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
15
- from transformers import PreTrainedTokenizerBase
16
- from transformers.image_processing_utils import BatchFeature, ImageProcessingMixin
17
- from transformers.processing_utils import ProcessorMixin
18
- from transformers.tokenization_utils import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
19
- from transformers.utils import TensorType
20
-
21
-
22
- # === Image Processing ===
23
- def letterbox_pad_transform(image: Image.Image, padding_fill_value: Tuple[int, int, int]) -> Image.Image:
24
- """Given a PIL.Image, pad to square by adding a symmetric border around the height/width."""
25
- (w, h), max_wh = image.size, max(image.size)
26
- horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2)
27
- padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
28
-
29
- return TVF.pad(image, padding, fill=padding_fill_value, padding_mode="constant")
30
-
31
-
32
- class PrismaticImageProcessor(ImageProcessingMixin):
33
- model_input_names: ClassVar[List[str]] = ["pixel_values"]
34
-
35
- def __init__(
36
- self,
37
- use_fused_vision_backbone: bool = False,
38
- image_resize_strategy: str = "letterbox",
39
- input_sizes: Optional[List[Tuple[int, int, int]]] = None,
40
- interpolations: Optional[List[str]] = None,
41
- means: Optional[List[Tuple[float, float, float]]] = None,
42
- stds: Optional[List[Tuple[float, float, float]]] = None,
43
- **kwargs: str,
44
- ) -> None:
45
- """
46
- Initialize a PrismaticImageProcessor as a wrapper around a torchvision transform; this transform will be
47
- created by TIMM, and edited to follow our custom `image_resize_strategy` logic.
48
- @param use_fused_vision_backbone: Boolean indicating single or fused (dual) vision backbone
49
- @param image_resize_strategy: Prismatic image resize strategy in < resize-naive | resize-crop | letterbox >
50
- @param input_size: [TIMM :: `data_cfg`] Input image size as tuple (channels, width, height)
51
- @param interpolation: [TIMM :: `data_cfg`] Interpolation as string (default: "bicubic")
52
- @param mean: [TIMM :: `data_cfg`] Normalization mean as float tuple (or two-tuple if `fused_backbone`)
53
- @param std: [TIMM :: `data_cfg`] Normalization std as float tuple (or two-tuple if `fused_backbone`)
54
- """
55
- self.use_fused_vision_backbone = use_fused_vision_backbone
56
- self.image_resize_strategy = image_resize_strategy
57
-
58
- # Handle `None` default values
59
- input_sizes = [(3, 224, 224)] if input_sizes is None else input_sizes
60
- means = [(0.5, 0.5, 0.5)] if means is None else means
61
- stds = [(0.5, 0.5, 0.5)] if stds is None else stds
62
-
63
- # TIMM `data_cfg` Parameters
64
- self.input_sizes, self.interpolations, self.means, self.stds = input_sizes, interpolations, means, stds
65
-
66
- # Grab torchvision transforms via TIMM =>> need to parse for specific "functional" transform values!
67
- self.tvf_resize_params, self.tvf_crop_params, self.tvf_normalize_params = [], [], []
68
- self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None
69
-
70
- for idx in range(len(input_sizes)):
71
- transform = timm.data.create_transform(
72
- input_size=self.input_sizes[idx],
73
- interpolation=self.interpolations[idx],
74
- mean=self.means[idx],
75
- std=self.stds[idx],
76
- crop_pct=1.0, # Set to 1.0 to ignore cropping (initial Resize sets `input_size`)
77
- crop_mode="center", # Default crop mode -- no-op when `crop_pct == 1.0`
78
- is_training=False, # No image augmentations when loading the transform!
79
- )
80
-
81
- # [Validation] Ensure appropriate transform structure, expected sizes
82
- if not (
83
- isinstance(transform, Compose)
84
- and (len(transform.transforms) == 4)
85
- and isinstance(transform.transforms[0], Resize)
86
- and isinstance(transform.transforms[1], CenterCrop)
87
- and isinstance(transform.transforms[2], ToTensor)
88
- and isinstance(transform.transforms[3], Normalize)
89
- and (transform.transforms[0].size == self.input_sizes[idx][-1])
90
- and (transform.transforms[1].size == self.input_sizes[idx][-2:])
91
- ):
92
- raise ValueError(f"Unexpected TIMM image transformation structure/sizes: `{transform}`")
93
-
94
- # HF Image Processors *must* be JSON-serializable; as such, cannot have torchvision. as an attribute.
95
- # => Instead, we're going to parse the transform and call "torchvision.transforms.functional" (`tvf`)
96
- resize_t, crop_t, norm_t = transform.transforms[0], transform.transforms[1], transform.transforms[3]
97
- self.tvf_resize_params.append(
98
- {
99
- "size": resize_t.size,
100
- "interpolation": TVF.pil_modes_mapping[resize_t.interpolation],
101
- "max_size": None,
102
- "antialias": True,
103
- }
104
- )
105
- self.tvf_crop_params.append({"output_size": crop_t.size})
106
- self.tvf_normalize_params.append(
107
- {
108
- "mean": norm_t.mean.float().numpy().tolist(),
109
- "std": norm_t.std.float().numpy().tolist(),
110
- "inplace": False,
111
- }
112
- )
113
- self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None
114
-
115
- # Handle Prismatic `image_resize_strategy`
116
- if self.image_resize_strategy == "resize-naive":
117
- self.tvf_resize_params[idx]["size"] = (resize_t.size, resize_t.size)
118
- elif self.image_resize_strategy == "letterbox":
119
- self.tvf_do_letterbox, self.tvf_letterbox_fill = True, tuple([int(x * 255) for x in self.means[idx]])
120
- elif self.image_resize_strategy == "resize-crop":
121
- pass
122
- else:
123
- raise ValueError(f"Image resize strategy `{self.image_resize_strategy}` is not supported!")
124
-
125
- # Dispatch **kwargs to super()
126
- super().__init__(**kwargs)
127
-
128
- def apply_transform(self, img: Image.Image) -> torch.Tensor:
129
- """Apply `functional` variant of TIMM's Transform = Compose([Resize -> CenterCrop -> ToTensor -> Normalize])"""
130
- if self.tvf_do_letterbox:
131
- img = letterbox_pad_transform(img, self.tvf_letterbox_fill)
132
-
133
- # [Contract] Fused Backbones expect "channel-stacked" inputs; we'll unpack on the model side!
134
- imgs_t = []
135
- for idx in range(len(self.input_sizes)):
136
- img_idx = TVF.resize(img, **self.tvf_resize_params[idx])
137
- img_idx = TVF.center_crop(img_idx, **self.tvf_crop_params[idx])
138
- img_idx_t = TVF.to_tensor(img_idx)
139
- img_idx_t = TVF.normalize(img_idx_t, **self.tvf_normalize_params[idx])
140
- imgs_t.append(img_idx_t)
141
-
142
- # [Contract] `imgs_t` is a list of Tensors of shape [3, input_size, input_size]; stack along dim = 0
143
- img_t = torch.vstack(imgs_t)
144
-
145
- return img_t
146
-
147
- def preprocess(
148
- self,
149
- images: Union[Image.Image, List[Image.Image]],
150
- return_tensors: Optional[Union[str, TensorType]] = None,
151
- **_: str,
152
- ) -> BatchFeature:
153
- """
154
- Preprocess an image (or batch of images); note that unlike the `transformers :: BaseImageProcessor` we
155
- explicitly only handle PIL.Image.Image instances for simplicity.
156
- @param images: A (batch of) PIL.Image.Image instance(s) to preprocess.
157
- @param return_tensors: BatchFeature default Tensor format (e.g., "pt" for torch); if None, returns np.ndarray
158
- @return: Instance of `transformers :: BatchFeature` with a single key "pixel_values"
159
- """
160
- if not isinstance(images, list):
161
- images = [images]
162
-
163
- # Apply `self.img_transform` to each image (will return list of torch.Tensors); stack into "batched" Tensor
164
- pixel_values = torch.stack([self.apply_transform(img.convert("RGB")) for img in images])
165
-
166
- # Return BatchFeature =>> note that for compatibility, constructor expects Dict[str, np.ndarray], so we convert
167
- return BatchFeature(data={"pixel_values": pixel_values.float().numpy()}, tensor_type=return_tensors)
168
-
169
- def __call__(self, images: Union[Image.Image, List[Image.Image]], **kwargs) -> BatchFeature:
170
- return self.preprocess(images, **kwargs)
171
-
172
-
173
- # === PrismaticProcessor =>> Wraps both ImageProcessor and Tokenizer ===
174
- # =>> https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/processing_llava.py
175
- class PrismaticProcessor(ProcessorMixin):
176
- attributes: ClassVar[List[str]] = ["image_processor", "tokenizer"]
177
- image_processor_class: str = "AutoImageProcessor"
178
- tokenizer_class: str = "AutoTokenizer"
179
-
180
- def __init__(
181
- self,
182
- image_processor: Optional[ImageProcessingMixin] = None,
183
- tokenizer: Optional[PreTrainedTokenizerBase] = None,
184
- ) -> None:
185
- super().__init__(image_processor, tokenizer)
186
-
187
- def __call__(
188
- self,
189
- text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
190
- images: Union[Image.Image, List[Image.Image]],
191
- padding: Union[bool, str, PaddingStrategy] = False,
192
- truncation: Optional[Union[bool, str, TruncationStrategy]] = None,
193
- max_length: Optional[int] = None,
194
- return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
195
- ) -> BatchFeature:
196
- """
197
- Preprocess a given (batch) of text/images for a Prismatic VLM; forwards text to the underlying LLM's tokenizer,
198
- forwards images to PrismaticImageProcessor.
199
- @param text: The (batch) of text to encode; must be a string or list of strings.
200
- @param images: A (batch of) PIL.Image.Image instance(s) to preprocess.
201
- @param padding: Sequence padding strategy (if multiple specified) in < True = "longest" | "max_length" | False >
202
- @param truncation: Truncation strategy for the output sequences; requires `max_length` to be specified
203
- @param max_length: Maximum length (in tokens) to truncate
204
- @param return_tensors: Type of return tensors (usually "pt" or TensorType.PYTORCH)
205
- @return: BatchFeature with keys for `input_ids`, `attention_mask` and `pixel_values`.
206
- """
207
- pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"]
208
- text_inputs = self.tokenizer(
209
- text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
210
- )
211
-
212
- # [Validate] Need same number of images and text inputs!
213
- if pixel_values.shape[0] != text_inputs.input_ids.shape[0]:
214
- raise ValueError("Batch is malformed; expected same number of images and text inputs!")
215
-
216
- return BatchFeature(data={**text_inputs, "pixel_values": pixel_values})
217
-
218
- # === Tokenizer Dispatch Utilities =>> check `PreTrainedTokenizerBase` for documentation ===
219
- def batch_decode(
220
- self,
221
- sequences: Union[List[int], List[List[int]], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor
222
- skip_special_tokens: bool = False,
223
- clean_up_tokenization_spaces: Optional[bool] = None,
224
- **kwargs: str,
225
- ) -> List[str]:
226
- return self.tokenizer.batch_decode(
227
- sequences=sequences,
228
- skip_special_tokens=skip_special_tokens,
229
- clean_up_tokenization_spaces=clean_up_tokenization_spaces,
230
- **kwargs,
231
- )
232
-
233
- def decode(
234
- self,
235
- token_ids: Union[int, List[int], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor
236
- skip_special_tokens: bool = False,
237
- clean_up_tokenization_spaces: Optional[bool] = None,
238
- **kwargs: str,
239
- ) -> str:
240
- return self.tokenizer.decode(
241
- token_ids=token_ids,
242
- skip_special_tokens=skip_special_tokens,
243
- clean_up_tokenization_spaces=clean_up_tokenization_spaces,
244
- **kwargs,
245
- )
246
-
247
- @property
248
- def model_input_names(self) -> List[str]:
249
- tokenizer_input_names = self.tokenizer.model_input_names
250
- image_processor_input_names = self.image_processor.model_input_names
251
-
252
- return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/prismatic/models/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from .load import available_model_names, available_models, get_model_description, load, load_vla
2
- from .materialize import get_llm_backbone_and_tokenizer, get_vision_backbone_and_transform, get_vlm
 
 
 
capvector-oft/prismatic/models/action_heads.py DELETED
@@ -1,211 +0,0 @@
1
- """Implementations of various action heads, which serve as alternatives to VLM sequential token prediction."""
2
-
3
- import math
4
-
5
- import numpy as np
6
- import torch
7
- import torch.nn as nn
8
- from diffusers.schedulers.scheduling_ddim import DDIMScheduler
9
- from prismatic.vla.constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX, NUM_ACTIONS_CHUNK, PROPRIO_DIM, STOP_INDEX
10
-
11
-
12
- class SinusoidalPositionalEncoding(nn.Module):
13
- """
14
- Sine- and cosine-based positional encoding that produces embeddings of a batch of timesteps.
15
-
16
- For example, at train time, the input might be a batch of 32 randomly sampled diffusion timesteps -> shape (32,)
17
- Then the output would be a batch of 32 timestep embeddings -> shape (32, D)
18
-
19
- Adapted from: https://github.com/real-stanford/diffusion_policy/blob/main/diffusion_policy/model/diffusion/positional_embedding.py
20
- """
21
-
22
- def __init__(self, dim):
23
- super().__init__()
24
- self.dim = dim # dimensionality of the positional encoding
25
-
26
- def forward(self, x):
27
- # x: (batch_size,)
28
- device = x.device
29
- assert self.dim % 2 == 0, f"# dimensions must be even but got {self.dim}"
30
- half_dim = self.dim // 2
31
- exponent = torch.arange(half_dim, device=device) * -math.log(10000) / (half_dim - 1) # shape: (D/2,)
32
- emb = torch.exp(exponent) # shape: (D/2,)
33
- emb = x[:, None] * emb[None, :] # shape: (batch_size, 1) * (1, D/2) -> (batch_size, D/2)
34
- emb = torch.cat((emb.sin(), emb.cos()), dim=-1) # shape: (batch_size, D)
35
- return emb
36
-
37
-
38
- class MLPResNetBlock(nn.Module):
39
- """One MLP ResNet block with a residual connection."""
40
- def __init__(self, dim):
41
- super().__init__()
42
- self.dim = dim
43
- self.ffn = nn.Sequential( # feedforward network, similar to the ones in Transformers
44
- nn.LayerNorm(dim),
45
- nn.Linear(dim, dim),
46
- nn.ReLU(),
47
- )
48
-
49
- def forward(self, x):
50
- # x: (batch_size, hidden_dim)
51
- # We follow the module ordering of "Pre-Layer Normalization" feedforward networks in Transformers as
52
- # described here: https://arxiv.org/pdf/2002.04745.pdf
53
- identity = x
54
- x = self.ffn(x)
55
- x = x + identity
56
- return x
57
-
58
-
59
- class MLPResNet(nn.Module):
60
- """MLP with residual connection blocks."""
61
- def __init__(self, num_blocks, input_dim, hidden_dim, output_dim):
62
- super().__init__()
63
- self.layer_norm1 = nn.LayerNorm(input_dim)
64
- self.fc1 = nn.Linear(input_dim, hidden_dim)
65
- self.relu = nn.ReLU()
66
- self.mlp_resnet_blocks = nn.ModuleList()
67
- for _ in range(num_blocks):
68
- self.mlp_resnet_blocks.append(MLPResNetBlock(dim=hidden_dim))
69
- self.layer_norm2 = nn.LayerNorm(hidden_dim)
70
- self.fc2 = nn.Linear(hidden_dim, output_dim)
71
-
72
- def forward(self, x):
73
- # x: (batch_size, input_dim)
74
- x = self.layer_norm1(x) # shape: (batch_size, input_dim)
75
- x = self.fc1(x) # shape: (batch_size, hidden_dim)
76
- x = self.relu(x) # shape: (batch_size, hidden_dim)
77
- for block in self.mlp_resnet_blocks:
78
- x = block(x) # shape: (batch_size, hidden_dim)
79
- x = self.layer_norm2(x) # shape: (batch_size, hidden_dim)
80
- x = self.fc2(x) # shape: (batch_size, output_dim)
81
- return x
82
-
83
-
84
- class L1RegressionActionHead(nn.Module):
85
- """Simple MLP-based action head that generates continuous actions via L1 regression."""
86
- def __init__(
87
- self,
88
- input_dim=4096,
89
- hidden_dim=4096,
90
- action_dim=7,
91
- ):
92
- super().__init__()
93
- self.action_dim = action_dim
94
- self.model = MLPResNet(
95
- num_blocks=2, input_dim=input_dim*ACTION_DIM, hidden_dim=hidden_dim, output_dim=action_dim
96
- )
97
-
98
- def predict_action(self, actions_hidden_states):
99
- # actions_hidden_states: last hidden states of Transformer corresponding to action tokens in sequence
100
- # - shape: (batch_size, chunk_len * action_dim, hidden_dim)
101
- # ground_truth_actions: ground-truth actions
102
- # - shape: (batch_size, chunk_len, action_dim)
103
- batch_size = actions_hidden_states.shape[0]
104
- device = actions_hidden_states.device
105
- rearranged_actions_hidden_states = actions_hidden_states.reshape(batch_size, NUM_ACTIONS_CHUNK, -1)
106
- action = self.model(rearranged_actions_hidden_states)
107
- return action
108
-
109
-
110
- class NoisePredictionModel(nn.Module):
111
- """
112
- Diffusion noise prediction model that takes an observation embedding (which fuses the
113
- noisy action, diffusion timestep, and image-language observation embeddings) and
114
- outputs a noise prediction.
115
- """
116
-
117
- def __init__(
118
- self,
119
- transformer_hidden_dim, # Transformer hidden embedding size
120
- hidden_dim, # MLP hidden size
121
- action_dim=7, # action dimensionality
122
- ):
123
- super().__init__()
124
- self.mlp_resnet = MLPResNet(
125
- num_blocks=2,
126
- input_dim=transformer_hidden_dim,
127
- hidden_dim=hidden_dim,
128
- output_dim=action_dim,
129
- )
130
-
131
- def forward(
132
- self,
133
- obs,
134
- ):
135
- # obs: observation embeddings to condition the generation on
136
- # - shape: (batch_size, chunk_len, rearranged_hidden_dim=action_dim*hidden_dim)
137
- #
138
- # output: predicted noise
139
- # - shape: (batch_size, action_dim)
140
- output = self.mlp_resnet(obs)
141
- return output
142
-
143
-
144
- class DiffusionActionHead(nn.Module):
145
- """
146
- Simple MLP-based action head that generates continuous actions via conditional denoising diffusion process.
147
-
148
- Loosely inspired by: https://github.com/real-stanford/diffusion_policy/blob/main/diffusion_policy/model/diffusion/transformer_for_diffusion.py
149
- """
150
-
151
- def __init__(
152
- self,
153
- input_dim=4096,
154
- hidden_dim=4096,
155
- action_dim=7,
156
- num_diffusion_steps_train=50,
157
- ):
158
- super().__init__()
159
- self.action_dim = action_dim
160
- self.noise_predictor = NoisePredictionModel(
161
- transformer_hidden_dim=hidden_dim*ACTION_DIM, hidden_dim=hidden_dim, action_dim=action_dim
162
- )
163
- self.num_diffusion_steps_train = num_diffusion_steps_train
164
- self.noise_scheduler = DDIMScheduler(num_train_timesteps=num_diffusion_steps_train, beta_schedule="squaredcos_cap_v2")
165
- self.time_encoder = SinusoidalPositionalEncoding(dim=hidden_dim)
166
-
167
- def sample_noisy_actions(self, ground_truth_actions):
168
- """
169
- Samples noise and applies noise to ground-truth actions to produce noisy actions, which are
170
- used as input in the noise prediction network. Returns noise, noisy actions, and the
171
- corresponding diffusion timestep embeddings.
172
- """
173
- # ground_truth_actions: ground-truth actions
174
- # - shape: (batch_size, chunk_len, action_dim)
175
- batch_size = ground_truth_actions.shape[0]
176
- device = ground_truth_actions.device
177
- # Sample random noise with shape equal to actions, used for closed-form forward diffusion.
178
- noise = torch.randn(size=(batch_size, NUM_ACTIONS_CHUNK, ACTION_DIM), device=device, dtype=ground_truth_actions.dtype) # (B, chunk_len, action_dim)
179
- # Sample random diffusion timesteps (one for each action in batch).
180
- timesteps = torch.randint(
181
- low=0, high=self.noise_scheduler.config.num_train_timesteps, size=(batch_size,), device=device
182
- )
183
- # Add noise to clean actions according to the magnitude at each diffusion timestep via
184
- # closed-form forward diffusion.
185
- noisy_actions = self.noise_scheduler.add_noise(ground_truth_actions, noise, timesteps) # (B, chunk_len, action_dim)
186
-
187
- # Get diffusion timestep embeddings as well
188
- diffusion_timestep_embeddings = self.time_encoder(timesteps).to(noisy_actions.dtype).to(noisy_actions.device) # (B, llm_dim)
189
- diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) # (B, 1, llm_dim)
190
-
191
- return_dict = dict(
192
- noise=noise,
193
- noisy_actions=noisy_actions,
194
- diffusion_timestep_embeddings=diffusion_timestep_embeddings,
195
- )
196
-
197
- return return_dict
198
-
199
- def predict_noise(self, actions_hidden_states):
200
- """
201
- Given a batch of last hidden Transformer layer embeddings (which fuse the vision-language observation embeddings,
202
- noisy action embeddings, and diffusion timestep embedding), predicts the noise applied to the actions.
203
- """
204
- # actions_hidden_states: last hidden states of Transformer corresponding to action tokens in sequence
205
- # - shape: (batch_size, chunk_len * action_dim, hidden_dim)
206
- batch_size = actions_hidden_states.shape[0]
207
- device = actions_hidden_states.device
208
- rearranged_actions_hidden_states = actions_hidden_states.reshape(batch_size, NUM_ACTIONS_CHUNK, -1) # (batch_size, chunk_len, action_dim * hidden_dim)
209
- # Get diffusion model's noise prediction.
210
- noise_pred = self.noise_predictor(rearranged_actions_hidden_states)
211
- return noise_pred
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/prismatic/models/backbones/__init__.py DELETED
File without changes
capvector-oft/prismatic/models/backbones/llm/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- from .base_llm import LLMBackbone
2
- from .llama2 import LLaMa2LLMBackbone
3
- from .mistral import MistralLLMBackbone
4
- from .phi import PhiLLMBackbone
 
 
 
 
 
capvector-oft/prismatic/models/backbones/llm/base_llm.py DELETED
@@ -1,223 +0,0 @@
1
- """
2
- base_llm.py
3
-
4
- Abstract class definition of a large (autoregressive) language model backbone (LLM), with full annotations of class
5
- methods, utility functions, and initialization logic.
6
-
7
- We also define the generic HFLLMBackbone class here, providing a default interface for loading any HF
8
- AutoModelForCausalLM (e.g., LLamaForCausalLM). In general, we make the assumption that any given LLM backbone implements
9
- the AutoModelForCausalLM API (though we may add Seq2Seq models in the future).
10
-
11
- We make this assumption to keep the LLM handling in this codebase relatively lightweight, and to inherit all the nice HF
12
- utilities around different types of decoding/generation strategies.
13
- """
14
-
15
- import warnings
16
- from abc import ABC, abstractmethod
17
- from functools import partial
18
- from typing import Callable, List, Optional, Sequence, Type
19
-
20
- import torch
21
- import torch.nn as nn
22
- from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
23
- from transformers import AutoConfig, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase
24
- from transformers.modeling_outputs import CausalLMOutputWithPast
25
-
26
- from prismatic.models.backbones.llm.prompting import PromptBuilder
27
- from prismatic.overwatch import initialize_overwatch
28
-
29
- # Suppress HF Deprecation Warnings
30
- warnings.filterwarnings("ignore", category=FutureWarning)
31
-
32
- # Initialize Overwatch =>> Wraps `logging.Logger`
33
- overwatch = initialize_overwatch(__name__)
34
-
35
-
36
- # === Abstract Base Class for arbitrary HF LLM Backbones ===
37
- class LLMBackbone(nn.Module, ABC):
38
- def __init__(self, llm_backbone_id: str) -> None:
39
- super().__init__()
40
- self.identifier = llm_backbone_id
41
-
42
- # Instance attributes for an LLM Backbone
43
- self.llm: PreTrainedModel = None
44
- self.tokenizer: PreTrainedTokenizerBase = None
45
-
46
- def get_tokenizer(self) -> PreTrainedTokenizerBase:
47
- return self.tokenizer
48
-
49
- @abstractmethod
50
- def get_fsdp_wrapping_policy(self) -> Callable: ...
51
-
52
- @abstractmethod
53
- def enable_gradient_checkpointing(self) -> None: ...
54
-
55
- @abstractmethod
56
- def forward(
57
- self,
58
- input_ids: Optional[torch.LongTensor] = None,
59
- attention_mask: Optional[torch.Tensor] = None,
60
- position_ids: Optional[torch.LongTensor] = None,
61
- past_key_values: Optional[List[torch.FloatTensor]] = None,
62
- inputs_embeds: Optional[torch.FloatTensor] = None,
63
- labels: Optional[torch.LongTensor] = None,
64
- use_cache: Optional[bool] = None,
65
- output_attentions: Optional[bool] = None,
66
- output_hidden_states: Optional[bool] = None,
67
- return_dict: Optional[bool] = None,
68
- ) -> CausalLMOutputWithPast:
69
- """Run a forward pass through the LLM given targets (labels), returning the scalar Cross-Entropy Loss"""
70
- raise NotImplementedError
71
-
72
- @abstractmethod
73
- def embed_input_ids(self, input_ids: torch.LongTensor) -> torch.Tensor: ...
74
-
75
- @property
76
- @abstractmethod
77
- def prompt_builder_fn(self) -> Type[PromptBuilder]: ...
78
-
79
- @property
80
- @abstractmethod
81
- def transformer_layer_cls(self) -> Type[nn.Module]: ...
82
-
83
- @property
84
- @abstractmethod
85
- def half_precision_dtype(self) -> torch.dtype: ...
86
-
87
- @property
88
- @abstractmethod
89
- def last_layer_finetune_modules(self) -> Sequence[nn.Module]: ...
90
-
91
- @property
92
- def embed_dim(self) -> int:
93
- return self.llm.config.hidden_size
94
-
95
- @property
96
- def pad_token_id(self) -> int:
97
- return self.tokenizer.pad_token_id
98
-
99
-
100
- # === Abstract Base Class for Arbitrary HF Causal LLMs ===
101
- class HFCausalLLMBackbone(LLMBackbone, ABC):
102
- def __init__(
103
- self,
104
- llm_backbone_id: str,
105
- llm_family: str,
106
- llm_cls: Type[PreTrainedModel],
107
- hf_hub_path: str,
108
- llm_max_length: int = 2048,
109
- hf_token: Optional[str] = None,
110
- inference_mode: bool = False,
111
- use_flash_attention_2: bool = False,
112
- ) -> None:
113
- super().__init__(llm_backbone_id)
114
- self.llm_family = llm_family
115
- self.llm_max_length = llm_max_length
116
- self.inference_mode = inference_mode
117
-
118
- # Initialize LLM (downloading from HF Hub if necessary) --> `llm_cls` is the actual {Model}ForCausalLM class!
119
- # => Note: We're eschewing use of the AutoModel API so that we can be more explicit about LLM-specific details
120
- if not self.inference_mode:
121
- overwatch.info(f"Loading [bold]{llm_family}[/] LLM from [underline]`{hf_hub_path}`[/]", ctx_level=1)
122
- self.llm = llm_cls.from_pretrained(
123
- hf_hub_path,
124
- token=hf_token,
125
- use_flash_attention_2=use_flash_attention_2 if not self.inference_mode else False,
126
- # The following parameters are set to prevent `UserWarnings` from HF; we want greedy decoding!
127
- do_sample=False,
128
- temperature=1.0,
129
- top_p=1.0,
130
- )
131
-
132
- # [Contract] `inference_mode` means we're loading from a pretrained checkpoint; no need to load base weights!
133
- else:
134
- overwatch.info(f"Building empty [bold]{llm_family}[/] LLM from [underline]`{hf_hub_path}`[/]", ctx_level=1)
135
- llm_config = AutoConfig.from_pretrained(hf_hub_path, token=hf_token)
136
- self.llm = llm_cls._from_config(llm_config)
137
-
138
- # Lightweight Handling (with extended explanation) for setting some LLM Parameters
139
- # => Set `decoder.use_cache = False` --> incompatible with gradient checkpointing (+ training in general)
140
- #
141
- # Reference: https://discuss.huggingface.co/t/what-is-the-purpose-of-use-cache-in-decoder/958
142
- self.llm.config.use_cache = False if not self.inference_mode else True
143
-
144
- # => Turns out that when gradient checkpointing is on and the underlying LLM has no "trainable" parameters
145
- # (requires_grad is False), backprop will fail; setting `enable_input_requires_grad()` registers a new
146
- # forward hook that fixes this =>> also totally safe for the "full finetuning" setting!
147
- if not self.inference_mode:
148
- self.llm.enable_input_require_grads()
149
-
150
- # Load (Fast) Tokenizer
151
- overwatch.info(f"Loading [bold]{llm_family}[/] (Fast) Tokenizer via the AutoTokenizer API", ctx_level=1)
152
- self.tokenizer = AutoTokenizer.from_pretrained(
153
- hf_hub_path, model_max_length=self.llm_max_length, token=hf_token, padding_side="right"
154
- )
155
-
156
- # Validation =>> Our VLM logic currently operates under the assumption that the tokenization of a new input
157
- # starts with a <BOS> token unless `add_special_tokens = False`; for these models, we empirically
158
- # find that adding image patches *after* the BOS leads to much better performance.
159
- #
160
- # As a result we explicitly validate that a tokenizer conforms to the expected behavior; if you're reading this
161
- # line, it's probably because you're adding a new LLM with a different tokenizer behavior. If so, feel free to
162
- # override the `SPECIAL_CASES` set below, but make sure to make the appropriate changes in the `datasets.py`
163
- # and VLM `forward()` logic!
164
- SPECIAL_CASES = {
165
- # Phi-2 Tokenizer doesn't add any BOS tokens by default, and sets BOS == EOS == "<|endoftext|>"
166
- # =>> We'll prepend BOS to first input (to play nicely with image token insertion logic; verified that
167
- # this works well with base LLM generation.
168
- # =>> Like Llama-2 Tokenizers -- we'll add a special PAD token for training purposes.
169
- "phi-2-3b",
170
- }
171
- if self.identifier in SPECIAL_CASES:
172
- return
173
-
174
- # Note =>> this assert should hold for all Llama-derived tokenizers (`LlamaTokenizerFast` ==> includes Mistral!
175
- assert (self.tokenizer("Test 123", add_special_tokens=True).input_ids[0] == self.tokenizer.bos_token_id) and (
176
- self.tokenizer("Test 123", add_special_tokens=False).input_ids[0] != self.tokenizer.bos_token_id
177
- ), (
178
- f"Default Tokenizer of type `{type(self.tokenizer)}` does not automatically prefix inputs with BOS token!\n"
179
- "Please read the comment in `base_llm.py` for more information!"
180
- )
181
-
182
- def get_fsdp_wrapping_policy(self) -> Callable:
183
- """Return a `transformer_auto_wrap_policy` where we wrap each instance of `self.transformer_layer_cls`"""
184
- transformer_block_policy = partial(
185
- transformer_auto_wrap_policy, transformer_layer_cls={self.transformer_layer_cls}
186
- )
187
-
188
- return transformer_block_policy
189
-
190
- def enable_gradient_checkpointing(self) -> None:
191
- """Dispatch to underlying LLM instance's `gradient_checkpointing_enable`; defined for all `PretrainedModel`."""
192
- self.llm.gradient_checkpointing_enable()
193
-
194
- def embed_input_ids(self, input_ids: torch.LongTensor) -> torch.Tensor:
195
- return self.llm.get_input_embeddings()(input_ids)
196
-
197
- # [Contract] Should match the `forward` call of the underlying `llm` instance!
198
- def forward(
199
- self,
200
- input_ids: Optional[torch.LongTensor] = None,
201
- attention_mask: Optional[torch.Tensor] = None,
202
- position_ids: Optional[torch.LongTensor] = None,
203
- past_key_values: Optional[List[torch.FloatTensor]] = None,
204
- inputs_embeds: Optional[torch.FloatTensor] = None,
205
- labels: Optional[torch.LongTensor] = None,
206
- use_cache: Optional[bool] = None,
207
- output_attentions: Optional[bool] = None,
208
- output_hidden_states: Optional[bool] = None,
209
- return_dict: Optional[bool] = None,
210
- ) -> CausalLMOutputWithPast:
211
- output: CausalLMOutputWithPast = self.llm(
212
- input_ids=input_ids,
213
- attention_mask=attention_mask,
214
- position_ids=position_ids,
215
- past_key_values=past_key_values,
216
- inputs_embeds=inputs_embeds,
217
- labels=labels,
218
- use_cache=use_cache,
219
- output_attentions=output_attentions,
220
- output_hidden_states=output_hidden_states,
221
- return_dict=return_dict,
222
- )
223
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/prismatic/models/backbones/llm/llama2.py DELETED
@@ -1,102 +0,0 @@
1
- """
2
- llama2.py
3
-
4
- Class definition for all LLMs derived from LlamaForCausalLM.
5
- """
6
-
7
- from typing import Optional, Sequence, Type
8
-
9
- import torch
10
- from torch import nn as nn
11
- from transformers import LlamaForCausalLM
12
- from transformers.models.llama.modeling_llama import LlamaDecoderLayer
13
-
14
- from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone
15
- from prismatic.models.backbones.llm.prompting import (
16
- LLaMa2ChatPromptBuilder,
17
- PromptBuilder,
18
- PurePromptBuilder,
19
- VicunaV15ChatPromptBuilder,
20
- )
21
-
22
- # Registry =>> Support LLaMa-2 Models (from HF Transformers)
23
- # fmt: off
24
- LLAMA2_MODELS = {
25
- # === Pure Meta LLaMa-2 (non-instruct/chat-tuned) Models ===
26
- "llama2-7b-pure": {
27
- "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-7b-hf"
28
- },
29
-
30
- "llama2-13b-pure": {
31
- "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-13b-hf"
32
- },
33
-
34
- # === Meta LLaMa-2 Chat Models ===
35
- "llama2-7b-chat": {
36
- "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-7b-chat-hf"
37
- },
38
-
39
- "llama2-13b-chat": {
40
- "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-13b-chat-hf"
41
- },
42
-
43
- # === Vicuna v1.5 Chat Models ===
44
- "vicuna-v15-7b": {
45
- "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "lmsys/vicuna-7b-v1.5"
46
- },
47
-
48
- "vicuna-v15-13b": {
49
- "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "lmsys/vicuna-13b-v1.5"
50
- },
51
- }
52
- # fmt: on
53
-
54
-
55
- class LLaMa2LLMBackbone(HFCausalLLMBackbone):
56
- def __init__(
57
- self,
58
- llm_backbone_id: str,
59
- llm_max_length: int = 2048,
60
- hf_token: Optional[str] = None,
61
- inference_mode: bool = False,
62
- use_flash_attention_2: bool = True,
63
- ) -> None:
64
- super().__init__(
65
- llm_backbone_id,
66
- llm_max_length=llm_max_length,
67
- hf_token=hf_token,
68
- inference_mode=inference_mode,
69
- use_flash_attention_2=use_flash_attention_2,
70
- **LLAMA2_MODELS[llm_backbone_id],
71
- )
72
-
73
- # [Special Case] LLaMa-2 PAD Token Handling --> for clarity, we add an extra token (and resize)
74
- self.tokenizer.add_special_tokens({"pad_token": "<PAD>"})
75
- self.llm.config.pad_token_id = self.tokenizer.pad_token_id
76
- self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64)
77
-
78
- @property
79
- def prompt_builder_fn(self) -> Type[PromptBuilder]:
80
- if self.identifier.startswith("llama2-") and self.identifier.endswith("-pure"):
81
- return PurePromptBuilder
82
-
83
- elif self.identifier.startswith("llama2-") and self.identifier.endswith("-chat"):
84
- return LLaMa2ChatPromptBuilder
85
-
86
- elif self.identifier.startswith("vicuna"):
87
- return VicunaV15ChatPromptBuilder
88
-
89
- raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`")
90
-
91
- @property
92
- def transformer_layer_cls(self) -> Type[nn.Module]:
93
- return LlamaDecoderLayer
94
-
95
- @property
96
- def half_precision_dtype(self) -> torch.dtype:
97
- """LLaMa-2 was trained in BF16; see https://huggingface.co/docs/transformers/main/model_doc/llama2."""
98
- return torch.bfloat16
99
-
100
- @property
101
- def last_layer_finetune_modules(self) -> Sequence[nn.Module]:
102
- return (self.llm.model.embed_tokens, self.llm.model.layers[-1], self.llm.lm_head)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capvector-oft/prismatic/models/backbones/llm/mistral.py DELETED
@@ -1,72 +0,0 @@
1
- """
2
- mistral.py
3
-
4
- Class definition for all LLMs derived from MistralForCausalLM.
5
- """
6
-
7
- from typing import Optional, Type
8
-
9
- import torch
10
- from torch import nn as nn
11
- from transformers import MistralForCausalLM
12
- from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
13
-
14
- from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone
15
- from prismatic.models.backbones.llm.prompting import MistralInstructPromptBuilder, PromptBuilder, PurePromptBuilder
16
-
17
- # Registry =>> Support Mistral Models (from HF Transformers)
18
- # fmt: off
19
- MISTRAL_MODELS = {
20
- # === Base Mistral v0.1 ===
21
- "mistral-v0.1-7b-pure": {
22
- "llm_family": "mistral", "llm_cls": MistralForCausalLM, "hf_hub_path": "mistralai/Mistral-7B-v0.1"
23
- },
24
-
25
- # === Mistral Instruct v0.1 ===
26
- "mistral-v0.1-7b-instruct": {
27
- "llm_family": "mistral", "llm_cls": MistralForCausalLM, "hf_hub_path": "mistralai/Mistral-7B-Instruct-v0.1"
28
- }
29
- }
30
- # fmt: on
31
-
32
-
33
- class MistralLLMBackbone(HFCausalLLMBackbone):
34
- def __init__(
35
- self,
36
- llm_backbone_id: str,
37
- llm_max_length: int = 2048,
38
- hf_token: Optional[str] = None,
39
- inference_mode: bool = False,
40
- use_flash_attention_2: bool = True,
41
- ) -> None:
42
- super().__init__(
43
- llm_backbone_id,
44
- llm_max_length=llm_max_length,
45
- hf_token=hf_token,
46
- inference_mode=inference_mode,
47
- use_flash_attention_2=use_flash_attention_2,
48
- **MISTRAL_MODELS[llm_backbone_id],
49
- )
50
-
51
- # [Special Case] Mistral PAD Token Handling --> for clarity, we add an extra token (and resize)
52
- self.tokenizer.add_special_tokens({"pad_token": "<PAD>"})
53
- self.llm.config.pad_token_id = self.tokenizer.pad_token_id
54
- self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64)
55
-
56
- @property
57
- def prompt_builder_fn(self) -> Type[PromptBuilder]:
58
- if self.identifier.endswith("-pure"):
59
- return PurePromptBuilder
60
-
61
- elif self.identifier.endswith("-instruct"):
62
- return MistralInstructPromptBuilder
63
-
64
- raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`")
65
-
66
- @property
67
- def transformer_layer_cls(self) -> Type[nn.Module]:
68
- return MistralDecoderLayer
69
-
70
- @property
71
- def half_precision_dtype(self) -> torch.dtype:
72
- return torch.bfloat16