lfh commited on
Commit ·
eb868a1
1
Parent(s): 95a3948
remove files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +0 -35
- README.md +0 -57
- capvector-oft/.pre-commit-config.yaml +0 -27
- capvector-oft/ALOHA.md +0 -157
- capvector-oft/LIBERO.md +0 -130
- capvector-oft/LICENSE +0 -21
- capvector-oft/SETUP.md +0 -24
- capvector-oft/capvector/.gitignore +0 -8
- capvector-oft/capvector/compute_lora_diff.py +0 -35
- capvector-oft/capvector/compute_lora_shell/compute_lora_diff.sh +0 -8
- capvector-oft/capvector/initialized_interpolate_shell/get_vector_robotwin.sh +0 -26
- capvector-oft/capvector/interpolate.py +0 -247
- capvector-oft/capvector/interpolate.sh +0 -26
- capvector-oft/capvector/interpolate_robotwin.py +0 -247
- capvector-oft/capvector/tools/check_model_config.py +0 -23
- capvector-oft/capvector/tools/compute_lora_diff.py +0 -36
- capvector-oft/capvector/tools/compute_lora_diff.sh +0 -8
- capvector-oft/capvector/tools/vector_analyze.py +0 -153
- capvector-oft/capvector/tools/vector_regularize.py +0 -75
- capvector-oft/experiments/robot/aloha/aloha_utils.py +0 -85
- capvector-oft/experiments/robot/aloha/constants.py +0 -100
- capvector-oft/experiments/robot/aloha/preprocess_split_aloha_data.py +0 -260
- capvector-oft/experiments/robot/aloha/real_env.py +0 -213
- capvector-oft/experiments/robot/aloha/requirements_aloha.txt +0 -26
- capvector-oft/experiments/robot/aloha/robot_utils.py +0 -187
- capvector-oft/experiments/robot/aloha/run_aloha_eval.py +0 -385
- capvector-oft/experiments/robot/libero/libero_requirements.txt +0 -6
- capvector-oft/experiments/robot/libero/libero_utils.py +0 -87
- capvector-oft/experiments/robot/libero/regenerate_libero_dataset.py +0 -249
- capvector-oft/experiments/robot/libero/run_libero_eval.py +0 -540
- capvector-oft/experiments/robot/libero/sample_libero_spatial_observation.pkl +0 -3
- capvector-oft/experiments/robot/openvla_utils.py +0 -818
- capvector-oft/experiments/robot/robot_utils.py +0 -199
- capvector-oft/prismatic/__init__.py +0 -1
- capvector-oft/prismatic/conf/__init__.py +0 -3
- capvector-oft/prismatic/conf/datasets.py +0 -133
- capvector-oft/prismatic/conf/models.py +0 -584
- capvector-oft/prismatic/conf/vla.py +0 -235
- capvector-oft/prismatic/extern/__init__.py +0 -0
- capvector-oft/prismatic/extern/hf/__init__.py +0 -0
- capvector-oft/prismatic/extern/hf/configuration_prismatic.py +0 -140
- capvector-oft/prismatic/extern/hf/modeling_prismatic.py +0 -1085
- capvector-oft/prismatic/extern/hf/processing_prismatic.py +0 -252
- capvector-oft/prismatic/models/__init__.py +0 -2
- capvector-oft/prismatic/models/action_heads.py +0 -211
- capvector-oft/prismatic/models/backbones/__init__.py +0 -0
- capvector-oft/prismatic/models/backbones/llm/__init__.py +0 -4
- capvector-oft/prismatic/models/backbones/llm/base_llm.py +0 -223
- capvector-oft/prismatic/models/backbones/llm/llama2.py +0 -102
- capvector-oft/prismatic/models/backbones/llm/mistral.py +0 -72
.gitattributes
DELETED
|
@@ -1,35 +0,0 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
DELETED
|
@@ -1,57 +0,0 @@
|
|
| 1 |
-
# CapVector: Learning Transferable Capability Vectors in Parametric Space for Vision-Language-Action Models
|
| 2 |
-
|
| 3 |
-
<div align="center">
|
| 4 |
-
|
| 5 |
-
[](http://arxiv.org/abs/) [](https://capvector.github.io/) [](https://huggingface.co/haofuly/capvector_models_collection)
|
| 6 |
-
|
| 7 |
-
</div>
|
| 8 |
-
|
| 9 |
-
CapVector is a training recipe for vision-language-action (VLA) models that extracts a transferable capability vector from the parameter difference between auxiliary-objective SFT methods and standard SFT methods. This vector is merged into a pretrained VLA to form a stronger initialization, and downstream adaptation uses standard SFT with a lightweight orthogonal regularization loss to preserve the injected capability.
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
## 🌟 Key Features
|
| 13 |
-
- **Efficient downstream adaptation**: CapVector recovers much of the benefit of auxiliary-objective SFT methods, while keeping the downstream overhead close to standard SFT.
|
| 14 |
-
- **Versatility**: CapVector fits for OpenVLA-based, OpenPi-based, and StarVLA-based backbones.
|
| 15 |
-
- **Generalization**: CapVector is designed to transfer across tasks, environments, and robot embodiments.
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
## 🚀 Get Started
|
| 19 |
-
|
| 20 |
-
This repository provides two implementation paths:
|
| 21 |
-
- [`capvector-oft/`](./capvector-oft) based implementation
|
| 22 |
-
- [`capvector-pi05/`](./capvector-pi05) based implementation.
|
| 23 |
-
|
| 24 |
-
Choose the subdirectory that matches your base model and training stack. Follow the subproject README for environment setup, data preparation, training, and inference.
|
| 25 |
-
|
| 26 |
-
[`capvector-pi05/`](./capvector-pi05) provides the capability vector extraction and merging scripts.
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
## 🌏 Contact
|
| 30 |
-
For further discussion and collaboration, please feel free to contact us via Email and WeChat:
|
| 31 |
-
|
| 32 |
-
| Author | Email | WeChat |
|
| 33 |
-
|:---:|:---:|:---:|
|
| 34 |
-
| Wenxuan Song | songwenxuan0115@gmail.com | swx0757 |
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
## ❤️ Acknowledgments
|
| 38 |
-
|
| 39 |
-
CapVector builds on and interfaces with several excellent open-source projects, including:
|
| 40 |
-
|
| 41 |
-
- [OpenVLA-OFT](https://github.com/moojink/openvla-oft)
|
| 42 |
-
- [OpenPI](https://github.com/Physical-Intelligence/openpi)
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
## 🖊 Citation
|
| 46 |
-
|
| 47 |
-
If you find this work useful, please cite:
|
| 48 |
-
|
| 49 |
-
```bibtex
|
| 50 |
-
@article{song2026capvector,
|
| 51 |
-
title = {CapVector: Learning Transferable Capability Vectors in Parametric Space for Vision-Language-Action Models},
|
| 52 |
-
author = {Song, Wenxuan and Zhao, Han and Li, Fuhao and Zhou, Ziyang and Wang, Xi and Lyu, Jing and Ding, Pengxiang and Wang, Yan and Wang, Donglin and Li, Haoang},
|
| 53 |
-
journal = {Preprint},
|
| 54 |
-
year = {2026}
|
| 55 |
-
}
|
| 56 |
-
```
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/.pre-commit-config.yaml
DELETED
|
@@ -1,27 +0,0 @@
|
|
| 1 |
-
# See https://pre-commit.com for more information
|
| 2 |
-
# See https://pre-commit.com/hooks.html for more hooks
|
| 3 |
-
exclude: ".git"
|
| 4 |
-
|
| 5 |
-
repos:
|
| 6 |
-
- repo: https://github.com/astral-sh/ruff-pre-commit
|
| 7 |
-
rev: v0.2.2
|
| 8 |
-
hooks:
|
| 9 |
-
- id: ruff
|
| 10 |
-
args: [ --fix, --exit-non-zero-on-fix ]
|
| 11 |
-
|
| 12 |
-
- repo: https://github.com/psf/black
|
| 13 |
-
rev: 24.2.0
|
| 14 |
-
hooks:
|
| 15 |
-
- id: black
|
| 16 |
-
|
| 17 |
-
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 18 |
-
rev: v4.5.0
|
| 19 |
-
hooks:
|
| 20 |
-
- id: check-added-large-files
|
| 21 |
-
- id: check-ast
|
| 22 |
-
- id: check-case-conflict
|
| 23 |
-
- id: check-merge-conflict
|
| 24 |
-
- id: check-toml
|
| 25 |
-
- id: check-yaml
|
| 26 |
-
- id: end-of-file-fixer
|
| 27 |
-
- id: trailing-whitespace
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/ALOHA.md
DELETED
|
@@ -1,157 +0,0 @@
|
|
| 1 |
-
# OpenVLA-OFT+ in Real-World ALOHA Robot Tasks
|
| 2 |
-
|
| 3 |
-
## Relevant Files
|
| 4 |
-
|
| 5 |
-
Evaluation
|
| 6 |
-
* `experiments/robot/aloha/`: ALOHA training and eval files
|
| 7 |
-
* `run_aloha_eval.py`: ALOHA eval script (CLIENT SIDE; see "SERVER SIDE" below)
|
| 8 |
-
* `aloha_utils.py`: ALOHA eval utils
|
| 9 |
-
* Other ALOHA robot environment files copied from the original [ALOHA GitHub repo](https://github.com/tonyzhaozh/aloha):
|
| 10 |
-
* `constants.py`
|
| 11 |
-
* `real_env.py`
|
| 12 |
-
* `robot_utils.py`
|
| 13 |
-
* `experiments/robot/`: General eval utils files
|
| 14 |
-
* `openvla_utils.py`: OpenVLA-specific eval utils
|
| 15 |
-
* `robot_utils.py`: Other eval utils
|
| 16 |
-
* `vla-scripts/deploy.py`: VLA server deploy script (SERVER SIDE)
|
| 17 |
-
|
| 18 |
-
Note: Unlike the LIBERO evaluation setup, we use a server-client interface here. This is particularly useful if the user's machine which commands the robot does not have access to a local GPU with sufficient specs to run the fine-tuned VLA policies.
|
| 19 |
-
|
| 20 |
-
Training
|
| 21 |
-
* `experiments/robot/aloha/`: ALOHA training and eval files
|
| 22 |
-
* `preprocess_split_aloha_data.py`: ALOHA data preprocessing script
|
| 23 |
-
* `vla-scripts/finetune.py`: VLA fine-tuning script
|
| 24 |
-
|
| 25 |
-
## Setup
|
| 26 |
-
|
| 27 |
-
Set up a conda environment for training policies and deploying them on the VLA server (see instructions in [SETUP.md](SETUP.md)).
|
| 28 |
-
|
| 29 |
-
## Fine-Tuning on ALOHA Robot Data
|
| 30 |
-
|
| 31 |
-
We assume that you have collected a set of expert demonstrations on the ALOHA robot already.
|
| 32 |
-
|
| 33 |
-
First, use our `preprocess_split_aloha_data.py` script to preprocess the raw ALOHA dataset: downsize images from 480x640 to 256x256 and split into training and validation sets. Below are examples for the `put X into pot` task in our paper (which has 3 possible target objects, 1 per episode):
|
| 34 |
-
|
| 35 |
-
```bash
|
| 36 |
-
python experiments/robot/aloha/preprocess_split_aloha_data.py \
|
| 37 |
-
--dataset_path /scr/moojink/data/aloha1_raw/put_green_pepper_into_pot/ \
|
| 38 |
-
--out_base_dir /scr/moojink/data/aloha1_preprocessed/ \
|
| 39 |
-
--percent_val 0.05
|
| 40 |
-
python experiments/robot/aloha/preprocess_split_aloha_data.py \
|
| 41 |
-
--dataset_path /scr/moojink/data/aloha1_raw/put_red_pepper_into_pot/ \
|
| 42 |
-
--out_base_dir /scr/moojink/data/aloha1_preprocessed/ \
|
| 43 |
-
--percent_val 0.05
|
| 44 |
-
python experiments/robot/aloha/preprocess_split_aloha_data.py \
|
| 45 |
-
--dataset_path /scr/moojink/data/aloha1_raw/put_yellow_corn_into_pot/ \
|
| 46 |
-
--out_base_dir /scr/moojink/data/aloha1_preprocessed/ \
|
| 47 |
-
--percent_val 0.05
|
| 48 |
-
```
|
| 49 |
-
|
| 50 |
-
Then, convert the preprocessed ALOHA datasets into a single RLDS dataset that is compatible with OpenVLA fine-tuning. This process is the same as in the original OpenVLA repo. See instructions for converting to RLDS [here](https://github.com/moojink/rlds_dataset_builder) (a sample ALOHA preprocessed-to-RLDS conversion script is available [here](https://github.com/moojink/rlds_dataset_builder/blob/main/aloha1_put_X_into_pot_300_demos/aloha1_put_X_into_pot_300_demos_dataset_builder.py); this script converts the three preprocessed datasets above into one unified RLDS dataset, with train/val splits).
|
| 51 |
-
|
| 52 |
-
After converting to RLDS, register the dataset (which, for the example task above, would be called `aloha1_put_X_into_pot_300_demos`) with our dataloader by adding an entry for it in `configs.py` ([here](prismatic/vla/datasets/rlds/oxe/configs.py#L680)), `transforms.py` ([here](prismatic/vla/datasets/rlds/oxe/transforms.py#L928)), and `mixtures.py` ([here](prismatic/vla/datasets/rlds/oxe/mixtures.py#L216)). For reference, in each of these files, there are sample entries for the ALOHA datasets that we used in our paper.
|
| 53 |
-
|
| 54 |
-
Before fine-tuning, set the desired ALOHA action chunk size in [`prismatic/vla/constants.py`](prismatic/vla/constants.py) (see `NUM_ACTIONS_CHUNK` in `ALOHA_CONSTANTS`). We set it to 25 by default because we used a control frequency of 25 Hz in our ALOHA setup to reduce storage costs and training time (while still maintaining smoothness in the robot's motions). If you use 50 Hz, we recommend setting `NUM_ACTIONS_CHUNK` to `50`. In general, 1 second-long action chunks are a good default. Do NOT modify `ACTION_PROPRIO_NORMALIZATION_TYPE`: Since the ALOHA robot action space is absolute joint angles, we do not want to use a normalization scheme that clips outlier values (like the Q1-Q99 normalization we used with the relative end-effector pose actions for LIBERO), since that would prevent the model from outputting certain robot joint angles that are crucial for solving the task.
|
| 55 |
-
|
| 56 |
-
Now begin fine-tuning! Below is a sample command to fine-tune OpenVLA using our OFT+ recipe on the `put X into pot` task above ("+" in "OFT+" means FiLM is included for enhanced language grounding). Replace `X` in the first line with the number of GPUs available to you.
|
| 57 |
-
|
| 58 |
-
```bash
|
| 59 |
-
torchrun --standalone --nnodes 1 --nproc-per-node X vla-scripts/finetune.py \
|
| 60 |
-
--vla_path openvla/openvla-7b \
|
| 61 |
-
--data_root_dir /PATH/TO/RLDS/DATASETS/DIR/ \
|
| 62 |
-
--dataset_name aloha1_put_X_into_pot_300_demos \
|
| 63 |
-
--run_root_dir /YOUR/CHECKPOINTS/AND/LOG/DIR/ \
|
| 64 |
-
--use_l1_regression True \
|
| 65 |
-
--use_diffusion False \
|
| 66 |
-
--use_film True \
|
| 67 |
-
--num_images_in_input 3 \
|
| 68 |
-
--use_proprio True \
|
| 69 |
-
--batch_size 4 \
|
| 70 |
-
--learning_rate 5e-4 \
|
| 71 |
-
--num_steps_before_decay 50000 \
|
| 72 |
-
--max_steps 100005 \
|
| 73 |
-
--use_val_set True \
|
| 74 |
-
--val_freq 10000 \
|
| 75 |
-
--save_freq 10000 \
|
| 76 |
-
--save_latest_checkpoint_only False \
|
| 77 |
-
--image_aug True \
|
| 78 |
-
--lora_rank 32 \
|
| 79 |
-
--wandb_entity "YOUR_WANDB_ENTITY" \
|
| 80 |
-
--wandb_project "YOUR_WANDB_PROJECT" \
|
| 81 |
-
--run_id_note parallel_dec--25_acts_chunk--continuous_acts--L1_regression--3rd_person_img--left_right_wrist_imgs--proprio_state--film
|
| 82 |
-
```
|
| 83 |
-
|
| 84 |
-
The above training command should reproduce our OpenVLA-OFT+ results on the `put X into pot` task if `X = 8` and the 100K step checkpoint is evaluated. It will fine-tune OpenVLA using 3 input images (1 third-person image + 2 wrist camera images). Note that we use learning rate decay after a certain point (50K steps in the command above) since doing so speeds up training convergence (train L1 loss spikes down from our experience).
|
| 85 |
-
|
| 86 |
-
Best practices for fine-tuning:
|
| 87 |
-
* In general, we recommend fine-tuning until training L1 loss goes below 0.01 and starts to plateau.
|
| 88 |
-
* One way to achieve this is to fine-tune using our default learning rate of `5e-4` until the loss starts to decrease very slowly, and then decay the learning rate by 10x to `5e-5` (which should make the loss spike down) and train until the training L1 loss finally plateaus.
|
| 89 |
-
* Depending on your dataset size, you may need to adjust some hyperparameters. For example, if you use a large dataset with over 300 demos, you may need to decay the learning rate later and train for longer for best performance. Decaying too earlier can lead to a suboptimal policy.
|
| 90 |
-
* If your task does not require good langauge grounding (e.g., if there is only one language instruction), FiLM is not necessary; consider setting `--use_film False` to train fewer model parameters.
|
| 91 |
-
* Please be sure to test your policy with the same device/GPU used to train it! Otherwise, performance may drop substantially. You may be able to avoid the performance drop if you merge the LoRA weights into the base model on the downstream device used for testing (e.g., if you train on H100 and then merge on A100 before testing on A100). You can see our script [vla-scripts/merge_lora_weights_and_save.py](vla-scripts/merge_lora_weights_and_save.py) for merging the LoRA adapter into the base model offline. It's okay if you already merged LoRA weights into the base OpenVLA model during fine-tuning; you can always redownload the base model and merge again as long as you still have the LoRA adapter (`merge_lora_weights_and_save.py` will handle this for you).
|
| 92 |
-
|
| 93 |
-
If you run into any issues, please open a new GitHub issue.
|
| 94 |
-
|
| 95 |
-
## Launching ALOHA Robot Evaluations
|
| 96 |
-
|
| 97 |
-
In the primary conda environment (`openvla-oft`) which you will use to launch the VLA server, install a few packages for the server-client interface:
|
| 98 |
-
|
| 99 |
-
```bash
|
| 100 |
-
conda activate openvla-oft
|
| 101 |
-
pip install uvicorn fastapi json-numpy
|
| 102 |
-
```
|
| 103 |
-
|
| 104 |
-
On the machine that you will use to command the robot, set up a second conda environment that will be used to run the robot environment, query the VLA server, and execute actions in the environment:
|
| 105 |
-
|
| 106 |
-
```bash
|
| 107 |
-
# Create and activate client conda environment
|
| 108 |
-
conda create -n openvla-oft-aloha python=3.10 -y
|
| 109 |
-
conda activate openvla-oft-aloha
|
| 110 |
-
|
| 111 |
-
# Install PyTorch
|
| 112 |
-
# Use a command specific to your machine: https://pytorch.org/get-started/locally/
|
| 113 |
-
pip3 install torch torchvision torchaudio
|
| 114 |
-
|
| 115 |
-
# Clone openvla-oft repo and pip install to download dependencies
|
| 116 |
-
git clone https://github.com/moojink/openvla-oft.git
|
| 117 |
-
cd openvla-oft
|
| 118 |
-
pip install -e .
|
| 119 |
-
|
| 120 |
-
# Install packages needed for the ALOHA robot environment
|
| 121 |
-
pip install -r experiments/robot/aloha/requirements_aloha.txt
|
| 122 |
-
```
|
| 123 |
-
|
| 124 |
-
Launch the VLA server on the machine that has the GPU you will use to run model inference (using the `openvla-oft` conda environment). Below is a sample command for this (change as needed):
|
| 125 |
-
|
| 126 |
-
```bash
|
| 127 |
-
python vla-scripts/deploy.py \
|
| 128 |
-
--pretrained_checkpoint /PATH/TO/FINETUNED/MODEL/CHECKPOINT/DIR/ \
|
| 129 |
-
--use_l1_regression True \
|
| 130 |
-
--use_film True \
|
| 131 |
-
--num_images_in_input 3 \
|
| 132 |
-
--use_proprio True \
|
| 133 |
-
--center_crop True \
|
| 134 |
-
--unnorm_key aloha1_put_X_into_pot_300_demos
|
| 135 |
-
```
|
| 136 |
-
|
| 137 |
-
Then, run the ALOHA evaluation script. Specify the VLA server URL or IP address in the `vla_server_url` argument. Below is a sample command:
|
| 138 |
-
|
| 139 |
-
```bash
|
| 140 |
-
python experiments/robot/aloha/run_aloha_eval.py \
|
| 141 |
-
--center_crop True \
|
| 142 |
-
--num_open_loop_steps 25 \
|
| 143 |
-
--use_vla_server True \
|
| 144 |
-
--vla_server_url <URL OF VLA SERVER> \
|
| 145 |
-
--num_rollouts_planned <NUM TEST ROLLOUTS> \
|
| 146 |
-
--max_steps <MAX NUM STEPS PER ROLLOUT>
|
| 147 |
-
```
|
| 148 |
-
|
| 149 |
-
If you run into any issues, please open a new GitHub issue.
|
| 150 |
-
|
| 151 |
-
## Troubleshooting Tips
|
| 152 |
-
|
| 153 |
-
* Tip #1: If you run into a ROS error such as `ImportError: /lib/x86_64-linux-gnu/libp11-kit.so.0: undefined symbol: ffi_type_pointer, version LIBFFI_BASE_7.0`, try running the following command in your client conda environment (`openvla-oft-aloha`):
|
| 154 |
-
|
| 155 |
-
```
|
| 156 |
-
conda install -c conda-forge libffi
|
| 157 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/LIBERO.md
DELETED
|
@@ -1,130 +0,0 @@
|
|
| 1 |
-
# OpenVLA-OFT in the LIBERO Simulation Benchmark
|
| 2 |
-
|
| 3 |
-
## Relevant Files
|
| 4 |
-
|
| 5 |
-
Evaluation
|
| 6 |
-
* `experiments/robot/libero/`: LIBERO eval files
|
| 7 |
-
* `run_libero_eval.py`: LIBERO eval script
|
| 8 |
-
* `libero_utils.py`: LIBERO eval utils
|
| 9 |
-
* `experiments/robot/`: General eval utils files
|
| 10 |
-
* `openvla_utils.py`: OpenVLA-specific eval utils
|
| 11 |
-
* `robot_utils.py`: Other eval utils
|
| 12 |
-
|
| 13 |
-
Training
|
| 14 |
-
* `vla-scripts/finetune.py`: VLA fine-tuning script
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
## Setup
|
| 18 |
-
|
| 19 |
-
Set up a conda environment (see instructions in [SETUP.md](SETUP.md)).
|
| 20 |
-
|
| 21 |
-
Clone and install the [LIBERO repo](https://github.com/Lifelong-Robot-Learning/LIBERO) and required packages:
|
| 22 |
-
|
| 23 |
-
```bash
|
| 24 |
-
git clone https://github.com/Lifelong-Robot-Learning/LIBERO.git
|
| 25 |
-
pip install -e LIBERO
|
| 26 |
-
pip install -r experiments/robot/libero/libero_requirements.txt # From openvla-oft base dir
|
| 27 |
-
```
|
| 28 |
-
|
| 29 |
-
(Optional, if you plan to launch training) To download the [LIBERO datasets](https://huggingface.co/datasets/openvla/modified_libero_rlds) that we used in our fine-tuning
|
| 30 |
-
experiments, run the command below. This will download the LIBERO-Spatial, LIBERO-Object, LIBERO-Goal,
|
| 31 |
-
and LIBERO-10 datasets in RLDS data format (~10 GB total). You can use these to fine-tune OpenVLA or
|
| 32 |
-
train other methods. This step is optional since we provide pretrained OpenVLA-OFT checkpoints below.
|
| 33 |
-
Note that these are the same datasets used in the original OpenVLA project. If needed, see details on how to download the original non-RLDS datasets [here](https://github.com/openvla/openvla?tab=readme-ov-file#libero-setup).
|
| 34 |
-
```bash
|
| 35 |
-
git clone git@hf.co:datasets/openvla/modified_libero_rlds
|
| 36 |
-
```
|
| 37 |
-
|
| 38 |
-
## Launching LIBERO Evaluations
|
| 39 |
-
|
| 40 |
-
We fine-tuned OpenVLA via LoRA (r=32) with our OFT recipe on four LIBERO task suites: LIBERO-Spatial, LIBERO-Object, LIBERO-Goal, and LIBERO-10 (also called LIBERO-Long).
|
| 41 |
-
In the initial version of our paper, we trained one checkpoint for each LIBERO task suite independently. In an updated version of the paper, we conducted an additional experiment in which we trained a single policy on all four task suites combined (results for this are available in the Additional Experiments section in the Appendix). Overall, the results for the task-specific policies and the combined policy are comparable: 97.1% vs. 96.8% average success rate across the four suites, respectively.
|
| 42 |
-
|
| 43 |
-
Below are the four independently trained OpenVLA-OFT checkpoints for LIBERO:
|
| 44 |
-
* [moojink/openvla-7b-oft-finetuned-libero-spatial](https://huggingface.co/moojink/openvla-7b-oft-finetuned-libero-spatial)
|
| 45 |
-
* [moojink/openvla-7b-oft-finetuned-libero-object](https://huggingface.co/moojink/openvla-7b-oft-finetuned-libero-object)
|
| 46 |
-
* [moojink/openvla-7b-oft-finetuned-libero-goal](https://huggingface.co/moojink/openvla-7b-oft-finetuned-libero-goal)
|
| 47 |
-
* [moojink/openvla-7b-oft-finetuned-libero-10](https://huggingface.co/moojink/openvla-7b-oft-finetuned-libero-10)
|
| 48 |
-
|
| 49 |
-
Below is the OpenVLA-OFT checkpoint trained on all four task suites combined:
|
| 50 |
-
* [moojink/openvla-7b-oft-finetuned-libero-spatial-object-goal-10](https://huggingface.co/moojink/openvla-7b-oft-finetuned-libero-spatial-object-goal-10)
|
| 51 |
-
|
| 52 |
-
To start evaluations with one of the independently trained checkpoints, run one of the commands below. Each will automatically download the appropriate checkpoint listed above. You can set the `TRANSFORMERS_CACHE` and `HF_HOME` environment variable to change where the checkpoint files get cached.
|
| 53 |
-
|
| 54 |
-
```bash
|
| 55 |
-
# Launch LIBERO-Spatial evals
|
| 56 |
-
python experiments/robot/libero/run_libero_eval.py \
|
| 57 |
-
--pretrained_checkpoint moojink/openvla-7b-oft-finetuned-libero-spatial \
|
| 58 |
-
--task_suite_name libero_spatial
|
| 59 |
-
|
| 60 |
-
# Launch LIBERO-Object evals
|
| 61 |
-
python experiments/robot/libero/run_libero_eval.py \
|
| 62 |
-
--pretrained_checkpoint moojink/openvla-7b-oft-finetuned-libero-object \
|
| 63 |
-
--task_suite_name libero_object
|
| 64 |
-
|
| 65 |
-
# Launch LIBERO-Goal evals
|
| 66 |
-
python experiments/robot/libero/run_libero_eval.py \
|
| 67 |
-
--pretrained_checkpoint moojink/openvla-7b-oft-finetuned-libero-goal \
|
| 68 |
-
--task_suite_name libero_goal
|
| 69 |
-
|
| 70 |
-
# Launch LIBERO-10 (LIBERO-Long) evals
|
| 71 |
-
python experiments/robot/libero/run_libero_eval.py \
|
| 72 |
-
--pretrained_checkpoint moojink/openvla-7b-oft-finetuned-libero-10 \
|
| 73 |
-
--task_suite_name libero_10
|
| 74 |
-
```
|
| 75 |
-
|
| 76 |
-
To evaluate the policy trained on all four task suites together, simply swap out the `--pretrained_checkpoint` in the commands above with `moojink/openvla-7b-oft-finetuned-libero-spatial-object-goal-10`.
|
| 77 |
-
|
| 78 |
-
Notes:
|
| 79 |
-
* The evaluation script will run 500 trials by default (10 tasks x 50 episodes each). You can modify the number of
|
| 80 |
-
trials per task by setting `--num_trials_per_task`. You can also change the random seed via `--seed`. There are
|
| 81 |
-
other arguments in the script; we set them to the default values that work with the OpenVLA-OFT checkpoints above.
|
| 82 |
-
* **NOTE: Setting `--center_crop True` is important** because we fine-tuned OpenVLA with random crop augmentations
|
| 83 |
-
(we took a random crop with 90% area in every training sample, so at test time we simply take the center 90% crop).
|
| 84 |
-
* The evaluation script logs results locally. You can also log results in Weights & Biases
|
| 85 |
-
by setting `--use_wandb True` and specifying `--wandb_project <PROJECT>` and `--wandb_entity <ENTITY>`.
|
| 86 |
-
* The results reported in our paper were obtained using **Python 3.10.14, PyTorch 2.2.0, and our
|
| 87 |
-
[custom transformers v4.40.1 fork](https://github.com/moojink/transformers-openvla-oft.git)**
|
| 88 |
-
on an **NVIDIA A100 GPU**, averaged over three random seeds. Please stick to these package versions if possible.
|
| 89 |
-
Note that results may vary slightly if you use a different GPU than the A100. If the discrepancy is large,
|
| 90 |
-
please post a GitHub issue, and we will look into it.
|
| 91 |
-
|
| 92 |
-
## Fine-Tuning on LIBERO Datasets
|
| 93 |
-
|
| 94 |
-
First, download the LIBERO datasets as mentioned above in the Setup section above: `libero_spatial_no_noops`, `libero_object_no_noops`, `libero_goal_no_noops`, `libero_10_no_noops`. (`"_no_noops"` stands for no no-op actions, i.e., training samples with near-zero actions are filtered out).
|
| 95 |
-
|
| 96 |
-
Then, launch the fine-tuning script with the OFT configuration below, replacing `X` in the first line with the number of GPUs. The command below launches fine-tuning on LIBERO-Spatial with the hyperparameters that we used in our paper. Here, batch size 8 per GPU will require ~62 GB VRAM, and batch size 1 per GPU will require ~25 GB VRAM.
|
| 97 |
-
|
| 98 |
-
```bash
|
| 99 |
-
torchrun --standalone --nnodes 1 --nproc-per-node X vla-scripts/finetune.py \
|
| 100 |
-
--vla_path openvla/openvla-7b \
|
| 101 |
-
--data_root_dir /PATH/TO/RLDS/DATASETS/DIR/ \
|
| 102 |
-
--dataset_name libero_spatial_no_noops \
|
| 103 |
-
--run_root_dir /YOUR/CHECKPOINTS/AND/LOG/DIR/ \
|
| 104 |
-
--use_l1_regression True \
|
| 105 |
-
--use_diffusion False \
|
| 106 |
-
--use_film False \
|
| 107 |
-
--num_images_in_input 2 \
|
| 108 |
-
--use_proprio True \
|
| 109 |
-
--batch_size 8 \
|
| 110 |
-
--learning_rate 5e-4 \
|
| 111 |
-
--num_steps_before_decay 100000 \
|
| 112 |
-
--max_steps 150005 \
|
| 113 |
-
--save_freq 10000 \
|
| 114 |
-
--save_latest_checkpoint_only False \
|
| 115 |
-
--image_aug True \
|
| 116 |
-
--lora_rank 32 \
|
| 117 |
-
--wandb_entity "YOUR_WANDB_ENTITY" \
|
| 118 |
-
--wandb_project "YOUR_WANDB_PROJECT" \
|
| 119 |
-
--run_id_note parallel_dec--8_acts_chunk--continuous_acts--L1_regression--3rd_person_img--wrist_img--proprio_state
|
| 120 |
-
```
|
| 121 |
-
|
| 122 |
-
The above training command should reproduce our OpenVLA-OFT results if `X = 8` and the 150K step checkpoint is evaluated.
|
| 123 |
-
|
| 124 |
-
You can replace `libero_spatial_no_noops` with `libero_object_no_noops`, `libero_goal_no_noops`, or `libero_10_no_noops`. You can also modify other args — e.g., if you want to train with just one input image from the third-person camera and disable proprio state input, you can set `--num_images_in_input 1` and `--use_proprio False`.
|
| 125 |
-
|
| 126 |
-
In general, we recommend fine-tuning until training L1 loss goes below 0.01 and starts to plateau (with the above configuration, it should reach ~0.006 L1 loss on LIBERO-Spatial after 150K gradient steps with 10x LR decay after 100K steps). However, for LIBERO-Goal only, we found that the 50K checkpoint (which was at ~0.02 L1 loss) performed best for unknown reasons. For all other task suites though, we found that the 150K checkpoint performed best.
|
| 127 |
-
|
| 128 |
-
Please be sure to test your policy with the same device/GPU used to train it! Otherwise, performance may drop substantially. You may be able to avoid the performance drop if you merge the LoRA weights into the base model on the downstream device used for testing (e.g., if you train on H100 and then merge on A100 before testing on A100). You can see our script [vla-scripts/merge_lora_weights_and_save.py](vla-scripts/merge_lora_weights_and_save.py) for merging the LoRA adapter into the base model offline. It's okay if you already merged LoRA weights into the base OpenVLA model during fine-tuning; you can always redownload the base model and merge again as long as you still have the LoRA adapter (`merge_lora_weights_and_save.py` will handle this for you).
|
| 129 |
-
|
| 130 |
-
If you run into any issues, please open a new GitHub issue. If you do not receive a response within 2 business days, please email Moo Jin Kim (moojink@cs.stanford.edu) to bring the issue to his attention.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/LICENSE
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
MIT License
|
| 2 |
-
|
| 3 |
-
Copyright (c) 2025 Moo Jin Kim, Chelsea Finn, Percy Liang.
|
| 4 |
-
|
| 5 |
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
-
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
-
in the Software without restriction, including without limitation the rights
|
| 8 |
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
-
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
-
furnished to do so, subject to the following conditions:
|
| 11 |
-
|
| 12 |
-
The above copyright notice and this permission notice shall be included in all
|
| 13 |
-
copies or substantial portions of the Software.
|
| 14 |
-
|
| 15 |
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
-
SOFTWARE.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/SETUP.md
DELETED
|
@@ -1,24 +0,0 @@
|
|
| 1 |
-
# Setup Instructions
|
| 2 |
-
|
| 3 |
-
## Set Up Conda Environment
|
| 4 |
-
|
| 5 |
-
```bash
|
| 6 |
-
# Create and activate conda environment
|
| 7 |
-
conda create -n capvector-openvla-oft python=3.10 -y
|
| 8 |
-
conda activate capvector-openvla-oft
|
| 9 |
-
|
| 10 |
-
# Install PyTorch
|
| 11 |
-
# Use a command specific to your machine: https://pytorch.org/get-started/locally/
|
| 12 |
-
pip3 install torch torchvision torchaudio
|
| 13 |
-
|
| 14 |
-
# Clone openvla-oft repo and pip install to download dependencies
|
| 15 |
-
git clone https://github.com/Songwxuan/CapVector
|
| 16 |
-
cd openvla-oft
|
| 17 |
-
pip install -e .
|
| 18 |
-
|
| 19 |
-
# Install Flash Attention 2 for training (https://github.com/Dao-AILab/flash-attention)
|
| 20 |
-
# =>> If you run into difficulty, try `pip cache remove flash_attn` first
|
| 21 |
-
pip install packaging ninja
|
| 22 |
-
ninja --version; echo $? # Verify Ninja --> should return exit code "0"
|
| 23 |
-
pip install "flash-attn==2.5.5" --no-build-isolation
|
| 24 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/capvector/.gitignore
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
bin/
|
| 2 |
-
draw_pic/
|
| 3 |
-
feature_vector_ckpt/
|
| 4 |
-
figure/
|
| 5 |
-
id_extrapolation/
|
| 6 |
-
id_interpolation/
|
| 7 |
-
initialized_pt_vla/
|
| 8 |
-
lora_diff/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/capvector/compute_lora_diff.py
DELETED
|
@@ -1,35 +0,0 @@
|
|
| 1 |
-
from safetensors.torch import load_file, save_file
|
| 2 |
-
import torch
|
| 3 |
-
import argparse
|
| 4 |
-
|
| 5 |
-
def main():
|
| 6 |
-
parser = argparse.ArgumentParser()
|
| 7 |
-
parser.add_argument("--base", required=True)
|
| 8 |
-
parser.add_argument("--target", required=True)
|
| 9 |
-
parser.add_argument("--out", default="lora_diff.safetensors")
|
| 10 |
-
args = parser.parse_args()
|
| 11 |
-
|
| 12 |
-
base = load_file(args.base)
|
| 13 |
-
target = load_file(args.target)
|
| 14 |
-
|
| 15 |
-
diff = {}
|
| 16 |
-
|
| 17 |
-
print("=== Key Comparison ===")
|
| 18 |
-
only_in_base = set(base) - set(target)
|
| 19 |
-
only_in_target = set(target) - set(base)
|
| 20 |
-
|
| 21 |
-
print("Only in base:", list(only_in_base)[:10])
|
| 22 |
-
print("Only in target:", list(only_in_target)[:10])
|
| 23 |
-
|
| 24 |
-
for k in target:
|
| 25 |
-
if k in base:
|
| 26 |
-
diff[k] = target[k] - base[k]
|
| 27 |
-
else:
|
| 28 |
-
# new parameters are directly retained
|
| 29 |
-
diff[k] = target[k].clone()
|
| 30 |
-
|
| 31 |
-
save_file(diff, args.out)
|
| 32 |
-
print(f"\nSaved diff to: {args.out}")
|
| 33 |
-
|
| 34 |
-
if __name__ == "__main__":
|
| 35 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/capvector/compute_lora_shell/compute_lora_diff.sh
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
BASE_ADAPTER="checkpoints/reference_models/openvla_oft_libero_spatial/lora_adapter/adapter_model.safetensors"
|
| 2 |
-
TARGET_ADAPTER="checkpoints/task_models/SF_spatial/lora_adapter/adapter_model.safetensors"
|
| 3 |
-
OUTPUT_DIFF="checkpoints/lora_diff/sf_150000_steps_spatial_adapter_diff.safetensors"
|
| 4 |
-
|
| 5 |
-
python compute_lora_diff.py \
|
| 6 |
-
--base "$BASE_ADAPTER" \
|
| 7 |
-
--target "$TARGET_ADAPTER" \
|
| 8 |
-
--out "$OUTPUT_DIFF"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/capvector/initialized_interpolate_shell/get_vector_robotwin.sh
DELETED
|
@@ -1,26 +0,0 @@
|
|
| 1 |
-
TASK=bigbin_pot_microwave_qrcode_bowlsthree # Customize for your task
|
| 2 |
-
VERSION=53
|
| 3 |
-
PT_CKPT="checkpoints/openvla_base"
|
| 4 |
-
TASK_MODEL_CHECKPOINT="checkpoints/task_models/v106.1"
|
| 5 |
-
REFERENCE_MODEL_CHECKPOINT="checkpoints/reference_models/v106.0"
|
| 6 |
-
VECTOR_SAVE_PATH="checkpoints/feature_vectors/feature_vector_with_SF_${TASK}_v${VERSION}.pth"
|
| 7 |
-
INITIALIZED_PT_VLA_PATH="checkpoints/initialized_pt_vla/initailized_openvla_with_SF_${TASK}_v${VERSION}"
|
| 8 |
-
TASK_SUITE_NAME="ALOHA_${TASK}"
|
| 9 |
-
|
| 10 |
-
python interpolate_robotwin.py \
|
| 11 |
-
--pretrained_checkpoint "$TASK_MODEL_CHECKPOINT" \
|
| 12 |
-
--original_pretrained_checkpoint "$REFERENCE_MODEL_CHECKPOINT" \
|
| 13 |
-
--vector_save_path "$VECTOR_SAVE_PATH" \
|
| 14 |
-
--initialized_pt_vla_path $INITIALIZED_PT_VLA_PATH \
|
| 15 |
-
--pt_ckpt $PT_CKPT\
|
| 16 |
-
--feature_vector_weight 1.1 \
|
| 17 |
-
--task_suite_name $TASK_SUITE_NAME
|
| 18 |
-
|
| 19 |
-
#the code below is used to transplant the model parameters except for the vla backbone, such as processor and tokenizer, so that the initialized model is complete
|
| 20 |
-
|
| 21 |
-
rsync -av \
|
| 22 |
-
--ignore-existing \
|
| 23 |
-
--exclude='*.safetensors' \
|
| 24 |
-
--exclude='*.back.*' \
|
| 25 |
-
$PT_CKPT/ \
|
| 26 |
-
$INITIALIZED_PT_VLA_PATH/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/capvector/interpolate.py
DELETED
|
@@ -1,247 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
This is for extracting feature vector from the openvla-oft model and interpolating it with the original openvla model.
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
import os
|
| 7 |
-
import json
|
| 8 |
-
import logging
|
| 9 |
-
|
| 10 |
-
import sys
|
| 11 |
-
from collections import deque
|
| 12 |
-
from dataclasses import dataclass
|
| 13 |
-
from enum import Enum
|
| 14 |
-
from pathlib import Path
|
| 15 |
-
from typing import Optional, Union
|
| 16 |
-
from PIL import Image
|
| 17 |
-
|
| 18 |
-
import draccus
|
| 19 |
-
import numpy as np
|
| 20 |
-
from tqdm import tqdm
|
| 21 |
-
import torch
|
| 22 |
-
import copy
|
| 23 |
-
|
| 24 |
-
import wandb
|
| 25 |
-
|
| 26 |
-
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 27 |
-
if str(REPO_ROOT) not in sys.path:
|
| 28 |
-
sys.path.append(str(REPO_ROOT))
|
| 29 |
-
from experiments.robot.openvla_utils import (
|
| 30 |
-
get_action_head,
|
| 31 |
-
get_noisy_action_projector,
|
| 32 |
-
get_processor,
|
| 33 |
-
get_proprio_projector,
|
| 34 |
-
resize_image_for_policy,
|
| 35 |
-
)
|
| 36 |
-
from experiments.robot.robot_utils import (
|
| 37 |
-
DATE_TIME,
|
| 38 |
-
get_action,
|
| 39 |
-
get_image_resize_size,
|
| 40 |
-
get_model,
|
| 41 |
-
invert_gripper_action,
|
| 42 |
-
normalize_gripper_action,
|
| 43 |
-
set_seed_everywhere,
|
| 44 |
-
)
|
| 45 |
-
from experiments.robot.libero.run_libero_eval import check_unnorm_key
|
| 46 |
-
from prismatic.vla.constants import NUM_ACTIONS_CHUNK
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
# Set up logging
|
| 50 |
-
logging.basicConfig(
|
| 51 |
-
level=logging.INFO,
|
| 52 |
-
format="%(asctime)s [%(levelname)s] %(message)s",
|
| 53 |
-
handlers=[logging.StreamHandler()],
|
| 54 |
-
)
|
| 55 |
-
logger = logging.getLogger(__name__)
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
@dataclass
|
| 59 |
-
class GenerateConfig:
|
| 60 |
-
# fmt: off
|
| 61 |
-
|
| 62 |
-
#################################################################################################################
|
| 63 |
-
# Model-specific parameters
|
| 64 |
-
#################################################################################################################
|
| 65 |
-
model_family: str = "openvla" # Model family
|
| 66 |
-
#the task-specific model after sf fine-tuning
|
| 67 |
-
pretrained_checkpoint: Union[str, Path] = "checkpoints/task_model" # Task-specific checkpoint path
|
| 68 |
-
#the task-specific model after oft fine-tuning
|
| 69 |
-
original_pretrained_checkpoint: Union[str, Path] = "checkpoints/reference_model" # Reference checkpoint path
|
| 70 |
-
#feature vector is the difference between the two models, which represents the spatial features
|
| 71 |
-
vector_save_path: Union[str, Path] = "checkpoints/feature_vectors/feature_vector.pth"
|
| 72 |
-
#the pt vla model initialized with the feature vector, named rule: initailized_{pt_ckpt}_with_{task-specific model name}_${task name on libero}
|
| 73 |
-
initialized_pt_vla_path: Union[str, Path] = "checkpoints/initialized_pt_vla"
|
| 74 |
-
#the original pretrained openvla model
|
| 75 |
-
pt_ckpt: Union[str, Path] = "checkpoints/openvla_base"
|
| 76 |
-
#the weight of the feature vector when initializing the pt vla model
|
| 77 |
-
feature_vector_weight: float = 1 # Weight of feature vector for interpolation
|
| 78 |
-
|
| 79 |
-
use_l1_regression: bool = True # If True, uses continuous action head with L1 regression objective
|
| 80 |
-
use_diffusion: bool = False # If True, uses continuous action head with diffusion modeling objective (DDIM)
|
| 81 |
-
num_diffusion_steps_train: int = 50 # (When `diffusion==True`) Number of diffusion steps used for training
|
| 82 |
-
num_diffusion_steps_inference: int = 50 # (When `diffusion==True`) Number of diffusion steps used for inference
|
| 83 |
-
use_film: bool = False # If True, uses FiLM to infuse language inputs into visual features
|
| 84 |
-
num_images_in_input: int = 2 # Number of images in the VLA input (default: 1)
|
| 85 |
-
use_proprio: bool = True # Whether to include proprio state in input
|
| 86 |
-
|
| 87 |
-
center_crop: bool = True # Center crop? (if trained w/ random crop image aug)
|
| 88 |
-
num_open_loop_steps: int = 8 # Number of actions to execute open-loop before requerying policy
|
| 89 |
-
|
| 90 |
-
lora_rank: int = 32 # Rank of LoRA weight matrix (MAKE SURE THIS MATCHES TRAINING!)
|
| 91 |
-
|
| 92 |
-
unnorm_key: Union[str, Path] = "" # Action un-normalization key
|
| 93 |
-
|
| 94 |
-
load_in_8bit: bool = False # (For OpenVLA only) Load with 8-bit quantization
|
| 95 |
-
load_in_4bit: bool = False # (For OpenVLA only) Load with 4-bit quantization
|
| 96 |
-
|
| 97 |
-
#################################################################################################################
|
| 98 |
-
# LIBERO environment-specific parameters
|
| 99 |
-
#################################################################################################################
|
| 100 |
-
task_suite_name: str = "de" # Task suite
|
| 101 |
-
num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize in sim
|
| 102 |
-
num_trials_per_task: int = 50 # Number of rollouts per task
|
| 103 |
-
initial_states_path: str = "DEFAULT" # "DEFAULT", or path to initial states JSON file
|
| 104 |
-
env_img_res: int = 256 # Resolution for environment images (not policy input resolution)
|
| 105 |
-
|
| 106 |
-
#################################################################################################################
|
| 107 |
-
# Utils
|
| 108 |
-
#################################################################################################################
|
| 109 |
-
run_id_note: Optional[str] = None # Extra note to add to end of run ID for logging
|
| 110 |
-
local_log_dir: str = "./experiments/logs" # Local directory for eval logs
|
| 111 |
-
|
| 112 |
-
use_wandb: bool = False # Whether to also log results in Weights & Biases
|
| 113 |
-
wandb_entity: str = "your-wandb-entity" # Name of WandB entity
|
| 114 |
-
wandb_project: str = "your-wandb-project" # Name of WandB project
|
| 115 |
-
|
| 116 |
-
seed: int = 7 # Random Seed (for reproducibility)
|
| 117 |
-
|
| 118 |
-
def validate_config(cfg: GenerateConfig) -> None:
|
| 119 |
-
"""Validate configuration parameters."""
|
| 120 |
-
assert cfg.pretrained_checkpoint is not None, "pretrained_checkpoint must not be None!"
|
| 121 |
-
|
| 122 |
-
if "image_aug" in str(cfg.pretrained_checkpoint):
|
| 123 |
-
assert cfg.center_crop, "Expecting `center_crop==True` because model was trained with image augmentations!"
|
| 124 |
-
|
| 125 |
-
assert not (cfg.load_in_8bit and cfg.load_in_4bit), "Cannot use both 8-bit and 4-bit quantization!"
|
| 126 |
-
|
| 127 |
-
# Validate task suite
|
| 128 |
-
assert cfg.task_suite_name in [suite.value for suite in TaskSuite], f"Invalid task suite: {cfg.task_suite_name}"
|
| 129 |
-
|
| 130 |
-
def initialize_model(cfg: GenerateConfig, only_pt: bool = False): #load action_head and noisy_action_projector separately
|
| 131 |
-
"""Initialize model and associated components."""
|
| 132 |
-
# Load model
|
| 133 |
-
model = get_model(cfg)
|
| 134 |
-
|
| 135 |
-
# Load proprio projector if needed
|
| 136 |
-
proprio_projector = None
|
| 137 |
-
if cfg.use_proprio:
|
| 138 |
-
proprio_projector = get_proprio_projector(
|
| 139 |
-
cfg,
|
| 140 |
-
model.llm_dim,
|
| 141 |
-
proprio_dim=8, # 8-dimensional proprio for LIBERO
|
| 142 |
-
)
|
| 143 |
-
|
| 144 |
-
# Load action head if needed
|
| 145 |
-
action_head = None
|
| 146 |
-
if cfg.use_l1_regression or cfg.use_diffusion:
|
| 147 |
-
action_head = get_action_head(cfg, model.llm_dim)
|
| 148 |
-
|
| 149 |
-
# Load noisy action projector if using diffusion
|
| 150 |
-
noisy_action_projector = None
|
| 151 |
-
if cfg.use_diffusion:
|
| 152 |
-
noisy_action_projector = get_noisy_action_projector(cfg, model.llm_dim)
|
| 153 |
-
|
| 154 |
-
# Get OpenVLA processor if needed
|
| 155 |
-
processor = None
|
| 156 |
-
if not only_pt:
|
| 157 |
-
if cfg.model_family == "openvla":
|
| 158 |
-
processor = get_processor(cfg)
|
| 159 |
-
check_unnorm_key(cfg, model)
|
| 160 |
-
|
| 161 |
-
return model, action_head, proprio_projector, noisy_action_projector, processor
|
| 162 |
-
|
| 163 |
-
# @draccus.wrap()
|
| 164 |
-
def generate_feature_vector(cfg: GenerateConfig):
|
| 165 |
-
"""Generate a feature vector (parameter differences) between two task-specific models."""
|
| 166 |
-
# Validate configuration
|
| 167 |
-
|
| 168 |
-
# Set random seed
|
| 169 |
-
set_seed_everywhere(cfg.seed)
|
| 170 |
-
|
| 171 |
-
# Initialize model and components
|
| 172 |
-
model, action_head, proprio_projector, noisy_action_projector, processor = initialize_model(cfg)
|
| 173 |
-
|
| 174 |
-
original_config = GenerateConfig(
|
| 175 |
-
pretrained_checkpoint=cfg.original_pretrained_checkpoint,
|
| 176 |
-
task_suite_name=cfg.task_suite_name,
|
| 177 |
-
)
|
| 178 |
-
|
| 179 |
-
original_model, original_action_head, original_proprio_projector, original_noisy_action_projector, original_processor = initialize_model(original_config)
|
| 180 |
-
#for action_head and noisy_action_projector, these modules are not interpolated
|
| 181 |
-
assert len(model.state_dict()) == len(original_model.state_dict())
|
| 182 |
-
feature_vector_dict = {}
|
| 183 |
-
total = len(original_model.state_dict())
|
| 184 |
-
for name, original_model_param in tqdm(original_model.named_parameters(), total=total):
|
| 185 |
-
model_param = model.state_dict()[name]
|
| 186 |
-
feature_vector_dict[name] = (model_param - original_model_param).detach().cpu()
|
| 187 |
-
|
| 188 |
-
return feature_vector_dict
|
| 189 |
-
|
| 190 |
-
# @draccus.wrap()
|
| 191 |
-
def interpolate_feature_vector(cfg: GenerateConfig):
|
| 192 |
-
"""Interpolate feature vector."""
|
| 193 |
-
feature_vector_dict = torch.load(cfg.vector_save_path)
|
| 194 |
-
|
| 195 |
-
pt_vla_config = GenerateConfig(
|
| 196 |
-
pretrained_checkpoint=cfg.pt_ckpt,
|
| 197 |
-
original_pretrained_checkpoint=cfg.original_pretrained_checkpoint,
|
| 198 |
-
vector_save_path=cfg.vector_save_path,
|
| 199 |
-
initialized_pt_vla_path=cfg.initialized_pt_vla_path,
|
| 200 |
-
feature_vector_weight=cfg.feature_vector_weight,
|
| 201 |
-
pt_ckpt=cfg.pt_ckpt,
|
| 202 |
-
task_suite_name=cfg.task_suite_name,
|
| 203 |
-
use_proprio=False,
|
| 204 |
-
use_l1_regression=False,
|
| 205 |
-
use_diffusion=False
|
| 206 |
-
)
|
| 207 |
-
|
| 208 |
-
pt_vla,_,_,_,_ = initialize_model(pt_vla_config, only_pt=True)
|
| 209 |
-
|
| 210 |
-
#copy the SF parameters for checking the change before and after interpolation
|
| 211 |
-
model_sd = pt_vla.state_dict()
|
| 212 |
-
before_interp_sd = {k: v.clone() for k, v in model_sd.items() if v.dtype.is_floating_point}
|
| 213 |
-
|
| 214 |
-
with torch.no_grad():
|
| 215 |
-
pt_params = dict(pt_vla.named_parameters())
|
| 216 |
-
for name, diff in feature_vector_dict.items():
|
| 217 |
-
if name in pt_params:
|
| 218 |
-
pt_param = pt_params[name]
|
| 219 |
-
diff = diff.to(pt_param.device)
|
| 220 |
-
pt_param.add_(diff, alpha=cfg.feature_vector_weight)
|
| 221 |
-
|
| 222 |
-
#check after interpolation
|
| 223 |
-
diffs_after = []
|
| 224 |
-
for name, before_tensor in before_interp_sd.items():
|
| 225 |
-
after_tensor = model_sd[name]
|
| 226 |
-
difference = (after_tensor - before_tensor).float().norm().item()
|
| 227 |
-
diffs_after.append(difference)
|
| 228 |
-
|
| 229 |
-
print(f"[DEBUG] post-interp (SF -> interp): mean={sum(diffs_after)/len(diffs_after):.6f}, "
|
| 230 |
-
f"max={max(diffs_after):.6f}, num_tensors={len(diffs_after)}")
|
| 231 |
-
|
| 232 |
-
#########################################################
|
| 233 |
-
return pt_vla
|
| 234 |
-
|
| 235 |
-
@draccus.wrap()
|
| 236 |
-
def main(cfg: GenerateConfig):
|
| 237 |
-
if not os.path.exists(cfg.vector_save_path):
|
| 238 |
-
feature_vector_dict = generate_feature_vector(cfg)
|
| 239 |
-
torch.save(feature_vector_dict, cfg.vector_save_path)
|
| 240 |
-
else:
|
| 241 |
-
print(f"Feature vector already exists at {cfg.vector_save_path}")
|
| 242 |
-
initialized_pt_vla = interpolate_feature_vector(cfg)
|
| 243 |
-
os.makedirs(cfg.initialized_pt_vla_path, exist_ok=True)
|
| 244 |
-
initialized_pt_vla.save_pretrained(cfg.initialized_pt_vla_path)
|
| 245 |
-
|
| 246 |
-
if __name__ == "__main__":
|
| 247 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/capvector/interpolate.sh
DELETED
|
@@ -1,26 +0,0 @@
|
|
| 1 |
-
TASK=spatial # or object / goal / 10 / 90
|
| 2 |
-
VERSION=21.4
|
| 3 |
-
PT_CKPT="checkpoints/openvla_base"
|
| 4 |
-
TASK_MODEL_CHECKPOINT="checkpoints/task_models/SF_${TASK}"
|
| 5 |
-
REFERENCE_MODEL_CHECKPOINT="checkpoints/reference_models/openvla_oft_libero_${TASK}"
|
| 6 |
-
VECTOR_SAVE_PATH="checkpoints/feature_vectors/feature_vector_with_SF_${TASK}_v${VERSION}.pth"
|
| 7 |
-
INITIALIZED_PT_VLA_PATH="checkpoints/initialized_pt_vla/initailized_openvla_with_SF_${TASK}_v${VERSION}"
|
| 8 |
-
TASK_SUITE_NAME="libero_${TASK}"
|
| 9 |
-
|
| 10 |
-
python interpolate.py \
|
| 11 |
-
--pretrained_checkpoint "$TASK_MODEL_CHECKPOINT" \
|
| 12 |
-
--original_pretrained_checkpoint "$REFERENCE_MODEL_CHECKPOINT" \
|
| 13 |
-
--vector_save_path "$VECTOR_SAVE_PATH" \
|
| 14 |
-
--initialized_pt_vla_path $INITIALIZED_PT_VLA_PATH \
|
| 15 |
-
--pt_ckpt $PT_CKPT\
|
| 16 |
-
--feature_vector_weight 0.5 \
|
| 17 |
-
--task_suite_name $TASK_SUITE_NAME
|
| 18 |
-
|
| 19 |
-
#the code below is used to transplant the model parameters except for the vla backbone, such as processor and tokenizer, so that the initialized model is complete
|
| 20 |
-
|
| 21 |
-
rsync -av \
|
| 22 |
-
--ignore-existing \
|
| 23 |
-
--exclude='*.safetensors' \
|
| 24 |
-
--exclude='*.back.*' \
|
| 25 |
-
$PT_CKPT/ \
|
| 26 |
-
$INITIALIZED_PT_VLA_PATH/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/capvector/interpolate_robotwin.py
DELETED
|
@@ -1,247 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
This is for extracting feature vector from the openvla-oft model and interpolating it with the original openvla model.
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
import os
|
| 7 |
-
import json
|
| 8 |
-
import logging
|
| 9 |
-
|
| 10 |
-
import sys
|
| 11 |
-
from collections import deque
|
| 12 |
-
from dataclasses import dataclass
|
| 13 |
-
from enum import Enum
|
| 14 |
-
from pathlib import Path
|
| 15 |
-
from typing import Optional, Union
|
| 16 |
-
from PIL import Image
|
| 17 |
-
|
| 18 |
-
import draccus
|
| 19 |
-
import numpy as np
|
| 20 |
-
from tqdm import tqdm
|
| 21 |
-
import torch
|
| 22 |
-
import copy
|
| 23 |
-
|
| 24 |
-
import wandb
|
| 25 |
-
|
| 26 |
-
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 27 |
-
if str(REPO_ROOT) not in sys.path:
|
| 28 |
-
sys.path.append(str(REPO_ROOT))
|
| 29 |
-
from experiments.robot.openvla_utils import (
|
| 30 |
-
get_action_head,
|
| 31 |
-
get_noisy_action_projector,
|
| 32 |
-
get_processor,
|
| 33 |
-
get_proprio_projector,
|
| 34 |
-
resize_image_for_policy,
|
| 35 |
-
)
|
| 36 |
-
from experiments.robot.robot_utils import (
|
| 37 |
-
DATE_TIME,
|
| 38 |
-
get_action,
|
| 39 |
-
get_image_resize_size,
|
| 40 |
-
get_model,
|
| 41 |
-
invert_gripper_action,
|
| 42 |
-
normalize_gripper_action,
|
| 43 |
-
set_seed_everywhere,
|
| 44 |
-
)
|
| 45 |
-
from experiments.robot.libero.run_libero_eval import check_unnorm_key
|
| 46 |
-
from prismatic.vla.constants import NUM_ACTIONS_CHUNK
|
| 47 |
-
from prismatic.vla.constants import PROPRIO_DIM
|
| 48 |
-
|
| 49 |
-
# Set up logging
|
| 50 |
-
logging.basicConfig(
|
| 51 |
-
level=logging.INFO,
|
| 52 |
-
format="%(asctime)s [%(levelname)s] %(message)s",
|
| 53 |
-
handlers=[logging.StreamHandler()],
|
| 54 |
-
)
|
| 55 |
-
logger = logging.getLogger(__name__)
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
@dataclass
|
| 59 |
-
class GenerateConfig:
|
| 60 |
-
# fmt: off
|
| 61 |
-
|
| 62 |
-
#################################################################################################################
|
| 63 |
-
# Model-specific parameters
|
| 64 |
-
#################################################################################################################
|
| 65 |
-
model_family: str = "openvla" # Model family
|
| 66 |
-
#the task-specific model after sf fine-tuning
|
| 67 |
-
pretrained_checkpoint: Union[str, Path] = "checkpoints/task_model" # Task-specific checkpoint path
|
| 68 |
-
#the task-specific model after oft fine-tuning
|
| 69 |
-
original_pretrained_checkpoint: Union[str, Path] = "checkpoints/reference_model" # Reference checkpoint path
|
| 70 |
-
#feature vector is the difference between the two models, which represents the spatial features
|
| 71 |
-
vector_save_path: Union[str, Path] = "checkpoints/feature_vectors/feature_vector.pth"
|
| 72 |
-
#the pt vla model initialized with the feature vector, named rule: initailized_{pt_ckpt}_with_{task-specific model name}_${task name on libero}
|
| 73 |
-
initialized_pt_vla_path: Union[str, Path] = "checkpoints/initialized_pt_vla"
|
| 74 |
-
#the original pretrained openvla model
|
| 75 |
-
pt_ckpt: Union[str, Path] = "checkpoints/openvla_base"
|
| 76 |
-
#the weight of the feature vector when initializing the pt vla model
|
| 77 |
-
feature_vector_weight: float = 1 # Weight of feature vector for interpolation
|
| 78 |
-
|
| 79 |
-
use_l1_regression: bool = True # If True, uses continuous action head with L1 regression objective
|
| 80 |
-
use_diffusion: bool = False # If True, uses continuous action head with diffusion modeling objective (DDIM)
|
| 81 |
-
num_diffusion_steps_train: int = 50 # (When `diffusion==True`) Number of diffusion steps used for training
|
| 82 |
-
num_diffusion_steps_inference: int = 50 # (When `diffusion==True`) Number of diffusion steps used for inference
|
| 83 |
-
use_film: bool = False # If True, uses FiLM to infuse language inputs into visual features
|
| 84 |
-
num_images_in_input: int = 3 # Number of images in the VLA input (default: 1)
|
| 85 |
-
use_proprio: bool = True # Whether to include proprio state in input
|
| 86 |
-
|
| 87 |
-
center_crop: bool = True # Center crop? (if trained w/ random crop image aug)
|
| 88 |
-
num_open_loop_steps: int = 8 # Number of actions to execute open-loop before requerying policy
|
| 89 |
-
|
| 90 |
-
lora_rank: int = 32 # Rank of LoRA weight matrix (MAKE SURE THIS MATCHES TRAINING!)
|
| 91 |
-
|
| 92 |
-
unnorm_key: Union[str, Path] = "" # Action un-normalization key
|
| 93 |
-
|
| 94 |
-
load_in_8bit: bool = False # (For OpenVLA only) Load with 8-bit quantization
|
| 95 |
-
load_in_4bit: bool = False # (For OpenVLA only) Load with 4-bit quantization
|
| 96 |
-
|
| 97 |
-
#################################################################################################################
|
| 98 |
-
# LIBERO environment-specific parameters
|
| 99 |
-
#################################################################################################################
|
| 100 |
-
task_suite_name: str = "de" # Task suite
|
| 101 |
-
num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize in sim
|
| 102 |
-
num_trials_per_task: int = 50 # Number of rollouts per task
|
| 103 |
-
initial_states_path: str = "DEFAULT" # "DEFAULT", or path to initial states JSON file
|
| 104 |
-
env_img_res: int = 256 # Resolution for environment images (not policy input resolution)
|
| 105 |
-
|
| 106 |
-
#################################################################################################################
|
| 107 |
-
# Utils
|
| 108 |
-
#################################################################################################################
|
| 109 |
-
run_id_note: Optional[str] = None # Extra note to add to end of run ID for logging
|
| 110 |
-
local_log_dir: str = "./experiments/logs" # Local directory for eval logs
|
| 111 |
-
|
| 112 |
-
use_wandb: bool = False # Whether to also log results in Weights & Biases
|
| 113 |
-
wandb_entity: str = "your-wandb-entity" # Name of WandB entity
|
| 114 |
-
wandb_project: str = "your-wandb-project" # Name of WandB project
|
| 115 |
-
|
| 116 |
-
seed: int = 7 # Random Seed (for reproducibility)
|
| 117 |
-
|
| 118 |
-
def validate_config(cfg: GenerateConfig) -> None:
|
| 119 |
-
"""Validate configuration parameters."""
|
| 120 |
-
assert cfg.pretrained_checkpoint is not None, "pretrained_checkpoint must not be None!"
|
| 121 |
-
|
| 122 |
-
if "image_aug" in str(cfg.pretrained_checkpoint):
|
| 123 |
-
assert cfg.center_crop, "Expecting `center_crop==True` because model was trained with image augmentations!"
|
| 124 |
-
|
| 125 |
-
assert not (cfg.load_in_8bit and cfg.load_in_4bit), "Cannot use both 8-bit and 4-bit quantization!"
|
| 126 |
-
|
| 127 |
-
# Validate task suite
|
| 128 |
-
# assert cfg.task_suite_name in [suite.value for suite in TaskSuite], f"Invalid task suite: {cfg.task_suite_name}"
|
| 129 |
-
|
| 130 |
-
def initialize_model(cfg: GenerateConfig, only_pt: bool = False): #load action_head and noisy_action_projector separately
|
| 131 |
-
"""Initialize model and associated components."""
|
| 132 |
-
# Load model
|
| 133 |
-
model = get_model(cfg)
|
| 134 |
-
|
| 135 |
-
# Load proprio projector if needed
|
| 136 |
-
proprio_projector = None
|
| 137 |
-
if cfg.use_proprio:
|
| 138 |
-
proprio_projector = get_proprio_projector(
|
| 139 |
-
cfg,
|
| 140 |
-
model.llm_dim,
|
| 141 |
-
proprio_dim=PROPRIO_DIM, #set the proprio_dim for different robots
|
| 142 |
-
)
|
| 143 |
-
|
| 144 |
-
# Load action head if needed
|
| 145 |
-
action_head = None
|
| 146 |
-
if cfg.use_l1_regression or cfg.use_diffusion:
|
| 147 |
-
action_head = get_action_head(cfg, model.llm_dim)
|
| 148 |
-
|
| 149 |
-
# Load noisy action projector if using diffusion
|
| 150 |
-
noisy_action_projector = None
|
| 151 |
-
if cfg.use_diffusion:
|
| 152 |
-
noisy_action_projector = get_noisy_action_projector(cfg, model.llm_dim)
|
| 153 |
-
|
| 154 |
-
# Get OpenVLA processor if needed
|
| 155 |
-
processor = None
|
| 156 |
-
if not only_pt:
|
| 157 |
-
if cfg.model_family == "openvla":
|
| 158 |
-
processor = get_processor(cfg)
|
| 159 |
-
# check_unnorm_key(cfg, model)
|
| 160 |
-
|
| 161 |
-
return model, action_head, proprio_projector, noisy_action_projector, processor
|
| 162 |
-
|
| 163 |
-
# @draccus.wrap()
|
| 164 |
-
def generate_feature_vector(cfg: GenerateConfig):
|
| 165 |
-
"""Generate a feature vector (parameter differences) between two task-specific models."""
|
| 166 |
-
# Validate configuration
|
| 167 |
-
|
| 168 |
-
# Set random seed
|
| 169 |
-
set_seed_everywhere(cfg.seed)
|
| 170 |
-
|
| 171 |
-
# Initialize model and components
|
| 172 |
-
model, action_head, proprio_projector, noisy_action_projector, processor = initialize_model(cfg)
|
| 173 |
-
|
| 174 |
-
original_config = GenerateConfig(
|
| 175 |
-
pretrained_checkpoint=cfg.original_pretrained_checkpoint,
|
| 176 |
-
task_suite_name=cfg.task_suite_name,
|
| 177 |
-
)
|
| 178 |
-
|
| 179 |
-
original_model, original_action_head, original_proprio_projector, original_noisy_action_projector, original_processor = initialize_model(original_config)
|
| 180 |
-
#for action_head and noisy_action_projector, these modules are not interpolated
|
| 181 |
-
assert len(model.state_dict()) == len(original_model.state_dict())
|
| 182 |
-
feature_vector_dict = {}
|
| 183 |
-
total = len(original_model.state_dict())
|
| 184 |
-
for name, original_model_param in tqdm(original_model.named_parameters(), total=total):
|
| 185 |
-
model_param = model.state_dict()[name]
|
| 186 |
-
feature_vector_dict[name] = (model_param - original_model_param).detach().cpu()
|
| 187 |
-
|
| 188 |
-
return feature_vector_dict
|
| 189 |
-
|
| 190 |
-
# @draccus.wrap()
|
| 191 |
-
def interpolate_feature_vector(cfg: GenerateConfig):
|
| 192 |
-
"""Interpolate feature vector."""
|
| 193 |
-
feature_vector_dict = torch.load(cfg.vector_save_path)
|
| 194 |
-
|
| 195 |
-
pt_vla_config = GenerateConfig(
|
| 196 |
-
pretrained_checkpoint=cfg.pt_ckpt,
|
| 197 |
-
original_pretrained_checkpoint=cfg.original_pretrained_checkpoint,
|
| 198 |
-
vector_save_path=cfg.vector_save_path,
|
| 199 |
-
initialized_pt_vla_path=cfg.initialized_pt_vla_path,
|
| 200 |
-
feature_vector_weight=cfg.feature_vector_weight,
|
| 201 |
-
pt_ckpt=cfg.pt_ckpt,
|
| 202 |
-
task_suite_name=cfg.task_suite_name,
|
| 203 |
-
use_proprio=False,
|
| 204 |
-
use_l1_regression=False,
|
| 205 |
-
use_diffusion=False
|
| 206 |
-
)
|
| 207 |
-
|
| 208 |
-
pt_vla,_,_,_,_ = initialize_model(pt_vla_config, only_pt=True)
|
| 209 |
-
|
| 210 |
-
#copy the SF parameters for checking the change before and after interpolation
|
| 211 |
-
model_sd = pt_vla.state_dict()
|
| 212 |
-
before_interp_sd = {k: v.clone() for k, v in model_sd.items() if v.dtype.is_floating_point}
|
| 213 |
-
|
| 214 |
-
with torch.no_grad():
|
| 215 |
-
pt_params = dict(pt_vla.named_parameters())
|
| 216 |
-
for name, diff in feature_vector_dict.items():
|
| 217 |
-
if name in pt_params:
|
| 218 |
-
pt_param = pt_params[name]
|
| 219 |
-
diff = diff.to(pt_param.device)
|
| 220 |
-
pt_param.add_(diff, alpha=cfg.feature_vector_weight)
|
| 221 |
-
|
| 222 |
-
#check after interpolation
|
| 223 |
-
diffs_after = []
|
| 224 |
-
for name, before_tensor in before_interp_sd.items():
|
| 225 |
-
after_tensor = model_sd[name]
|
| 226 |
-
difference = (after_tensor - before_tensor).float().norm().item()
|
| 227 |
-
diffs_after.append(difference)
|
| 228 |
-
|
| 229 |
-
print(f"[DEBUG] post-interp (SF -> interp): mean={sum(diffs_after)/len(diffs_after):.6f}, "
|
| 230 |
-
f"max={max(diffs_after):.6f}, num_tensors={len(diffs_after)}")
|
| 231 |
-
|
| 232 |
-
#########################################################
|
| 233 |
-
return pt_vla
|
| 234 |
-
|
| 235 |
-
@draccus.wrap()
|
| 236 |
-
def main(cfg: GenerateConfig):
|
| 237 |
-
if not os.path.exists(cfg.vector_save_path):
|
| 238 |
-
feature_vector_dict = generate_feature_vector(cfg)
|
| 239 |
-
torch.save(feature_vector_dict, cfg.vector_save_path)
|
| 240 |
-
else:
|
| 241 |
-
print(f"Feature vector already exists at {cfg.vector_save_path}")
|
| 242 |
-
initialized_pt_vla = interpolate_feature_vector(cfg)
|
| 243 |
-
os.makedirs(cfg.initialized_pt_vla_path, exist_ok=True)
|
| 244 |
-
initialized_pt_vla.save_pretrained(cfg.initialized_pt_vla_path)
|
| 245 |
-
|
| 246 |
-
if __name__ == "__main__":
|
| 247 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/capvector/tools/check_model_config.py
DELETED
|
@@ -1,23 +0,0 @@
|
|
| 1 |
-
#This is for checking the completeness of the model parameters.
|
| 2 |
-
import argparse
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
def main():
|
| 8 |
-
parser = argparse.ArgumentParser()
|
| 9 |
-
parser.add_argument("checkpoint_path", help="Path to the feature vector checkpoint (.pth)")
|
| 10 |
-
args = parser.parse_args()
|
| 11 |
-
|
| 12 |
-
fv = torch.load(args.checkpoint_path, map_location="cpu")
|
| 13 |
-
|
| 14 |
-
print("num_tensors:", len(fv))
|
| 15 |
-
nz = 0
|
| 16 |
-
for _, value in fv.items():
|
| 17 |
-
if value.abs().sum().item() != 0:
|
| 18 |
-
nz += 1
|
| 19 |
-
print("nonzero_tensors:", nz)
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
if __name__ == "__main__":
|
| 23 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/capvector/tools/compute_lora_diff.py
DELETED
|
@@ -1,36 +0,0 @@
|
|
| 1 |
-
#This is for computing the difference between the base model and the target model.
|
| 2 |
-
from safetensors.torch import load_file, save_file
|
| 3 |
-
import torch
|
| 4 |
-
import argparse
|
| 5 |
-
|
| 6 |
-
def main():
|
| 7 |
-
parser = argparse.ArgumentParser()
|
| 8 |
-
parser.add_argument("--base", required=True)
|
| 9 |
-
parser.add_argument("--target", required=True)
|
| 10 |
-
parser.add_argument("--out", default="lora_diff.safetensors")
|
| 11 |
-
args = parser.parse_args()
|
| 12 |
-
|
| 13 |
-
base = load_file(args.base)
|
| 14 |
-
target = load_file(args.target)
|
| 15 |
-
|
| 16 |
-
diff = {}
|
| 17 |
-
|
| 18 |
-
print("=== Key Comparison ===")
|
| 19 |
-
only_in_base = set(base) - set(target)
|
| 20 |
-
only_in_target = set(target) - set(base)
|
| 21 |
-
|
| 22 |
-
print("Only in base:", list(only_in_base)[:10])
|
| 23 |
-
print("Only in target:", list(only_in_target)[:10])
|
| 24 |
-
|
| 25 |
-
for k in target:
|
| 26 |
-
if k in base:
|
| 27 |
-
diff[k] = target[k] - base[k]
|
| 28 |
-
else:
|
| 29 |
-
# keep the new parameters
|
| 30 |
-
diff[k] = target[k].clone()
|
| 31 |
-
|
| 32 |
-
save_file(diff, args.out)
|
| 33 |
-
print(f"\nSaved diff to: {args.out}")
|
| 34 |
-
|
| 35 |
-
if __name__ == "__main__":
|
| 36 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/capvector/tools/compute_lora_diff.sh
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
BASE_ADAPTER="checkpoints/reference_models/openvla_oft_libero_spatial/lora_adapter/adapter_model.safetensors"
|
| 2 |
-
TARGET_ADAPTER="checkpoints/task_models/SF_spatial/lora_adapter/adapter_model.safetensors"
|
| 3 |
-
OUTPUT_DIFF="checkpoints/lora_diff/sf_150000_steps_spatial_adapter_diff.safetensors"
|
| 4 |
-
|
| 5 |
-
python compute_lora_diff.py \
|
| 6 |
-
--base "$BASE_ADAPTER" \
|
| 7 |
-
--target "$TARGET_ADAPTER" \
|
| 8 |
-
--out "$OUTPUT_DIFF"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/capvector/tools/vector_analyze.py
DELETED
|
@@ -1,153 +0,0 @@
|
|
| 1 |
-
#This is for analyzing the vector of the model and finding out which layers have the largest absolute values.
|
| 2 |
-
import argparse
|
| 3 |
-
import csv
|
| 4 |
-
import os
|
| 5 |
-
import re
|
| 6 |
-
from collections import OrderedDict, defaultdict
|
| 7 |
-
|
| 8 |
-
import matplotlib.pyplot as plt
|
| 9 |
-
import torch
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
LAYER_PREFIX = "language_model.model.layers."
|
| 13 |
-
NUM_LAYERS = 32
|
| 14 |
-
USE_LOG_Y = True
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def pick_state_dict(obj):
|
| 18 |
-
if isinstance(obj, (OrderedDict, dict)):
|
| 19 |
-
for key in ["state_dict", "model_state_dict", "model", "net", "weights", "params"]:
|
| 20 |
-
if key in obj and isinstance(obj[key], (OrderedDict, dict)):
|
| 21 |
-
return obj[key]
|
| 22 |
-
if any(torch.is_tensor(value) for value in obj.values()):
|
| 23 |
-
return obj
|
| 24 |
-
return None
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def aggregate_layers_abs_sum(state_dict):
|
| 28 |
-
layer_sum = defaultdict(float)
|
| 29 |
-
layer_cnt = defaultdict(int)
|
| 30 |
-
pattern = re.compile(r"^" + re.escape(LAYER_PREFIX) + r"(\d+)\.")
|
| 31 |
-
|
| 32 |
-
for name, tensor in state_dict.items():
|
| 33 |
-
if not isinstance(name, str):
|
| 34 |
-
continue
|
| 35 |
-
match = pattern.match(name)
|
| 36 |
-
if match is None or not torch.is_tensor(tensor):
|
| 37 |
-
continue
|
| 38 |
-
|
| 39 |
-
layer_id = int(match.group(1))
|
| 40 |
-
if layer_id < 0 or layer_id >= NUM_LAYERS:
|
| 41 |
-
continue
|
| 42 |
-
|
| 43 |
-
value = tensor.detach()
|
| 44 |
-
if value.is_cuda:
|
| 45 |
-
value = value.cpu()
|
| 46 |
-
|
| 47 |
-
value = value.to(torch.float64)
|
| 48 |
-
layer_sum[layer_id] += value.abs().sum().item()
|
| 49 |
-
layer_cnt[layer_id] += 1
|
| 50 |
-
|
| 51 |
-
for layer_id in range(NUM_LAYERS):
|
| 52 |
-
layer_sum[layer_id] = float(layer_sum.get(layer_id, 0.0))
|
| 53 |
-
layer_cnt[layer_id] = int(layer_cnt.get(layer_id, 0))
|
| 54 |
-
|
| 55 |
-
return layer_sum, layer_cnt
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
def save_layer_csv(layer_sum, layer_cnt, path):
|
| 59 |
-
output_dir = os.path.dirname(path)
|
| 60 |
-
if output_dir:
|
| 61 |
-
os.makedirs(output_dir, exist_ok=True)
|
| 62 |
-
with open(path, "w", newline="") as file_obj:
|
| 63 |
-
writer = csv.DictWriter(file_obj, fieldnames=["layer_id", "abs_sum", "num_tensors"])
|
| 64 |
-
writer.writeheader()
|
| 65 |
-
for layer_id in range(NUM_LAYERS):
|
| 66 |
-
writer.writerow(
|
| 67 |
-
{
|
| 68 |
-
"layer_id": layer_id,
|
| 69 |
-
"abs_sum": f"{layer_sum[layer_id]:.12e}",
|
| 70 |
-
"num_tensors": layer_cnt[layer_id],
|
| 71 |
-
}
|
| 72 |
-
)
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
def plot_line(xs, ys, out_png, title):
|
| 76 |
-
ys_plot = ys[:]
|
| 77 |
-
if USE_LOG_Y:
|
| 78 |
-
min_pos = min([value for value in ys_plot if value > 0], default=1e-300)
|
| 79 |
-
eps = min_pos * 1e-6 if min_pos > 0 else 1e-300
|
| 80 |
-
ys_plot = [value if value > 0 else eps for value in ys_plot]
|
| 81 |
-
|
| 82 |
-
plt.figure(figsize=(12, 4.5))
|
| 83 |
-
plt.plot(xs, ys_plot, marker="o", linewidth=1.5)
|
| 84 |
-
plt.xlabel("Layer id")
|
| 85 |
-
plt.ylabel("abs_sum (all params in layer)")
|
| 86 |
-
plt.title(title)
|
| 87 |
-
plt.grid(True, which="both", linestyle="--", linewidth=0.5, alpha=0.5)
|
| 88 |
-
if USE_LOG_Y:
|
| 89 |
-
plt.yscale("log")
|
| 90 |
-
plt.tight_layout()
|
| 91 |
-
plt.savefig(out_png, dpi=200)
|
| 92 |
-
plt.close()
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
def plot_bar(xs, ys, out_png, title):
|
| 96 |
-
ys_plot = ys[:]
|
| 97 |
-
if USE_LOG_Y:
|
| 98 |
-
min_pos = min([value for value in ys_plot if value > 0], default=1e-300)
|
| 99 |
-
eps = min_pos * 1e-6 if min_pos > 0 else 1e-300
|
| 100 |
-
ys_plot = [value if value > 0 else eps for value in ys_plot]
|
| 101 |
-
|
| 102 |
-
plt.figure(figsize=(12, 4.5))
|
| 103 |
-
plt.bar(xs, ys_plot)
|
| 104 |
-
plt.xlabel("Layer id")
|
| 105 |
-
plt.ylabel("abs_sum (all params in layer)")
|
| 106 |
-
plt.title(title)
|
| 107 |
-
plt.grid(True, which="both", axis="y", linestyle="--", linewidth=0.5, alpha=0.5)
|
| 108 |
-
if USE_LOG_Y:
|
| 109 |
-
plt.yscale("log")
|
| 110 |
-
plt.tight_layout()
|
| 111 |
-
plt.savefig(out_png, dpi=200)
|
| 112 |
-
plt.close()
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
def main():
|
| 116 |
-
parser = argparse.ArgumentParser()
|
| 117 |
-
parser.add_argument("checkpoint_path", help="Path to the feature vector checkpoint (.pth)")
|
| 118 |
-
args = parser.parse_args()
|
| 119 |
-
|
| 120 |
-
base = os.path.splitext(args.checkpoint_path)[0]
|
| 121 |
-
out_csv = base + "_language_model_layers_abs_sum.csv"
|
| 122 |
-
out_png_line = base + "_language_model_layers_abs_sum_line.png"
|
| 123 |
-
out_png_bar = base + "_language_model_layers_abs_sum_bar.png"
|
| 124 |
-
|
| 125 |
-
ckpt = torch.load(args.checkpoint_path, map_location="cpu")
|
| 126 |
-
state_dict = pick_state_dict(ckpt)
|
| 127 |
-
|
| 128 |
-
if state_dict is None:
|
| 129 |
-
print("Not a state_dict-like dict. Type:", type(ckpt))
|
| 130 |
-
if isinstance(ckpt, dict):
|
| 131 |
-
print("Top-level keys:", list(ckpt.keys())[:50])
|
| 132 |
-
raise SystemExit(1)
|
| 133 |
-
|
| 134 |
-
layer_sum, layer_cnt = aggregate_layers_abs_sum(state_dict)
|
| 135 |
-
save_layer_csv(layer_sum, layer_cnt, out_csv)
|
| 136 |
-
print(f"Saved CSV: {out_csv}")
|
| 137 |
-
|
| 138 |
-
xs = list(range(NUM_LAYERS))
|
| 139 |
-
ys = [layer_sum[i] for i in xs]
|
| 140 |
-
plot_line(xs, ys, out_png_line, f"{LAYER_PREFIX}*: abs_sum per layer")
|
| 141 |
-
plot_bar(xs, ys, out_png_bar, f"{LAYER_PREFIX}*: abs_sum per layer")
|
| 142 |
-
|
| 143 |
-
print(f"Saved plot: {out_png_line}")
|
| 144 |
-
print(f"Saved plot: {out_png_bar}")
|
| 145 |
-
|
| 146 |
-
top = sorted(((i, layer_sum[i], layer_cnt[i]) for i in xs), key=lambda item: item[1], reverse=True)[:5]
|
| 147 |
-
print("Top-5 layers by abs_sum:")
|
| 148 |
-
for layer_id, abs_sum, tensor_count in top:
|
| 149 |
-
print(f" layer {layer_id:02d}: abs_sum={abs_sum:.6e}, tensors={tensor_count}")
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
if __name__ == "__main__":
|
| 153 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/capvector/tools/vector_regularize.py
DELETED
|
@@ -1,75 +0,0 @@
|
|
| 1 |
-
# Used to regularize feature vectors by first computing the absolute-sum of each layer and then performing normalization
|
| 2 |
-
|
| 3 |
-
import argparse
|
| 4 |
-
from collections import OrderedDict
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
def pick_state_dict(obj):
|
| 10 |
-
"""Extract state_dict from a checkpoint-like object"""
|
| 11 |
-
if isinstance(obj, (OrderedDict, dict)):
|
| 12 |
-
for k in ["state_dict", "model_state_dict", "model", "net", "weights", "params"]:
|
| 13 |
-
if k in obj and isinstance(obj[k], (OrderedDict, dict)):
|
| 14 |
-
return obj[k]
|
| 15 |
-
if any(torch.is_tensor(v) for v in obj.values()):
|
| 16 |
-
return obj
|
| 17 |
-
return None
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def calculate_total_abs_sum(state_dict):
|
| 21 |
-
"""Compute the sum of absolute values over all parameters"""
|
| 22 |
-
total_sum = 0.0
|
| 23 |
-
param_count = 0
|
| 24 |
-
|
| 25 |
-
for name, tensor in state_dict.items():
|
| 26 |
-
if not torch.is_tensor(tensor):
|
| 27 |
-
continue
|
| 28 |
-
|
| 29 |
-
x = tensor.detach()
|
| 30 |
-
if x.is_cuda:
|
| 31 |
-
x = x.cpu()
|
| 32 |
-
|
| 33 |
-
# Use float64 to ensure numerical precision
|
| 34 |
-
x = x.to(torch.float64)
|
| 35 |
-
abs_sum = x.abs().sum().item()
|
| 36 |
-
total_sum += abs_sum
|
| 37 |
-
param_count += 1
|
| 38 |
-
|
| 39 |
-
print(f"{name}: {abs_sum:.12e} (shape: {list(x.shape)}, numel: {x.numel()})")
|
| 40 |
-
|
| 41 |
-
return total_sum, param_count
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def main():
|
| 45 |
-
parser = argparse.ArgumentParser()
|
| 46 |
-
parser.add_argument("checkpoint_path", help="Path to the feature vector checkpoint (.pth)")
|
| 47 |
-
args = parser.parse_args()
|
| 48 |
-
|
| 49 |
-
print(f"Loading checkpoint: {args.checkpoint_path}")
|
| 50 |
-
ckpt = torch.load(args.checkpoint_path, map_location="cpu")
|
| 51 |
-
sd = pick_state_dict(ckpt)
|
| 52 |
-
|
| 53 |
-
if sd is None:
|
| 54 |
-
print("Error: failed to extract state_dict from checkpoint")
|
| 55 |
-
print(f"Checkpoint type: {type(ckpt)}")
|
| 56 |
-
if isinstance(ckpt, dict):
|
| 57 |
-
print(f"Top-level keys: {list(ckpt.keys())[:20]}")
|
| 58 |
-
raise SystemExit(1)
|
| 59 |
-
|
| 60 |
-
print(f"\nFound {len(sd)} parameters\n")
|
| 61 |
-
print("=" * 80)
|
| 62 |
-
print("Absolute-sum of each parameter:")
|
| 63 |
-
print("=" * 80)
|
| 64 |
-
|
| 65 |
-
total_abs_sum, param_count = calculate_total_abs_sum(sd)
|
| 66 |
-
|
| 67 |
-
print("=" * 80)
|
| 68 |
-
print(f"\nSummary:")
|
| 69 |
-
print(f" Total number of parameters: {param_count}")
|
| 70 |
-
print(f" Sum of absolute values of all parameters: {total_abs_sum:.12e}")
|
| 71 |
-
print(f" Sum of absolute values of all parameters (scientific notation): {total_abs_sum:.6e}")
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
if __name__ == "__main__":
|
| 75 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/experiments/robot/aloha/aloha_utils.py
DELETED
|
@@ -1,85 +0,0 @@
|
|
| 1 |
-
"""Utils for evaluating policies in real-world ALOHA environments."""
|
| 2 |
-
|
| 3 |
-
import os
|
| 4 |
-
|
| 5 |
-
import imageio
|
| 6 |
-
import numpy as np
|
| 7 |
-
from PIL import Image
|
| 8 |
-
|
| 9 |
-
from experiments.robot.aloha.real_env import make_real_env
|
| 10 |
-
from experiments.robot.robot_utils import (
|
| 11 |
-
DATE,
|
| 12 |
-
DATE_TIME,
|
| 13 |
-
)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def get_next_task_label(task_label):
|
| 17 |
-
"""Prompt the user to input the next task."""
|
| 18 |
-
if task_label == "":
|
| 19 |
-
user_input = ""
|
| 20 |
-
while user_input == "":
|
| 21 |
-
user_input = input("Enter the task name: ")
|
| 22 |
-
task_label = user_input
|
| 23 |
-
else:
|
| 24 |
-
user_input = input("Enter the task name (or leave blank to repeat the previous task): ")
|
| 25 |
-
if user_input == "":
|
| 26 |
-
pass # Do nothing -> Let task_label be the same
|
| 27 |
-
else:
|
| 28 |
-
task_label = user_input
|
| 29 |
-
print(f"Task: {task_label}")
|
| 30 |
-
return task_label
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
def get_aloha_env():
|
| 34 |
-
"""Initializes and returns the ALOHA environment."""
|
| 35 |
-
env = make_real_env(init_node=True)
|
| 36 |
-
return env
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
def resize_image_for_preprocessing(img):
|
| 40 |
-
"""
|
| 41 |
-
Takes numpy array corresponding to a single image and resizes to 256x256, exactly as done
|
| 42 |
-
in the ALOHA data preprocessing script, which is used before converting the dataset to RLDS.
|
| 43 |
-
"""
|
| 44 |
-
ALOHA_PREPROCESS_SIZE = 256
|
| 45 |
-
img = np.array(
|
| 46 |
-
Image.fromarray(img).resize((ALOHA_PREPROCESS_SIZE, ALOHA_PREPROCESS_SIZE), resample=Image.BICUBIC)
|
| 47 |
-
) # BICUBIC is default; specify explicitly to make it clear
|
| 48 |
-
return img
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
def get_aloha_image(obs):
|
| 52 |
-
"""Extracts third-person image from observations and preprocesses it."""
|
| 53 |
-
# obs: dm_env._environment.TimeStep
|
| 54 |
-
img = obs.observation["images"]["cam_high"]
|
| 55 |
-
img = resize_image_for_preprocessing(img)
|
| 56 |
-
return img
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
def get_aloha_wrist_images(obs):
|
| 60 |
-
"""Extracts both wrist camera images from observations and preprocesses them."""
|
| 61 |
-
# obs: dm_env._environment.TimeStep
|
| 62 |
-
left_wrist_img = obs.observation["images"]["cam_left_wrist"]
|
| 63 |
-
right_wrist_img = obs.observation["images"]["cam_right_wrist"]
|
| 64 |
-
left_wrist_img = resize_image_for_preprocessing(left_wrist_img)
|
| 65 |
-
right_wrist_img = resize_image_for_preprocessing(right_wrist_img)
|
| 66 |
-
return left_wrist_img, right_wrist_img
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
def save_rollout_video(rollout_images, idx, success, task_description, log_file=None, notes=None):
|
| 70 |
-
"""Saves an MP4 replay of an episode."""
|
| 71 |
-
rollout_dir = f"./rollouts/{DATE}"
|
| 72 |
-
os.makedirs(rollout_dir, exist_ok=True)
|
| 73 |
-
processed_task_description = task_description.lower().replace(" ", "_").replace("\n", "_").replace(".", "_")[:50]
|
| 74 |
-
filetag = f"{rollout_dir}/{DATE_TIME}--openvla_oft--episode={idx}--success={success}--task={processed_task_description}"
|
| 75 |
-
if notes is not None:
|
| 76 |
-
filetag += f"--{notes}"
|
| 77 |
-
mp4_path = f"{filetag}.mp4"
|
| 78 |
-
video_writer = imageio.get_writer(mp4_path, fps=25)
|
| 79 |
-
for img in rollout_images:
|
| 80 |
-
video_writer.append_data(img)
|
| 81 |
-
video_writer.close()
|
| 82 |
-
print(f"Saved rollout MP4 at path {mp4_path}")
|
| 83 |
-
if log_file is not None:
|
| 84 |
-
log_file.write(f"Saved rollout MP4 at path {mp4_path}\n")
|
| 85 |
-
return mp4_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/experiments/robot/aloha/constants.py
DELETED
|
@@ -1,100 +0,0 @@
|
|
| 1 |
-
### Task parameters
|
| 2 |
-
|
| 3 |
-
DATA_DIR = '/scr2/moojink/data/aloha1/'
|
| 4 |
-
TASK_CONFIGS = {
|
| 5 |
-
# fold shorts
|
| 6 |
-
'fold_shorts':{
|
| 7 |
-
'dataset_dir': DATA_DIR + '/fold_shorts',
|
| 8 |
-
'num_episodes': 20,
|
| 9 |
-
'episode_len': 1000,
|
| 10 |
-
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
|
| 11 |
-
},
|
| 12 |
-
# fold shirt
|
| 13 |
-
'fold_shirt':{
|
| 14 |
-
'dataset_dir': DATA_DIR + '/fold_shirt',
|
| 15 |
-
'num_episodes': 30,
|
| 16 |
-
'episode_len': 1250,
|
| 17 |
-
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
|
| 18 |
-
},
|
| 19 |
-
# scoop X into bowl
|
| 20 |
-
'scoop_raisins_into_bowl':{
|
| 21 |
-
'dataset_dir': DATA_DIR + '/scoop_raisins_into_bowl',
|
| 22 |
-
'num_episodes': 15,
|
| 23 |
-
'episode_len': 900,
|
| 24 |
-
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
|
| 25 |
-
},
|
| 26 |
-
'scoop_almonds_and_green_M&Ms_into_bowl':{
|
| 27 |
-
'dataset_dir': DATA_DIR + '/scoop_almonds_and_green_M&Ms_into_bowl',
|
| 28 |
-
'num_episodes': 15,
|
| 29 |
-
'episode_len': 900,
|
| 30 |
-
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
|
| 31 |
-
},
|
| 32 |
-
'scoop_pretzels_into_bowl':{
|
| 33 |
-
'dataset_dir': DATA_DIR + '/scoop_pretzels_into_bowl',
|
| 34 |
-
'num_episodes': 15,
|
| 35 |
-
'episode_len': 900,
|
| 36 |
-
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
|
| 37 |
-
},
|
| 38 |
-
# put X into pot
|
| 39 |
-
'put_red_pepper_into_pot':{
|
| 40 |
-
'dataset_dir': DATA_DIR + '/put_red_pepper_into_pot',
|
| 41 |
-
'num_episodes': 100,
|
| 42 |
-
'episode_len': 400,
|
| 43 |
-
'camera_names': ['cam_high', 'cam_low', 'cam_left_wrist', 'cam_right_wrist']
|
| 44 |
-
},
|
| 45 |
-
'put_yellow_corn_into_pot':{
|
| 46 |
-
'dataset_dir': DATA_DIR + '/put_yellow_corn_into_pot',
|
| 47 |
-
'num_episodes': 100,
|
| 48 |
-
'episode_len': 400,
|
| 49 |
-
'camera_names': ['cam_high', 'cam_low', 'cam_left_wrist', 'cam_right_wrist']
|
| 50 |
-
},
|
| 51 |
-
'put_green_pepper_into_pot':{
|
| 52 |
-
'dataset_dir': DATA_DIR + '/put_green_pepper_into_pot',
|
| 53 |
-
'num_episodes': 100,
|
| 54 |
-
'episode_len': 400,
|
| 55 |
-
'camera_names': ['cam_high', 'cam_low', 'cam_left_wrist', 'cam_right_wrist']
|
| 56 |
-
},
|
| 57 |
-
}
|
| 58 |
-
|
| 59 |
-
### ALOHA fixed constants
|
| 60 |
-
DT = 0.04 # 1 / 0.04 -> 25 Hz
|
| 61 |
-
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
|
| 62 |
-
START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239]
|
| 63 |
-
|
| 64 |
-
# Left finger position limits (qpos[7]), right_finger = -1 * left_finger
|
| 65 |
-
MASTER_GRIPPER_POSITION_OPEN = 0.02417
|
| 66 |
-
MASTER_GRIPPER_POSITION_CLOSE = 0.01244
|
| 67 |
-
PUPPET_GRIPPER_POSITION_OPEN = 0.05800
|
| 68 |
-
PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
|
| 69 |
-
|
| 70 |
-
# Gripper joint limits (qpos[6])
|
| 71 |
-
MASTER_GRIPPER_JOINT_OPEN = 0.3083 # For ALOHA 1
|
| 72 |
-
MASTER_GRIPPER_JOINT_CLOSE = -0.6842 # For ALOHA 1
|
| 73 |
-
# MASTER_GRIPPER_JOINT_OPEN = -0.8 # For ALOHA 2
|
| 74 |
-
# MASTER_GRIPPER_JOINT_CLOSE = -1.65 # For ALOHA 2
|
| 75 |
-
PUPPET_GRIPPER_JOINT_OPEN = 1.4910
|
| 76 |
-
PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
|
| 77 |
-
|
| 78 |
-
############################ Helper functions ############################
|
| 79 |
-
|
| 80 |
-
MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
|
| 81 |
-
PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
|
| 82 |
-
MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
|
| 83 |
-
PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
|
| 84 |
-
MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x))
|
| 85 |
-
|
| 86 |
-
MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
| 87 |
-
PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
| 88 |
-
MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
|
| 89 |
-
PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
|
| 90 |
-
MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x))
|
| 91 |
-
|
| 92 |
-
MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
|
| 93 |
-
PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
|
| 94 |
-
|
| 95 |
-
MASTER_POS2JOINT = lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
|
| 96 |
-
MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN((x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE))
|
| 97 |
-
PUPPET_POS2JOINT = lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
|
| 98 |
-
PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN((x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE))
|
| 99 |
-
|
| 100 |
-
MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE)/2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/experiments/robot/aloha/preprocess_split_aloha_data.py
DELETED
|
@@ -1,260 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Preprocesses ALOHA dataset(s) and splits them into train/val sets.
|
| 3 |
-
|
| 4 |
-
Preprocessing includes downsizing images from 480x640 to 256x256.
|
| 5 |
-
Splits happen at the episode level (not step level), which means that
|
| 6 |
-
an episode is treated as an atomic unit that entirely goes to either
|
| 7 |
-
the train set or val set.
|
| 8 |
-
|
| 9 |
-
Original ALOHA data layout:
|
| 10 |
-
/PATH/TO/DATASET/dataset_name/
|
| 11 |
-
- episode_0.hdf5
|
| 12 |
-
- episode_1.hdf5
|
| 13 |
-
- ...
|
| 14 |
-
- episode_N.hdf5
|
| 15 |
-
|
| 16 |
-
Preprocessed data layout (after running this script):
|
| 17 |
-
/PATH/TO/PREPROCESSED_DATASETS/dataset_name/
|
| 18 |
-
- train/
|
| 19 |
-
- episode_0.hdf5
|
| 20 |
-
- episode_1.hdf5
|
| 21 |
-
- ...
|
| 22 |
-
- episode_M.hdf5
|
| 23 |
-
- val/
|
| 24 |
-
- episode_0.hdf5
|
| 25 |
-
- episode_1.hdf5
|
| 26 |
-
- ...
|
| 27 |
-
- episode_K.hdf5
|
| 28 |
-
|
| 29 |
-
where N > M > K
|
| 30 |
-
|
| 31 |
-
Example usage:
|
| 32 |
-
# "put X into pot" task
|
| 33 |
-
python experiments/robot/aloha/preprocess_split_aloha_data.py \
|
| 34 |
-
--dataset_path /scr/moojink/data/aloha1_raw/put_green_pepper_into_pot/ \
|
| 35 |
-
--out_base_dir /scr/moojink/data/aloha1_preprocessed/ \
|
| 36 |
-
--percent_val 0.05 && \
|
| 37 |
-
python experiments/robot/aloha/preprocess_split_aloha_data.py \
|
| 38 |
-
--dataset_path /scr/moojink/data/aloha1_raw/put_red_pepper_into_pot/ \
|
| 39 |
-
--out_base_dir /scr/moojink/data/aloha1_preprocessed/ \
|
| 40 |
-
--percent_val 0.05 && \
|
| 41 |
-
python experiments/robot/aloha/preprocess_split_aloha_data.py \
|
| 42 |
-
--dataset_path /scr/moojink/data/aloha1_raw/put_yellow_corn_into_pot/ \
|
| 43 |
-
--out_base_dir /scr/moojink/data/aloha1_preprocessed/ \
|
| 44 |
-
--percent_val 0.05
|
| 45 |
-
"""
|
| 46 |
-
|
| 47 |
-
import argparse
|
| 48 |
-
import glob
|
| 49 |
-
import os
|
| 50 |
-
import random
|
| 51 |
-
|
| 52 |
-
import h5py
|
| 53 |
-
import numpy as np
|
| 54 |
-
from PIL import Image
|
| 55 |
-
from tqdm import tqdm
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
def load_hdf5(demo_path):
|
| 59 |
-
"""Loads single episode."""
|
| 60 |
-
if not os.path.isfile(demo_path):
|
| 61 |
-
print(f"Dataset does not exist at \n{demo_path}\n")
|
| 62 |
-
exit()
|
| 63 |
-
|
| 64 |
-
print(f"Loading {demo_path}...")
|
| 65 |
-
with h5py.File(demo_path, "r") as root:
|
| 66 |
-
is_sim = root.attrs["sim"]
|
| 67 |
-
qpos = root["/observations/qpos"][()]
|
| 68 |
-
qvel = root["/observations/qvel"][()]
|
| 69 |
-
effort = root["/observations/effort"][()]
|
| 70 |
-
action = root["/action"][()]
|
| 71 |
-
image_dict = dict()
|
| 72 |
-
for cam_name in root["/observations/images/"].keys():
|
| 73 |
-
image_dict[cam_name] = root[f"/observations/images/{cam_name}"][()]
|
| 74 |
-
print(f"Loading episode complete: {demo_path}")
|
| 75 |
-
|
| 76 |
-
return qpos, qvel, effort, action, image_dict, is_sim
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
def load_and_preprocess_all_episodes(demo_paths, out_dataset_dir):
|
| 80 |
-
"""
|
| 81 |
-
Loads and preprocesses all episodes.
|
| 82 |
-
Resizes all images in one episode before loading the next, to reduce memory usage.
|
| 83 |
-
"""
|
| 84 |
-
cam_names = ["cam_high", "cam_left_wrist", "cam_right_wrist"]
|
| 85 |
-
idx = 0
|
| 86 |
-
for demo in tqdm(demo_paths):
|
| 87 |
-
qpos, qvel, effort, action, image_dict, is_sim = load_hdf5(demo)
|
| 88 |
-
# Save non-image info
|
| 89 |
-
episode_len = image_dict["cam_high"].shape[0]
|
| 90 |
-
# Resize all images
|
| 91 |
-
print("Resizing images in episode...")
|
| 92 |
-
for k in cam_names:
|
| 93 |
-
resized_images = []
|
| 94 |
-
for i in range(episode_len):
|
| 95 |
-
resized_images.append(
|
| 96 |
-
np.array(
|
| 97 |
-
Image.fromarray(image_dict[k][i]).resize(
|
| 98 |
-
(args.img_resize_size, args.img_resize_size), resample=Image.BICUBIC
|
| 99 |
-
)
|
| 100 |
-
) # BICUBIC is default; specify explicitly to make it clear
|
| 101 |
-
)
|
| 102 |
-
image_dict[k] = np.stack(resized_images)
|
| 103 |
-
print("Resizing images in episode complete!")
|
| 104 |
-
# Save preprocessed episode
|
| 105 |
-
data_dict = dict(
|
| 106 |
-
qpos=qpos,
|
| 107 |
-
qvel=qvel,
|
| 108 |
-
effort=effort,
|
| 109 |
-
action=action,
|
| 110 |
-
image_dict=image_dict,
|
| 111 |
-
is_sim=is_sim,
|
| 112 |
-
)
|
| 113 |
-
save_new_hdf5(out_dataset_dir, data_dict, idx)
|
| 114 |
-
idx += 1
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
def randomly_split(full_qpos, full_qvel, full_effort, full_action, full_image_dict, percent_val):
|
| 118 |
-
"""Randomly splits dataset into train and validation sets."""
|
| 119 |
-
# Create a list of episode indices
|
| 120 |
-
num_episodes_total = len(full_qpos)
|
| 121 |
-
indices = list(range(num_episodes_total))
|
| 122 |
-
# Shuffle the episode indices
|
| 123 |
-
random.shuffle(indices)
|
| 124 |
-
# Create new lists using the shuffled indices
|
| 125 |
-
shuffled_qpos = [full_qpos[idx] for idx in indices]
|
| 126 |
-
shuffled_qvel = [full_qvel[idx] for idx in indices]
|
| 127 |
-
shuffled_effort = [full_effort[idx] for idx in indices]
|
| 128 |
-
shuffled_action = [full_action[idx] for idx in indices]
|
| 129 |
-
shuffled_image_dict = {
|
| 130 |
-
"cam_high": [],
|
| 131 |
-
"cam_left_wrist": [],
|
| 132 |
-
"cam_right_wrist": [],
|
| 133 |
-
}
|
| 134 |
-
for k in full_image_dict.keys():
|
| 135 |
-
shuffled_image_dict[k] = [full_image_dict[k][idx] for idx in indices]
|
| 136 |
-
# Split into train and val sets
|
| 137 |
-
num_episodes_val = int(num_episodes_total * percent_val)
|
| 138 |
-
print(f"Total # steps: {num_episodes_total}; using {num_episodes_val} ({percent_val:.2f}%) for val set")
|
| 139 |
-
num_episodes_train = num_episodes_total - num_episodes_val
|
| 140 |
-
train_dict = dict(
|
| 141 |
-
qpos=shuffled_qpos[:num_episodes_train],
|
| 142 |
-
qvel=shuffled_qvel[:num_episodes_train],
|
| 143 |
-
effort=shuffled_effort[:num_episodes_train],
|
| 144 |
-
action=shuffled_action[:num_episodes_train],
|
| 145 |
-
image_dict=dict(
|
| 146 |
-
cam_high=shuffled_image_dict["cam_high"][:num_episodes_train],
|
| 147 |
-
cam_left_wrist=shuffled_image_dict["cam_left_wrist"][:num_episodes_train],
|
| 148 |
-
cam_right_wrist=shuffled_image_dict["cam_right_wrist"][:num_episodes_train],
|
| 149 |
-
),
|
| 150 |
-
)
|
| 151 |
-
val_dict = dict(
|
| 152 |
-
qpos=shuffled_qpos[num_episodes_train:],
|
| 153 |
-
qvel=shuffled_qvel[num_episodes_train:],
|
| 154 |
-
effort=shuffled_effort[num_episodes_train:],
|
| 155 |
-
action=shuffled_action[num_episodes_train:],
|
| 156 |
-
image_dict=dict(
|
| 157 |
-
cam_high=shuffled_image_dict["cam_high"][num_episodes_train:],
|
| 158 |
-
cam_left_wrist=shuffled_image_dict["cam_left_wrist"][num_episodes_train:],
|
| 159 |
-
cam_right_wrist=shuffled_image_dict["cam_right_wrist"][num_episodes_train:],
|
| 160 |
-
),
|
| 161 |
-
)
|
| 162 |
-
return train_dict, val_dict
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
def save_new_hdf5(out_dataset_dir, data_dict, episode_idx):
|
| 166 |
-
"""Saves an HDF5 file for a new episode."""
|
| 167 |
-
camera_names = data_dict["image_dict"].keys()
|
| 168 |
-
H, W, C = data_dict["image_dict"]["cam_high"][0].shape
|
| 169 |
-
out_path = os.path.join(out_dataset_dir, f"episode_{episode_idx}.hdf5")
|
| 170 |
-
# Save HDF5 with same structure as original demos (except that now we combine all episodes into one HDF5 file)
|
| 171 |
-
with h5py.File(
|
| 172 |
-
out_path, "w", rdcc_nbytes=1024**2 * 2
|
| 173 |
-
) as root: # Magic constant for rdcc_nbytes comes from ALOHA codebase
|
| 174 |
-
episode_len = data_dict["qpos"].shape[0]
|
| 175 |
-
root.attrs["sim"] = data_dict["is_sim"]
|
| 176 |
-
obs = root.create_group("observations")
|
| 177 |
-
_ = obs.create_dataset("qpos", (episode_len, 14))
|
| 178 |
-
_ = obs.create_dataset("qvel", (episode_len, 14))
|
| 179 |
-
_ = obs.create_dataset("effort", (episode_len, 14))
|
| 180 |
-
root["/observations/qpos"][...] = data_dict["qpos"]
|
| 181 |
-
root["/observations/qvel"][...] = data_dict["qvel"]
|
| 182 |
-
root["/observations/effort"][...] = data_dict["effort"]
|
| 183 |
-
image = obs.create_group("images")
|
| 184 |
-
for cam_name in camera_names:
|
| 185 |
-
_ = image.create_dataset(
|
| 186 |
-
cam_name,
|
| 187 |
-
(episode_len, H, W, C),
|
| 188 |
-
dtype="uint8",
|
| 189 |
-
chunks=(1, H, W, C),
|
| 190 |
-
)
|
| 191 |
-
root[f"/observations/images/{cam_name}"][...] = data_dict["image_dict"][cam_name]
|
| 192 |
-
_ = root.create_dataset("action", (episode_len, 14))
|
| 193 |
-
root["/action"][...] = data_dict["action"]
|
| 194 |
-
# Compute and save *relative* actions as well
|
| 195 |
-
actions = data_dict["action"]
|
| 196 |
-
relative_actions = np.zeros_like(actions)
|
| 197 |
-
relative_actions[:-1] = actions[1:] - actions[:-1] # Relative actions are the changes in joint pos
|
| 198 |
-
relative_actions[-1] = relative_actions[-2] # Just copy the second-to-last action for the last action
|
| 199 |
-
_ = root.create_dataset("relative_action", (episode_len, 14))
|
| 200 |
-
root["/relative_action"][...] = relative_actions
|
| 201 |
-
print(f"Saved dataset: {out_path}")
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
def main(args):
|
| 205 |
-
# Create directory to save preprocessed dataset (if it doesn't exist already)
|
| 206 |
-
os.makedirs(args.out_base_dir, exist_ok=True)
|
| 207 |
-
out_dataset_dir = os.path.join(args.out_base_dir, os.path.basename(args.dataset_path.rstrip("/")))
|
| 208 |
-
os.makedirs(out_dataset_dir, exist_ok=True)
|
| 209 |
-
# Get list of filepaths of all episodes
|
| 210 |
-
all_demo_paths = glob.glob(os.path.join(args.dataset_path, "*.hdf5")) # List of HDF5 filepaths
|
| 211 |
-
all_demo_paths.sort()
|
| 212 |
-
# Create a list of episode indices
|
| 213 |
-
num_episodes_total = len(all_demo_paths)
|
| 214 |
-
indices = list(range(num_episodes_total))
|
| 215 |
-
# Shuffle the episode indices
|
| 216 |
-
random.shuffle(indices)
|
| 217 |
-
# Split into train and val sets
|
| 218 |
-
num_episodes_val = int(num_episodes_total * args.percent_val)
|
| 219 |
-
print(f"Total # episodes: {num_episodes_total}; using {num_episodes_val} ({args.percent_val:.2f}%) for val set")
|
| 220 |
-
num_episodes_train = num_episodes_total - num_episodes_val
|
| 221 |
-
train_indices = indices[:num_episodes_train]
|
| 222 |
-
val_indices = indices[num_episodes_train:]
|
| 223 |
-
train_demo_paths = [all_demo_paths[i] for i in train_indices]
|
| 224 |
-
val_demo_paths = [all_demo_paths[i] for i in val_indices]
|
| 225 |
-
# Preprocess all episodes and save the result
|
| 226 |
-
out_dataset_dir_train = os.path.join(out_dataset_dir, "train")
|
| 227 |
-
out_dataset_dir_val = os.path.join(out_dataset_dir, "val")
|
| 228 |
-
os.makedirs(out_dataset_dir_train, exist_ok=True)
|
| 229 |
-
os.makedirs(out_dataset_dir_val, exist_ok=True)
|
| 230 |
-
load_and_preprocess_all_episodes(train_demo_paths, out_dataset_dir_train)
|
| 231 |
-
load_and_preprocess_all_episodes(val_demo_paths, out_dataset_dir_val)
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
if __name__ == "__main__":
|
| 235 |
-
parser = argparse.ArgumentParser()
|
| 236 |
-
parser.add_argument(
|
| 237 |
-
"--dataset_path",
|
| 238 |
-
required=True,
|
| 239 |
-
help="Path to raw ALOHA dataset directory. Example: /PATH/TO/USER/data/aloha_raw/put_green_pepper_into_pot/",
|
| 240 |
-
)
|
| 241 |
-
parser.add_argument(
|
| 242 |
-
"--out_base_dir",
|
| 243 |
-
required=True,
|
| 244 |
-
help="Path to directory in which to save preprocessed dataset. Example: /PATH/TO/USER/data/aloha_preprocessed/",
|
| 245 |
-
)
|
| 246 |
-
parser.add_argument(
|
| 247 |
-
"--percent_val",
|
| 248 |
-
type=float,
|
| 249 |
-
help="Percent of dataset to use as validation set (measured in episodes, not steps).",
|
| 250 |
-
default=0.05,
|
| 251 |
-
)
|
| 252 |
-
parser.add_argument(
|
| 253 |
-
"--img_resize_size",
|
| 254 |
-
type=int,
|
| 255 |
-
help="Size to resize images to. Final images will be square (img_resize_size x img_resize_size pixels).",
|
| 256 |
-
default=256,
|
| 257 |
-
)
|
| 258 |
-
args = parser.parse_args()
|
| 259 |
-
|
| 260 |
-
main(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/experiments/robot/aloha/real_env.py
DELETED
|
@@ -1,213 +0,0 @@
|
|
| 1 |
-
import time
|
| 2 |
-
import numpy as np
|
| 3 |
-
import collections
|
| 4 |
-
import matplotlib.pyplot as plt
|
| 5 |
-
import dm_env
|
| 6 |
-
|
| 7 |
-
from experiments.robot.aloha.constants import DT, START_ARM_POSE, MASTER_GRIPPER_JOINT_NORMALIZE_FN, PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN
|
| 8 |
-
from experiments.robot.aloha.constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN, PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN
|
| 9 |
-
from experiments.robot.aloha.constants import PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
|
| 10 |
-
from experiments.robot.aloha.robot_utils import Recorder, ImageRecorder
|
| 11 |
-
from experiments.robot.aloha.robot_utils import setup_master_bot, setup_puppet_bot, move_arms, move_grippers
|
| 12 |
-
from interbotix_xs_modules.arm import InterbotixManipulatorXS
|
| 13 |
-
from interbotix_xs_msgs.msg import JointSingleCommand
|
| 14 |
-
|
| 15 |
-
import IPython
|
| 16 |
-
e = IPython.embed
|
| 17 |
-
|
| 18 |
-
class RealEnv:
|
| 19 |
-
"""
|
| 20 |
-
Environment for real robot bi-manual manipulation
|
| 21 |
-
Action space: [left_arm_qpos (6), # absolute joint position
|
| 22 |
-
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
| 23 |
-
right_arm_qpos (6), # absolute joint position
|
| 24 |
-
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
| 25 |
-
|
| 26 |
-
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
| 27 |
-
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
| 28 |
-
right_arm_qpos (6), # absolute joint position
|
| 29 |
-
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
| 30 |
-
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
| 31 |
-
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
| 32 |
-
right_arm_qvel (6), # absolute joint velocity (rad)
|
| 33 |
-
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
| 34 |
-
"images": {"cam_high": (480x640x3), # h, w, c, dtype='uint8'
|
| 35 |
-
"cam_low": (480x640x3), # h, w, c, dtype='uint8'
|
| 36 |
-
"cam_left_wrist": (480x640x3), # h, w, c, dtype='uint8'
|
| 37 |
-
"cam_right_wrist": (480x640x3)} # h, w, c, dtype='uint8'
|
| 38 |
-
"""
|
| 39 |
-
|
| 40 |
-
def __init__(self, init_node, setup_robots=True):
|
| 41 |
-
self.puppet_bot_left = InterbotixManipulatorXS(robot_model="vx300s", group_name="arm", gripper_name="gripper",
|
| 42 |
-
robot_name=f'puppet_left', init_node=init_node)
|
| 43 |
-
self.puppet_bot_right = InterbotixManipulatorXS(robot_model="vx300s", group_name="arm", gripper_name="gripper",
|
| 44 |
-
robot_name=f'puppet_right', init_node=False)
|
| 45 |
-
if setup_robots:
|
| 46 |
-
self.setup_robots()
|
| 47 |
-
|
| 48 |
-
self.recorder_left = Recorder('left', init_node=False)
|
| 49 |
-
self.recorder_right = Recorder('right', init_node=False)
|
| 50 |
-
self.image_recorder = ImageRecorder(init_node=False)
|
| 51 |
-
self.gripper_command = JointSingleCommand(name="gripper")
|
| 52 |
-
|
| 53 |
-
def setup_robots(self):
|
| 54 |
-
setup_puppet_bot(self.puppet_bot_left)
|
| 55 |
-
setup_puppet_bot(self.puppet_bot_right)
|
| 56 |
-
|
| 57 |
-
def get_qpos(self):
|
| 58 |
-
left_qpos_raw = self.recorder_left.qpos
|
| 59 |
-
right_qpos_raw = self.recorder_right.qpos
|
| 60 |
-
left_arm_qpos = left_qpos_raw[:6]
|
| 61 |
-
right_arm_qpos = right_qpos_raw[:6]
|
| 62 |
-
left_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[7])] # this is position not joint
|
| 63 |
-
right_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[7])] # this is position not joint
|
| 64 |
-
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
| 65 |
-
|
| 66 |
-
def get_qvel(self):
|
| 67 |
-
left_qvel_raw = self.recorder_left.qvel
|
| 68 |
-
right_qvel_raw = self.recorder_right.qvel
|
| 69 |
-
left_arm_qvel = left_qvel_raw[:6]
|
| 70 |
-
right_arm_qvel = right_qvel_raw[:6]
|
| 71 |
-
left_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[7])]
|
| 72 |
-
right_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[7])]
|
| 73 |
-
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
| 74 |
-
|
| 75 |
-
def get_effort(self):
|
| 76 |
-
left_effort_raw = self.recorder_left.effort
|
| 77 |
-
right_effort_raw = self.recorder_right.effort
|
| 78 |
-
left_robot_effort = left_effort_raw[:7]
|
| 79 |
-
right_robot_effort = right_effort_raw[:7]
|
| 80 |
-
return np.concatenate([left_robot_effort, right_robot_effort])
|
| 81 |
-
|
| 82 |
-
def get_images(self):
|
| 83 |
-
return self.image_recorder.get_images()
|
| 84 |
-
|
| 85 |
-
def set_gripper_pose(self, left_gripper_desired_pos_normalized, right_gripper_desired_pos_normalized):
|
| 86 |
-
left_gripper_desired_joint = PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(left_gripper_desired_pos_normalized)
|
| 87 |
-
self.gripper_command.cmd = left_gripper_desired_joint
|
| 88 |
-
self.puppet_bot_left.gripper.core.pub_single.publish(self.gripper_command)
|
| 89 |
-
|
| 90 |
-
right_gripper_desired_joint = PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(right_gripper_desired_pos_normalized)
|
| 91 |
-
self.gripper_command.cmd = right_gripper_desired_joint
|
| 92 |
-
self.puppet_bot_right.gripper.core.pub_single.publish(self.gripper_command)
|
| 93 |
-
|
| 94 |
-
def _reset_joints(self):
|
| 95 |
-
reset_position = START_ARM_POSE[:6]
|
| 96 |
-
move_arms([self.puppet_bot_left, self.puppet_bot_right], [reset_position, reset_position], move_time=1)
|
| 97 |
-
|
| 98 |
-
def _reset_gripper(self):
|
| 99 |
-
"""Set to position mode and do position resets: first open then close. Then change back to PWM mode"""
|
| 100 |
-
move_grippers([self.puppet_bot_left, self.puppet_bot_right], [PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5)
|
| 101 |
-
move_grippers([self.puppet_bot_left, self.puppet_bot_right], [PUPPET_GRIPPER_JOINT_CLOSE] * 2, move_time=1)
|
| 102 |
-
|
| 103 |
-
def _get_obs(self):
|
| 104 |
-
obs = collections.OrderedDict()
|
| 105 |
-
obs['qpos'] = self.get_qpos()
|
| 106 |
-
obs['qvel'] = self.get_qvel()
|
| 107 |
-
obs['effort'] = self.get_effort()
|
| 108 |
-
obs['images'] = self.get_images()
|
| 109 |
-
return obs
|
| 110 |
-
|
| 111 |
-
def get_observation(self, t=0):
|
| 112 |
-
step_type = dm_env.StepType.FIRST if t == 0 else dm_env.StepType.MID
|
| 113 |
-
return dm_env.TimeStep(
|
| 114 |
-
step_type=step_type,
|
| 115 |
-
reward=self.get_reward(),
|
| 116 |
-
discount=None,
|
| 117 |
-
observation=self._get_obs()
|
| 118 |
-
)
|
| 119 |
-
|
| 120 |
-
def get_reward(self):
|
| 121 |
-
return 0
|
| 122 |
-
|
| 123 |
-
def reset(self, fake=False):
|
| 124 |
-
if not fake:
|
| 125 |
-
# Reboot puppet robot gripper motors
|
| 126 |
-
self.puppet_bot_left.dxl.robot_reboot_motors("single", "gripper", True)
|
| 127 |
-
self.puppet_bot_right.dxl.robot_reboot_motors("single", "gripper", True)
|
| 128 |
-
self._reset_joints()
|
| 129 |
-
self._reset_gripper()
|
| 130 |
-
return dm_env.TimeStep(
|
| 131 |
-
step_type=dm_env.StepType.FIRST,
|
| 132 |
-
reward=self.get_reward(),
|
| 133 |
-
discount=None,
|
| 134 |
-
observation=self._get_obs())
|
| 135 |
-
|
| 136 |
-
def step(self, action):
|
| 137 |
-
state_len = int(len(action) / 2)
|
| 138 |
-
left_action = action[:state_len]
|
| 139 |
-
right_action = action[state_len:]
|
| 140 |
-
self.puppet_bot_left.arm.set_joint_positions(left_action[:6], blocking=False)
|
| 141 |
-
self.puppet_bot_right.arm.set_joint_positions(right_action[:6], blocking=False)
|
| 142 |
-
self.set_gripper_pose(left_action[-1], right_action[-1])
|
| 143 |
-
time.sleep(DT)
|
| 144 |
-
return dm_env.TimeStep(
|
| 145 |
-
step_type=dm_env.StepType.MID,
|
| 146 |
-
reward=self.get_reward(),
|
| 147 |
-
discount=None,
|
| 148 |
-
observation=self._get_obs())
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
def get_action(master_bot_left, master_bot_right):
|
| 152 |
-
action = np.zeros(14) # 6 joint + 1 gripper, for two arms
|
| 153 |
-
# Arm actions
|
| 154 |
-
action[:6] = master_bot_left.dxl.joint_states.position[:6]
|
| 155 |
-
action[7:7+6] = master_bot_right.dxl.joint_states.position[:6]
|
| 156 |
-
# Gripper actions
|
| 157 |
-
action[6] = MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_left.dxl.joint_states.position[6])
|
| 158 |
-
action[7+6] = MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_right.dxl.joint_states.position[6])
|
| 159 |
-
|
| 160 |
-
return action
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
def make_real_env(init_node, setup_robots=True):
|
| 164 |
-
env = RealEnv(init_node, setup_robots)
|
| 165 |
-
return env
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
def test_real_teleop():
|
| 169 |
-
"""
|
| 170 |
-
Test bimanual teleoperation and show image observations onscreen.
|
| 171 |
-
It first reads joint poses from both master arms.
|
| 172 |
-
Then use it as actions to step the environment.
|
| 173 |
-
The environment returns full observations including images.
|
| 174 |
-
|
| 175 |
-
An alternative approach is to have separate scripts for teleoperation and observation recording.
|
| 176 |
-
This script will result in higher fidelity (obs, action) pairs
|
| 177 |
-
"""
|
| 178 |
-
|
| 179 |
-
onscreen_render = True
|
| 180 |
-
render_cam = 'cam_left_wrist'
|
| 181 |
-
|
| 182 |
-
# source of data
|
| 183 |
-
master_bot_left = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper",
|
| 184 |
-
robot_name=f'master_left', init_node=True)
|
| 185 |
-
master_bot_right = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper",
|
| 186 |
-
robot_name=f'master_right', init_node=False)
|
| 187 |
-
setup_master_bot(master_bot_left)
|
| 188 |
-
setup_master_bot(master_bot_right)
|
| 189 |
-
|
| 190 |
-
# setup the environment
|
| 191 |
-
env = make_real_env(init_node=False)
|
| 192 |
-
ts = env.reset(fake=True)
|
| 193 |
-
episode = [ts]
|
| 194 |
-
# setup visualization
|
| 195 |
-
if onscreen_render:
|
| 196 |
-
ax = plt.subplot()
|
| 197 |
-
plt_img = ax.imshow(ts.observation['images'][render_cam])
|
| 198 |
-
plt.ion()
|
| 199 |
-
|
| 200 |
-
for t in range(1000):
|
| 201 |
-
action = get_action(master_bot_left, master_bot_right)
|
| 202 |
-
ts = env.step(action)
|
| 203 |
-
episode.append(ts)
|
| 204 |
-
|
| 205 |
-
if onscreen_render:
|
| 206 |
-
plt_img.set_data(ts.observation['images'][render_cam])
|
| 207 |
-
plt.pause(DT)
|
| 208 |
-
else:
|
| 209 |
-
time.sleep(DT)
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
if __name__ == '__main__':
|
| 213 |
-
test_real_teleop()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/experiments/robot/aloha/requirements_aloha.txt
DELETED
|
@@ -1,26 +0,0 @@
|
|
| 1 |
-
numpy<2
|
| 2 |
-
draccus
|
| 3 |
-
torchvision
|
| 4 |
-
torch
|
| 5 |
-
pyquaternion
|
| 6 |
-
pyyaml
|
| 7 |
-
rospkg
|
| 8 |
-
pexpect
|
| 9 |
-
mujoco==2.3.7
|
| 10 |
-
dm_control==1.0.14
|
| 11 |
-
opencv-python
|
| 12 |
-
matplotlib
|
| 13 |
-
einops
|
| 14 |
-
packaging
|
| 15 |
-
h5py
|
| 16 |
-
traitlets
|
| 17 |
-
ipdb
|
| 18 |
-
IPython
|
| 19 |
-
modern_robotics
|
| 20 |
-
Pillow
|
| 21 |
-
termcolor
|
| 22 |
-
imageio[ffmpeg]
|
| 23 |
-
uvicorn
|
| 24 |
-
fastapi
|
| 25 |
-
requests
|
| 26 |
-
json_numpy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/experiments/robot/aloha/robot_utils.py
DELETED
|
@@ -1,187 +0,0 @@
|
|
| 1 |
-
import numpy as np
|
| 2 |
-
import time
|
| 3 |
-
from experiments.robot.aloha.constants import DT
|
| 4 |
-
from interbotix_xs_msgs.msg import JointSingleCommand
|
| 5 |
-
|
| 6 |
-
import IPython
|
| 7 |
-
e = IPython.embed
|
| 8 |
-
|
| 9 |
-
class ImageRecorder:
|
| 10 |
-
def __init__(self, init_node=True, is_debug=False):
|
| 11 |
-
from collections import deque
|
| 12 |
-
import rospy
|
| 13 |
-
from cv_bridge import CvBridge
|
| 14 |
-
from sensor_msgs.msg import Image
|
| 15 |
-
self.is_debug = is_debug
|
| 16 |
-
self.bridge = CvBridge()
|
| 17 |
-
self.camera_names = ['cam_high', 'cam_low', 'cam_left_wrist', 'cam_right_wrist']
|
| 18 |
-
if init_node:
|
| 19 |
-
rospy.init_node('image_recorder', anonymous=True)
|
| 20 |
-
for cam_name in self.camera_names:
|
| 21 |
-
setattr(self, f'{cam_name}_image', None)
|
| 22 |
-
setattr(self, f'{cam_name}_secs', None)
|
| 23 |
-
setattr(self, f'{cam_name}_nsecs', None)
|
| 24 |
-
if cam_name == 'cam_high':
|
| 25 |
-
callback_func = self.image_cb_cam_high
|
| 26 |
-
elif cam_name == 'cam_low':
|
| 27 |
-
callback_func = self.image_cb_cam_low
|
| 28 |
-
elif cam_name == 'cam_left_wrist':
|
| 29 |
-
callback_func = self.image_cb_cam_left_wrist
|
| 30 |
-
elif cam_name == 'cam_right_wrist':
|
| 31 |
-
callback_func = self.image_cb_cam_right_wrist
|
| 32 |
-
else:
|
| 33 |
-
raise NotImplementedError
|
| 34 |
-
rospy.Subscriber(f"/usb_{cam_name}/image_raw", Image, callback_func)
|
| 35 |
-
if self.is_debug:
|
| 36 |
-
setattr(self, f'{cam_name}_timestamps', deque(maxlen=50))
|
| 37 |
-
time.sleep(0.5)
|
| 38 |
-
|
| 39 |
-
def image_cb(self, cam_name, data):
|
| 40 |
-
setattr(self, f'{cam_name}_image', self.bridge.imgmsg_to_cv2(data, desired_encoding='passthrough'))
|
| 41 |
-
setattr(self, f'{cam_name}_secs', data.header.stamp.secs)
|
| 42 |
-
setattr(self, f'{cam_name}_nsecs', data.header.stamp.nsecs)
|
| 43 |
-
# cv2.imwrite('/home/tonyzhao/Desktop/sample.jpg', cv_image)
|
| 44 |
-
if self.is_debug:
|
| 45 |
-
getattr(self, f'{cam_name}_timestamps').append(data.header.stamp.secs + data.header.stamp.secs * 1e-9)
|
| 46 |
-
|
| 47 |
-
def image_cb_cam_high(self, data):
|
| 48 |
-
cam_name = 'cam_high'
|
| 49 |
-
return self.image_cb(cam_name, data)
|
| 50 |
-
|
| 51 |
-
def image_cb_cam_low(self, data):
|
| 52 |
-
cam_name = 'cam_low'
|
| 53 |
-
return self.image_cb(cam_name, data)
|
| 54 |
-
|
| 55 |
-
def image_cb_cam_left_wrist(self, data):
|
| 56 |
-
cam_name = 'cam_left_wrist'
|
| 57 |
-
return self.image_cb(cam_name, data)
|
| 58 |
-
|
| 59 |
-
def image_cb_cam_right_wrist(self, data):
|
| 60 |
-
cam_name = 'cam_right_wrist'
|
| 61 |
-
return self.image_cb(cam_name, data)
|
| 62 |
-
|
| 63 |
-
def get_images(self):
|
| 64 |
-
image_dict = dict()
|
| 65 |
-
for cam_name in self.camera_names:
|
| 66 |
-
image_dict[cam_name] = getattr(self, f'{cam_name}_image')
|
| 67 |
-
return image_dict
|
| 68 |
-
|
| 69 |
-
def print_diagnostics(self):
|
| 70 |
-
def dt_helper(l):
|
| 71 |
-
l = np.array(l)
|
| 72 |
-
diff = l[1:] - l[:-1]
|
| 73 |
-
return np.mean(diff)
|
| 74 |
-
for cam_name in self.camera_names:
|
| 75 |
-
image_freq = 1 / dt_helper(getattr(self, f'{cam_name}_timestamps'))
|
| 76 |
-
print(f'{cam_name} {image_freq=:.2f}')
|
| 77 |
-
print()
|
| 78 |
-
|
| 79 |
-
class Recorder:
|
| 80 |
-
def __init__(self, side, init_node=True, is_debug=False):
|
| 81 |
-
from collections import deque
|
| 82 |
-
import rospy
|
| 83 |
-
from sensor_msgs.msg import JointState
|
| 84 |
-
from interbotix_xs_msgs.msg import JointGroupCommand, JointSingleCommand
|
| 85 |
-
|
| 86 |
-
self.secs = None
|
| 87 |
-
self.nsecs = None
|
| 88 |
-
self.qpos = None
|
| 89 |
-
self.effort = None
|
| 90 |
-
self.arm_command = None
|
| 91 |
-
self.gripper_command = None
|
| 92 |
-
self.is_debug = is_debug
|
| 93 |
-
|
| 94 |
-
if init_node:
|
| 95 |
-
rospy.init_node('recorder', anonymous=True)
|
| 96 |
-
rospy.Subscriber(f"/puppet_{side}/joint_states", JointState, self.puppet_state_cb)
|
| 97 |
-
rospy.Subscriber(f"/puppet_{side}/commands/joint_group", JointGroupCommand, self.puppet_arm_commands_cb)
|
| 98 |
-
rospy.Subscriber(f"/puppet_{side}/commands/joint_single", JointSingleCommand, self.puppet_gripper_commands_cb)
|
| 99 |
-
if self.is_debug:
|
| 100 |
-
self.joint_timestamps = deque(maxlen=50)
|
| 101 |
-
self.arm_command_timestamps = deque(maxlen=50)
|
| 102 |
-
self.gripper_command_timestamps = deque(maxlen=50)
|
| 103 |
-
time.sleep(0.1)
|
| 104 |
-
|
| 105 |
-
def puppet_state_cb(self, data):
|
| 106 |
-
self.qpos = data.position
|
| 107 |
-
self.qvel = data.velocity
|
| 108 |
-
self.effort = data.effort
|
| 109 |
-
self.data = data
|
| 110 |
-
if self.is_debug:
|
| 111 |
-
self.joint_timestamps.append(time.time())
|
| 112 |
-
|
| 113 |
-
def puppet_arm_commands_cb(self, data):
|
| 114 |
-
self.arm_command = data.cmd
|
| 115 |
-
if self.is_debug:
|
| 116 |
-
self.arm_command_timestamps.append(time.time())
|
| 117 |
-
|
| 118 |
-
def puppet_gripper_commands_cb(self, data):
|
| 119 |
-
self.gripper_command = data.cmd
|
| 120 |
-
if self.is_debug:
|
| 121 |
-
self.gripper_command_timestamps.append(time.time())
|
| 122 |
-
|
| 123 |
-
def print_diagnostics(self):
|
| 124 |
-
def dt_helper(l):
|
| 125 |
-
l = np.array(l)
|
| 126 |
-
diff = l[1:] - l[:-1]
|
| 127 |
-
return np.mean(diff)
|
| 128 |
-
|
| 129 |
-
joint_freq = 1 / dt_helper(self.joint_timestamps)
|
| 130 |
-
arm_command_freq = 1 / dt_helper(self.arm_command_timestamps)
|
| 131 |
-
gripper_command_freq = 1 / dt_helper(self.gripper_command_timestamps)
|
| 132 |
-
|
| 133 |
-
print(f'{joint_freq=:.2f}\n{arm_command_freq=:.2f}\n{gripper_command_freq=:.2f}\n')
|
| 134 |
-
|
| 135 |
-
def get_arm_joint_positions(bot):
|
| 136 |
-
return bot.arm.core.joint_states.position[:6]
|
| 137 |
-
|
| 138 |
-
def get_arm_gripper_positions(bot):
|
| 139 |
-
joint_position = bot.gripper.core.joint_states.position[6]
|
| 140 |
-
return joint_position
|
| 141 |
-
|
| 142 |
-
def move_arms(bot_list, target_pose_list, move_time=1):
|
| 143 |
-
num_steps = int(move_time / DT)
|
| 144 |
-
curr_pose_list = [get_arm_joint_positions(bot) for bot in bot_list]
|
| 145 |
-
traj_list = [np.linspace(curr_pose, target_pose, num_steps) for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)]
|
| 146 |
-
for t in range(num_steps):
|
| 147 |
-
for bot_id, bot in enumerate(bot_list):
|
| 148 |
-
bot.arm.set_joint_positions(traj_list[bot_id][t], blocking=False)
|
| 149 |
-
time.sleep(DT)
|
| 150 |
-
|
| 151 |
-
def move_grippers(bot_list, target_pose_list, move_time):
|
| 152 |
-
gripper_command = JointSingleCommand(name="gripper")
|
| 153 |
-
num_steps = int(move_time / DT)
|
| 154 |
-
curr_pose_list = [get_arm_gripper_positions(bot) for bot in bot_list]
|
| 155 |
-
traj_list = [np.linspace(curr_pose, target_pose, num_steps) for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)]
|
| 156 |
-
for t in range(num_steps):
|
| 157 |
-
for bot_id, bot in enumerate(bot_list):
|
| 158 |
-
gripper_command.cmd = traj_list[bot_id][t]
|
| 159 |
-
bot.gripper.core.pub_single.publish(gripper_command)
|
| 160 |
-
time.sleep(DT)
|
| 161 |
-
|
| 162 |
-
def setup_puppet_bot(bot):
|
| 163 |
-
bot.dxl.robot_reboot_motors("single", "gripper", True)
|
| 164 |
-
bot.dxl.robot_set_operating_modes("group", "arm", "position")
|
| 165 |
-
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
|
| 166 |
-
torque_on(bot)
|
| 167 |
-
|
| 168 |
-
def setup_master_bot(bot):
|
| 169 |
-
bot.dxl.robot_set_operating_modes("group", "arm", "pwm")
|
| 170 |
-
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
|
| 171 |
-
torque_off(bot)
|
| 172 |
-
|
| 173 |
-
def set_standard_pid_gains(bot):
|
| 174 |
-
bot.dxl.robot_set_motor_registers("group", "arm", 'Position_P_Gain', 800)
|
| 175 |
-
bot.dxl.robot_set_motor_registers("group", "arm", 'Position_I_Gain', 0)
|
| 176 |
-
|
| 177 |
-
def set_low_pid_gains(bot):
|
| 178 |
-
bot.dxl.robot_set_motor_registers("group", "arm", 'Position_P_Gain', 100)
|
| 179 |
-
bot.dxl.robot_set_motor_registers("group", "arm", 'Position_I_Gain', 0)
|
| 180 |
-
|
| 181 |
-
def torque_off(bot):
|
| 182 |
-
bot.dxl.robot_torque_enable("group", "arm", False)
|
| 183 |
-
bot.dxl.robot_torque_enable("single", "gripper", False)
|
| 184 |
-
|
| 185 |
-
def torque_on(bot):
|
| 186 |
-
bot.dxl.robot_torque_enable("group", "arm", True)
|
| 187 |
-
bot.dxl.robot_torque_enable("single", "gripper", True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/experiments/robot/aloha/run_aloha_eval.py
DELETED
|
@@ -1,385 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
run_aloha_eval.py
|
| 3 |
-
|
| 4 |
-
Evaluates a model in a real-world ALOHA environment.
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
import logging
|
| 8 |
-
import os
|
| 9 |
-
import socket
|
| 10 |
-
import sys
|
| 11 |
-
import time
|
| 12 |
-
from collections import deque
|
| 13 |
-
from dataclasses import dataclass
|
| 14 |
-
from pathlib import Path
|
| 15 |
-
from typing import Optional, Union
|
| 16 |
-
|
| 17 |
-
import draccus
|
| 18 |
-
import tqdm
|
| 19 |
-
|
| 20 |
-
# Append current directory so that interpreter can find experiments.robot
|
| 21 |
-
sys.path.append(".")
|
| 22 |
-
from experiments.robot.aloha.aloha_utils import (
|
| 23 |
-
get_aloha_env,
|
| 24 |
-
get_aloha_image,
|
| 25 |
-
get_aloha_wrist_images,
|
| 26 |
-
get_next_task_label,
|
| 27 |
-
save_rollout_video,
|
| 28 |
-
)
|
| 29 |
-
from experiments.robot.openvla_utils import (
|
| 30 |
-
get_action_from_server,
|
| 31 |
-
resize_image_for_policy,
|
| 32 |
-
)
|
| 33 |
-
from experiments.robot.robot_utils import (
|
| 34 |
-
DATE_TIME,
|
| 35 |
-
get_image_resize_size,
|
| 36 |
-
set_seed_everywhere,
|
| 37 |
-
)
|
| 38 |
-
|
| 39 |
-
# Set up logging
|
| 40 |
-
logging.basicConfig(
|
| 41 |
-
level=logging.INFO,
|
| 42 |
-
format="%(asctime)s [%(levelname)s] %(message)s",
|
| 43 |
-
handlers=[logging.StreamHandler()],
|
| 44 |
-
)
|
| 45 |
-
logger = logging.getLogger(__name__)
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
@dataclass
|
| 49 |
-
class GenerateConfig:
|
| 50 |
-
# fmt: off
|
| 51 |
-
|
| 52 |
-
#################################################################################################################
|
| 53 |
-
# Model-specific parameters
|
| 54 |
-
#################################################################################################################
|
| 55 |
-
model_family: str = "openvla" # Model family
|
| 56 |
-
|
| 57 |
-
center_crop: bool = True # Center crop? (if trained w/ random crop image aug)
|
| 58 |
-
num_open_loop_steps: int = 25 # Number of actions to execute open-loop before requerying policy
|
| 59 |
-
|
| 60 |
-
use_vla_server: bool = True # Whether to query remote VLA server for actions
|
| 61 |
-
vla_server_url: Union[str, Path] = "" # Remote VLA server URL (set to 127.0.0.1 if on same machine)
|
| 62 |
-
|
| 63 |
-
#################################################################################################################
|
| 64 |
-
# ALOHA environment-specific parameters
|
| 65 |
-
#################################################################################################################
|
| 66 |
-
num_rollouts_planned: int = 50 # Number of test rollouts
|
| 67 |
-
max_steps: int = 1500 # Max number of steps per rollout
|
| 68 |
-
use_relative_actions: bool = False # Whether to use relative actions (delta joint angles)
|
| 69 |
-
|
| 70 |
-
#################################################################################################################
|
| 71 |
-
# Utils
|
| 72 |
-
#################################################################################################################
|
| 73 |
-
run_id_note: Optional[str] = None # Extra note to add to end of run ID for logging
|
| 74 |
-
local_log_dir: str = "./experiments/logs" # Local directory for eval logs
|
| 75 |
-
|
| 76 |
-
seed: int = 7 # Random Seed (for reproducibility)
|
| 77 |
-
|
| 78 |
-
# fmt: on
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
def validate_config(cfg: GenerateConfig) -> None:
|
| 82 |
-
"""Validate configuration parameters."""
|
| 83 |
-
assert cfg.use_vla_server, (
|
| 84 |
-
"Must use VLA server (server-client interface) to query model and get actions! Please set --use_vla_server=True"
|
| 85 |
-
)
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
def setup_logging(cfg: GenerateConfig):
|
| 89 |
-
"""Set up logging to file."""
|
| 90 |
-
# Create run ID
|
| 91 |
-
run_id = f"EVAL-{cfg.model_family}-{DATE_TIME}"
|
| 92 |
-
if cfg.run_id_note is not None:
|
| 93 |
-
run_id += f"--{cfg.run_id_note}"
|
| 94 |
-
|
| 95 |
-
# Set up local logging
|
| 96 |
-
os.makedirs(cfg.local_log_dir, exist_ok=True)
|
| 97 |
-
local_log_filepath = os.path.join(cfg.local_log_dir, run_id + ".txt")
|
| 98 |
-
log_file = open(local_log_filepath, "w")
|
| 99 |
-
logger.info(f"Logging to local log file: {local_log_filepath}")
|
| 100 |
-
|
| 101 |
-
return log_file, local_log_filepath, run_id
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
def log_message(message: str, log_file=None):
|
| 105 |
-
"""Log a message to console and optionally to a log file."""
|
| 106 |
-
print(message)
|
| 107 |
-
logger.info(message)
|
| 108 |
-
if log_file:
|
| 109 |
-
log_file.write(message + "\n")
|
| 110 |
-
log_file.flush()
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
def get_server_endpoint(cfg: GenerateConfig):
|
| 114 |
-
"""Get the server endpoint for remote inference."""
|
| 115 |
-
ip_address = socket.gethostbyname(cfg.vla_server_url)
|
| 116 |
-
return f"http://{ip_address}:8777/act"
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
def prepare_observation(obs, resize_size):
|
| 120 |
-
"""Prepare observation for policy input."""
|
| 121 |
-
# Get preprocessed images
|
| 122 |
-
img = get_aloha_image(obs)
|
| 123 |
-
left_wrist_img, right_wrist_img = get_aloha_wrist_images(obs)
|
| 124 |
-
|
| 125 |
-
# Resize images to size expected by model
|
| 126 |
-
img_resized = resize_image_for_policy(img, resize_size)
|
| 127 |
-
left_wrist_img_resized = resize_image_for_policy(left_wrist_img, resize_size)
|
| 128 |
-
right_wrist_img_resized = resize_image_for_policy(right_wrist_img, resize_size)
|
| 129 |
-
|
| 130 |
-
# Prepare observations dict
|
| 131 |
-
observation = {
|
| 132 |
-
"full_image": img_resized,
|
| 133 |
-
"left_wrist_image": left_wrist_img_resized,
|
| 134 |
-
"right_wrist_image": right_wrist_img_resized,
|
| 135 |
-
"state": obs.observation["qpos"],
|
| 136 |
-
}
|
| 137 |
-
|
| 138 |
-
return observation, img_resized, left_wrist_img_resized, right_wrist_img_resized
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
def run_episode(
|
| 142 |
-
cfg: GenerateConfig,
|
| 143 |
-
env,
|
| 144 |
-
task_description: str,
|
| 145 |
-
server_endpoint: str,
|
| 146 |
-
resize_size,
|
| 147 |
-
log_file=None,
|
| 148 |
-
):
|
| 149 |
-
"""Run a single episode in the ALOHA environment."""
|
| 150 |
-
# Define control frequency
|
| 151 |
-
STEP_DURATION_IN_SEC = 1.0 / 25.0
|
| 152 |
-
|
| 153 |
-
# Reset environment
|
| 154 |
-
obs = env.reset()
|
| 155 |
-
|
| 156 |
-
# Initialize action queue
|
| 157 |
-
action_queue = deque(maxlen=cfg.num_open_loop_steps)
|
| 158 |
-
|
| 159 |
-
# Setup
|
| 160 |
-
t = 0
|
| 161 |
-
curr_state = None
|
| 162 |
-
replay_images = []
|
| 163 |
-
replay_images_resized = []
|
| 164 |
-
replay_images_left_wrist_resized = []
|
| 165 |
-
replay_images_right_wrist_resized = []
|
| 166 |
-
|
| 167 |
-
log_message("Prepare the scene, and then press Enter to begin...", log_file)
|
| 168 |
-
input()
|
| 169 |
-
|
| 170 |
-
# Reset environment again to fetch first timestep observation
|
| 171 |
-
obs = env.reset()
|
| 172 |
-
|
| 173 |
-
# Fetch initial robot state (but sleep first so that robot stops moving)
|
| 174 |
-
time.sleep(2)
|
| 175 |
-
curr_state = env.get_qpos()
|
| 176 |
-
|
| 177 |
-
episode_start_time = time.time()
|
| 178 |
-
total_model_query_time = 0.0
|
| 179 |
-
|
| 180 |
-
try:
|
| 181 |
-
while t < cfg.max_steps:
|
| 182 |
-
# Get step start time (used to compute how much to sleep between steps)
|
| 183 |
-
step_start_time = time.time()
|
| 184 |
-
|
| 185 |
-
# Get observation
|
| 186 |
-
obs = env.get_observation(t=t)
|
| 187 |
-
|
| 188 |
-
# Save raw high camera image for replay video
|
| 189 |
-
replay_images.append(obs.observation["images"]["cam_high"])
|
| 190 |
-
|
| 191 |
-
# If action queue is empty, requery model
|
| 192 |
-
if len(action_queue) == 0:
|
| 193 |
-
# Prepare observation
|
| 194 |
-
observation, img_resized, left_wrist_resized, right_wrist_resized = prepare_observation(obs, resize_size)
|
| 195 |
-
observation["instruction"] = task_description
|
| 196 |
-
|
| 197 |
-
# Save processed images for replay
|
| 198 |
-
replay_images_resized.append(img_resized)
|
| 199 |
-
replay_images_left_wrist_resized.append(left_wrist_resized)
|
| 200 |
-
replay_images_right_wrist_resized.append(right_wrist_resized)
|
| 201 |
-
|
| 202 |
-
# Query model to get action
|
| 203 |
-
log_message("Requerying model...", log_file)
|
| 204 |
-
model_query_start_time = time.time()
|
| 205 |
-
actions = get_action_from_server(observation, server_endpoint)
|
| 206 |
-
actions = actions[: cfg.num_open_loop_steps]
|
| 207 |
-
total_model_query_time += time.time() - model_query_start_time
|
| 208 |
-
action_queue.extend(actions)
|
| 209 |
-
|
| 210 |
-
# Get action from queue
|
| 211 |
-
action = action_queue.popleft()
|
| 212 |
-
log_message("-----------------------------------------------------", log_file)
|
| 213 |
-
log_message(f"t: {t}", log_file)
|
| 214 |
-
log_message(f"action: {action}", log_file)
|
| 215 |
-
|
| 216 |
-
# Execute action in environment
|
| 217 |
-
if cfg.use_relative_actions:
|
| 218 |
-
# Get absolute joint angles from relative action
|
| 219 |
-
rel_action = action
|
| 220 |
-
target_state = curr_state + rel_action
|
| 221 |
-
obs = env.step(target_state.tolist())
|
| 222 |
-
# Update current state (assume it is the commanded target state)
|
| 223 |
-
curr_state = target_state
|
| 224 |
-
else:
|
| 225 |
-
obs = env.step(action.tolist())
|
| 226 |
-
t += 1
|
| 227 |
-
|
| 228 |
-
# Sleep until next timestep
|
| 229 |
-
step_elapsed_time = time.time() - step_start_time
|
| 230 |
-
if step_elapsed_time < STEP_DURATION_IN_SEC:
|
| 231 |
-
time_to_sleep = STEP_DURATION_IN_SEC - step_elapsed_time
|
| 232 |
-
log_message(f"Sleeping {time_to_sleep} sec...", log_file)
|
| 233 |
-
time.sleep(time_to_sleep)
|
| 234 |
-
|
| 235 |
-
except (KeyboardInterrupt, Exception) as e:
|
| 236 |
-
if isinstance(e, KeyboardInterrupt):
|
| 237 |
-
log_message("\nCaught KeyboardInterrupt: Terminating episode early.", log_file)
|
| 238 |
-
else:
|
| 239 |
-
log_message(f"\nCaught exception: {e}", log_file)
|
| 240 |
-
|
| 241 |
-
episode_end_time = time.time()
|
| 242 |
-
|
| 243 |
-
# Get success feedback from user
|
| 244 |
-
user_input = input("Success? Enter 'y' or 'n': ")
|
| 245 |
-
success = True if user_input.lower() == "y" else False
|
| 246 |
-
|
| 247 |
-
# Calculate episode statistics
|
| 248 |
-
episode_stats = {
|
| 249 |
-
"success": success,
|
| 250 |
-
"total_steps": t,
|
| 251 |
-
"model_query_time": total_model_query_time,
|
| 252 |
-
"episode_duration": episode_end_time - episode_start_time,
|
| 253 |
-
}
|
| 254 |
-
|
| 255 |
-
return (
|
| 256 |
-
episode_stats,
|
| 257 |
-
replay_images,
|
| 258 |
-
replay_images_resized,
|
| 259 |
-
replay_images_left_wrist_resized,
|
| 260 |
-
replay_images_right_wrist_resized,
|
| 261 |
-
)
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
def save_episode_videos(
|
| 265 |
-
replay_images,
|
| 266 |
-
replay_images_resized,
|
| 267 |
-
replay_images_left_wrist,
|
| 268 |
-
replay_images_right_wrist,
|
| 269 |
-
episode_idx,
|
| 270 |
-
success,
|
| 271 |
-
task_description,
|
| 272 |
-
log_file=None,
|
| 273 |
-
):
|
| 274 |
-
"""Save videos of the episode from different camera angles."""
|
| 275 |
-
# Save main replay video
|
| 276 |
-
save_rollout_video(replay_images, episode_idx, success=success, task_description=task_description, log_file=log_file)
|
| 277 |
-
|
| 278 |
-
# Save processed view videos
|
| 279 |
-
save_rollout_video(
|
| 280 |
-
replay_images_resized,
|
| 281 |
-
episode_idx,
|
| 282 |
-
success=success,
|
| 283 |
-
task_description=task_description,
|
| 284 |
-
log_file=log_file,
|
| 285 |
-
notes="resized",
|
| 286 |
-
)
|
| 287 |
-
save_rollout_video(
|
| 288 |
-
replay_images_left_wrist,
|
| 289 |
-
episode_idx,
|
| 290 |
-
success=success,
|
| 291 |
-
task_description=task_description,
|
| 292 |
-
log_file=log_file,
|
| 293 |
-
notes="left_wrist_resized",
|
| 294 |
-
)
|
| 295 |
-
save_rollout_video(
|
| 296 |
-
replay_images_right_wrist,
|
| 297 |
-
episode_idx,
|
| 298 |
-
success=success,
|
| 299 |
-
task_description=task_description,
|
| 300 |
-
log_file=log_file,
|
| 301 |
-
notes="right_wrist_resized",
|
| 302 |
-
)
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
@draccus.wrap()
|
| 306 |
-
def eval_aloha(cfg: GenerateConfig) -> None:
|
| 307 |
-
"""Main function to evaluate a trained policy in a real-world ALOHA environment."""
|
| 308 |
-
# Validate configuration
|
| 309 |
-
validate_config(cfg)
|
| 310 |
-
|
| 311 |
-
# Set random seed
|
| 312 |
-
set_seed_everywhere(cfg.seed)
|
| 313 |
-
|
| 314 |
-
# Setup logging
|
| 315 |
-
log_file, local_log_filepath, run_id = setup_logging(cfg)
|
| 316 |
-
|
| 317 |
-
# Get expected image dimensions
|
| 318 |
-
resize_size = get_image_resize_size(cfg)
|
| 319 |
-
|
| 320 |
-
# Get ALOHA environment
|
| 321 |
-
env = get_aloha_env()
|
| 322 |
-
|
| 323 |
-
# Get server endpoint for remote inference
|
| 324 |
-
server_endpoint = get_server_endpoint(cfg)
|
| 325 |
-
|
| 326 |
-
# Initialize task description
|
| 327 |
-
task_description = ""
|
| 328 |
-
|
| 329 |
-
# Start evaluation
|
| 330 |
-
num_rollouts_completed, total_successes = 0, 0
|
| 331 |
-
|
| 332 |
-
for episode_idx in tqdm.tqdm(range(cfg.num_rollouts_planned)):
|
| 333 |
-
# Get task description from user
|
| 334 |
-
task_description = get_next_task_label(task_description)
|
| 335 |
-
log_message(f"\nTask: {task_description}", log_file)
|
| 336 |
-
|
| 337 |
-
log_message(f"Starting episode {num_rollouts_completed + 1}...", log_file)
|
| 338 |
-
|
| 339 |
-
# Run episode
|
| 340 |
-
episode_stats, replay_images, replay_images_resized, replay_images_left_wrist, replay_images_right_wrist = (
|
| 341 |
-
run_episode(cfg, env, task_description, server_endpoint, resize_size, log_file)
|
| 342 |
-
)
|
| 343 |
-
|
| 344 |
-
# Update counters
|
| 345 |
-
num_rollouts_completed += 1
|
| 346 |
-
if episode_stats["success"]:
|
| 347 |
-
total_successes += 1
|
| 348 |
-
|
| 349 |
-
# Save videos
|
| 350 |
-
save_episode_videos(
|
| 351 |
-
replay_images,
|
| 352 |
-
replay_images_resized,
|
| 353 |
-
replay_images_left_wrist,
|
| 354 |
-
replay_images_right_wrist,
|
| 355 |
-
num_rollouts_completed,
|
| 356 |
-
episode_stats["success"],
|
| 357 |
-
task_description,
|
| 358 |
-
log_file,
|
| 359 |
-
)
|
| 360 |
-
|
| 361 |
-
# Log results
|
| 362 |
-
log_message(f"Success: {episode_stats['success']}", log_file)
|
| 363 |
-
log_message(f"# episodes completed so far: {num_rollouts_completed}", log_file)
|
| 364 |
-
log_message(f"# successes: {total_successes} ({total_successes / num_rollouts_completed * 100:.1f}%)", log_file)
|
| 365 |
-
log_message(f"Total model query time: {episode_stats['model_query_time']:.2f} sec", log_file)
|
| 366 |
-
log_message(f"Total episode elapsed time: {episode_stats['episode_duration']:.2f} sec", log_file)
|
| 367 |
-
|
| 368 |
-
# Calculate final success rate
|
| 369 |
-
final_success_rate = float(total_successes) / float(num_rollouts_completed) if num_rollouts_completed > 0 else 0
|
| 370 |
-
|
| 371 |
-
# Log final results
|
| 372 |
-
log_message("\nFinal results:", log_file)
|
| 373 |
-
log_message(f"Total episodes: {num_rollouts_completed}", log_file)
|
| 374 |
-
log_message(f"Total successes: {total_successes}", log_file)
|
| 375 |
-
log_message(f"Overall success rate: {final_success_rate:.4f} ({final_success_rate * 100:.1f}%)", log_file)
|
| 376 |
-
|
| 377 |
-
# Close log file
|
| 378 |
-
if log_file:
|
| 379 |
-
log_file.close()
|
| 380 |
-
|
| 381 |
-
return final_success_rate
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
if __name__ == "__main__":
|
| 385 |
-
eval_aloha()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/experiments/robot/libero/libero_requirements.txt
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
imageio[ffmpeg]
|
| 2 |
-
robosuite==1.4.1
|
| 3 |
-
bddl
|
| 4 |
-
easydict
|
| 5 |
-
cloudpickle
|
| 6 |
-
gym
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/experiments/robot/libero/libero_utils.py
DELETED
|
@@ -1,87 +0,0 @@
|
|
| 1 |
-
"""Utils for evaluating policies in LIBERO simulation environments."""
|
| 2 |
-
|
| 3 |
-
import math
|
| 4 |
-
import os
|
| 5 |
-
|
| 6 |
-
import imageio
|
| 7 |
-
import numpy as np
|
| 8 |
-
import tensorflow as tf
|
| 9 |
-
from libero.libero import get_libero_path
|
| 10 |
-
from libero.libero.envs import OffScreenRenderEnv
|
| 11 |
-
|
| 12 |
-
from experiments.robot.robot_utils import (
|
| 13 |
-
DATE,
|
| 14 |
-
DATE_TIME,
|
| 15 |
-
)
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def get_libero_env(task, model_family, resolution=256):
|
| 19 |
-
"""Initializes and returns the LIBERO environment, along with the task description."""
|
| 20 |
-
task_description = task.language
|
| 21 |
-
task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
|
| 22 |
-
env_args = {"bddl_file_name": task_bddl_file, "camera_heights": resolution, "camera_widths": resolution}
|
| 23 |
-
env = OffScreenRenderEnv(**env_args)
|
| 24 |
-
env.seed(0) # IMPORTANT: seed seems to affect object positions even when using fixed initial state
|
| 25 |
-
return env, task_description
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def get_libero_dummy_action(model_family: str):
|
| 29 |
-
"""Get dummy/no-op action, used to roll out the simulation while the robot does nothing."""
|
| 30 |
-
return [0, 0, 0, 0, 0, 0, -1]
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
def get_libero_image(obs):
|
| 34 |
-
"""Extracts third-person image from observations and preprocesses it."""
|
| 35 |
-
img = obs["agentview_image"]
|
| 36 |
-
img = img[::-1, ::-1] # IMPORTANT: rotate 180 degrees to match train preprocessing
|
| 37 |
-
return img
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
def get_libero_wrist_image(obs):
|
| 41 |
-
"""Extracts wrist camera image from observations and preprocesses it."""
|
| 42 |
-
img = obs["robot0_eye_in_hand_image"]
|
| 43 |
-
img = img[::-1, ::-1] # IMPORTANT: rotate 180 degrees to match train preprocessing
|
| 44 |
-
return img
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
def save_rollout_video(rollout_images, idx, success, task_description, log_file=None):
|
| 48 |
-
"""Saves an MP4 replay of an episode."""
|
| 49 |
-
rollout_dir = f"./rollouts/{DATE}"
|
| 50 |
-
os.makedirs(rollout_dir, exist_ok=True)
|
| 51 |
-
processed_task_description = task_description.lower().replace(" ", "_").replace("\n", "_").replace(".", "_")[:50]
|
| 52 |
-
mp4_path = f"{rollout_dir}/{DATE_TIME}--openvla_oft--episode={idx}--success={success}--task={processed_task_description}.mp4"
|
| 53 |
-
video_writer = imageio.get_writer(mp4_path, fps=30)
|
| 54 |
-
for img in rollout_images:
|
| 55 |
-
video_writer.append_data(img)
|
| 56 |
-
video_writer.close()
|
| 57 |
-
print(f"Saved rollout MP4 at path {mp4_path}")
|
| 58 |
-
if log_file is not None:
|
| 59 |
-
log_file.write(f"Saved rollout MP4 at path {mp4_path}\n")
|
| 60 |
-
return mp4_path
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
def quat2axisangle(quat):
|
| 64 |
-
"""
|
| 65 |
-
Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
|
| 66 |
-
|
| 67 |
-
Converts quaternion to axis-angle format.
|
| 68 |
-
Returns a unit vector direction scaled by its angle in radians.
|
| 69 |
-
|
| 70 |
-
Args:
|
| 71 |
-
quat (np.array): (x,y,z,w) vec4 float angles
|
| 72 |
-
|
| 73 |
-
Returns:
|
| 74 |
-
np.array: (ax,ay,az) axis-angle exponential coordinates
|
| 75 |
-
"""
|
| 76 |
-
# clip quaternion
|
| 77 |
-
if quat[3] > 1.0:
|
| 78 |
-
quat[3] = 1.0
|
| 79 |
-
elif quat[3] < -1.0:
|
| 80 |
-
quat[3] = -1.0
|
| 81 |
-
|
| 82 |
-
den = np.sqrt(1.0 - quat[3] * quat[3])
|
| 83 |
-
if math.isclose(den, 0.0):
|
| 84 |
-
# This is (close to) a zero degree rotation, immediately return
|
| 85 |
-
return np.zeros(3)
|
| 86 |
-
|
| 87 |
-
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/experiments/robot/libero/regenerate_libero_dataset.py
DELETED
|
@@ -1,249 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Regenerates a LIBERO dataset (HDF5 files) by replaying demonstrations in the environments.
|
| 3 |
-
|
| 4 |
-
Notes:
|
| 5 |
-
- We save image observations at 256x256px resolution (instead of 128x128).
|
| 6 |
-
- We filter out transitions with "no-op" (zero) actions that do not change the robot's state.
|
| 7 |
-
- We filter out unsuccessful demonstrations.
|
| 8 |
-
- In the LIBERO HDF5 data -> RLDS data conversion (not shown here), we rotate the images by
|
| 9 |
-
180 degrees because we observe that the environments return images that are upside down
|
| 10 |
-
on our platform.
|
| 11 |
-
|
| 12 |
-
Usage:
|
| 13 |
-
python experiments/robot/libero/regenerate_libero_dataset.py \
|
| 14 |
-
--libero_task_suite [ libero_spatial | libero_object | libero_goal | libero_10 ] \
|
| 15 |
-
--libero_raw_data_dir <PATH TO RAW HDF5 DATASET DIR> \
|
| 16 |
-
--libero_target_dir <PATH TO TARGET DIR>
|
| 17 |
-
|
| 18 |
-
Example (LIBERO-Spatial):
|
| 19 |
-
python experiments/robot/libero/regenerate_libero_dataset.py \
|
| 20 |
-
--libero_task_suite libero_spatial \
|
| 21 |
-
--libero_raw_data_dir ./LIBERO/libero/datasets/libero_spatial \
|
| 22 |
-
--libero_target_dir ./LIBERO/libero/datasets/libero_spatial_no_noops
|
| 23 |
-
|
| 24 |
-
"""
|
| 25 |
-
|
| 26 |
-
import argparse
|
| 27 |
-
import json
|
| 28 |
-
import os
|
| 29 |
-
import time
|
| 30 |
-
|
| 31 |
-
import h5py
|
| 32 |
-
import numpy as np
|
| 33 |
-
import robosuite.utils.transform_utils as T
|
| 34 |
-
import tqdm
|
| 35 |
-
from libero.libero import benchmark
|
| 36 |
-
|
| 37 |
-
from experiments.robot.libero.libero_utils import (
|
| 38 |
-
get_libero_dummy_action,
|
| 39 |
-
get_libero_env,
|
| 40 |
-
)
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
IMAGE_RESOLUTION = 256
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
def is_noop(action, prev_action=None, threshold=1e-4):
|
| 47 |
-
"""
|
| 48 |
-
Returns whether an action is a no-op action.
|
| 49 |
-
|
| 50 |
-
A no-op action satisfies two criteria:
|
| 51 |
-
(1) All action dimensions, except for the last one (gripper action), are near zero.
|
| 52 |
-
(2) The gripper action is equal to the previous timestep's gripper action.
|
| 53 |
-
|
| 54 |
-
Explanation of (2):
|
| 55 |
-
Naively filtering out actions with just criterion (1) is not good because you will
|
| 56 |
-
remove actions where the robot is staying still but opening/closing its gripper.
|
| 57 |
-
So you also need to consider the current state (by checking the previous timestep's
|
| 58 |
-
gripper action as a proxy) to determine whether the action really is a no-op.
|
| 59 |
-
"""
|
| 60 |
-
# Special case: Previous action is None if this is the first action in the episode
|
| 61 |
-
# Then we only care about criterion (1)
|
| 62 |
-
if prev_action is None:
|
| 63 |
-
return np.linalg.norm(action[:-1]) < threshold
|
| 64 |
-
|
| 65 |
-
# Normal case: Check both criteria (1) and (2)
|
| 66 |
-
gripper_action = action[-1]
|
| 67 |
-
prev_gripper_action = prev_action[-1]
|
| 68 |
-
return np.linalg.norm(action[:-1]) < threshold and gripper_action == prev_gripper_action
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
def main(args):
|
| 72 |
-
print(f"Regenerating {args.libero_task_suite} dataset!")
|
| 73 |
-
|
| 74 |
-
# Create target directory
|
| 75 |
-
if os.path.isdir(args.libero_target_dir):
|
| 76 |
-
user_input = input(f"Target directory already exists at path: {args.libero_target_dir}\nEnter 'y' to overwrite the directory, or anything else to exit: ")
|
| 77 |
-
if user_input != 'y':
|
| 78 |
-
exit()
|
| 79 |
-
os.makedirs(args.libero_target_dir, exist_ok=True)
|
| 80 |
-
|
| 81 |
-
# Prepare JSON file to record success/false and initial states per episode
|
| 82 |
-
metainfo_json_dict = {}
|
| 83 |
-
metainfo_json_out_path = f"./experiments/robot/libero/{args.libero_task_suite}_metainfo.json"
|
| 84 |
-
with open(metainfo_json_out_path, "w") as f:
|
| 85 |
-
# Just test that we can write to this file (we overwrite it later)
|
| 86 |
-
json.dump(metainfo_json_dict, f)
|
| 87 |
-
|
| 88 |
-
# Get task suite
|
| 89 |
-
benchmark_dict = benchmark.get_benchmark_dict()
|
| 90 |
-
task_suite = benchmark_dict[args.libero_task_suite]()
|
| 91 |
-
num_tasks_in_suite = task_suite.n_tasks
|
| 92 |
-
|
| 93 |
-
# Setup
|
| 94 |
-
num_replays = 0
|
| 95 |
-
num_success = 0
|
| 96 |
-
num_noops = 0
|
| 97 |
-
|
| 98 |
-
for task_id in tqdm.tqdm(range(num_tasks_in_suite)):
|
| 99 |
-
# Get task in suite
|
| 100 |
-
task = task_suite.get_task(task_id)
|
| 101 |
-
env, task_description = get_libero_env(task, "llava", resolution=IMAGE_RESOLUTION)
|
| 102 |
-
|
| 103 |
-
# Get dataset for task
|
| 104 |
-
orig_data_path = os.path.join(args.libero_raw_data_dir, f"{task.name}_demo.hdf5")
|
| 105 |
-
assert os.path.exists(orig_data_path), f"Cannot find raw data file {orig_data_path}."
|
| 106 |
-
orig_data_file = h5py.File(orig_data_path, "r")
|
| 107 |
-
orig_data = orig_data_file["data"]
|
| 108 |
-
|
| 109 |
-
# Create new HDF5 file for regenerated demos
|
| 110 |
-
new_data_path = os.path.join(args.libero_target_dir, f"{task.name}_demo.hdf5")
|
| 111 |
-
new_data_file = h5py.File(new_data_path, "w")
|
| 112 |
-
grp = new_data_file.create_group("data")
|
| 113 |
-
|
| 114 |
-
for i in range(len(orig_data.keys())):
|
| 115 |
-
# Get demo data
|
| 116 |
-
demo_data = orig_data[f"demo_{i}"]
|
| 117 |
-
orig_actions = demo_data["actions"][()]
|
| 118 |
-
orig_states = demo_data["states"][()]
|
| 119 |
-
|
| 120 |
-
# Reset environment, set initial state, and wait a few steps for environment to settle
|
| 121 |
-
env.reset()
|
| 122 |
-
env.set_init_state(orig_states[0])
|
| 123 |
-
for _ in range(10):
|
| 124 |
-
obs, reward, done, info = env.step(get_libero_dummy_action("llava"))
|
| 125 |
-
|
| 126 |
-
# Set up new data lists
|
| 127 |
-
states = []
|
| 128 |
-
actions = []
|
| 129 |
-
ee_states = []
|
| 130 |
-
gripper_states = []
|
| 131 |
-
joint_states = []
|
| 132 |
-
robot_states = []
|
| 133 |
-
agentview_images = []
|
| 134 |
-
eye_in_hand_images = []
|
| 135 |
-
|
| 136 |
-
# Replay original demo actions in environment and record observations
|
| 137 |
-
for _, action in enumerate(orig_actions):
|
| 138 |
-
# Skip transitions with no-op actions
|
| 139 |
-
prev_action = actions[-1] if len(actions) > 0 else None
|
| 140 |
-
if is_noop(action, prev_action):
|
| 141 |
-
print(f"\tSkipping no-op action: {action}")
|
| 142 |
-
num_noops += 1
|
| 143 |
-
continue
|
| 144 |
-
|
| 145 |
-
if states == []:
|
| 146 |
-
# In the first timestep, since we're using the original initial state to initialize the environment,
|
| 147 |
-
# copy the initial state (first state in episode) over from the original HDF5 to the new one
|
| 148 |
-
states.append(orig_states[0])
|
| 149 |
-
robot_states.append(demo_data["robot_states"][0])
|
| 150 |
-
else:
|
| 151 |
-
# For all other timesteps, get state from environment and record it
|
| 152 |
-
states.append(env.sim.get_state().flatten())
|
| 153 |
-
robot_states.append(
|
| 154 |
-
np.concatenate([obs["robot0_gripper_qpos"], obs["robot0_eef_pos"], obs["robot0_eef_quat"]])
|
| 155 |
-
)
|
| 156 |
-
|
| 157 |
-
# Record original action (from demo)
|
| 158 |
-
actions.append(action)
|
| 159 |
-
|
| 160 |
-
# Record data returned by environment
|
| 161 |
-
if "robot0_gripper_qpos" in obs:
|
| 162 |
-
gripper_states.append(obs["robot0_gripper_qpos"])
|
| 163 |
-
joint_states.append(obs["robot0_joint_pos"])
|
| 164 |
-
ee_states.append(
|
| 165 |
-
np.hstack(
|
| 166 |
-
(
|
| 167 |
-
obs["robot0_eef_pos"],
|
| 168 |
-
T.quat2axisangle(obs["robot0_eef_quat"]),
|
| 169 |
-
)
|
| 170 |
-
)
|
| 171 |
-
)
|
| 172 |
-
agentview_images.append(obs["agentview_image"])
|
| 173 |
-
eye_in_hand_images.append(obs["robot0_eye_in_hand_image"])
|
| 174 |
-
|
| 175 |
-
# Execute demo action in environment
|
| 176 |
-
obs, reward, done, info = env.step(action.tolist())
|
| 177 |
-
|
| 178 |
-
# At end of episode, save replayed trajectories to new HDF5 files (only keep successes)
|
| 179 |
-
if done:
|
| 180 |
-
dones = np.zeros(len(actions)).astype(np.uint8)
|
| 181 |
-
dones[-1] = 1
|
| 182 |
-
rewards = np.zeros(len(actions)).astype(np.uint8)
|
| 183 |
-
rewards[-1] = 1
|
| 184 |
-
assert len(actions) == len(agentview_images)
|
| 185 |
-
|
| 186 |
-
ep_data_grp = grp.create_group(f"demo_{i}")
|
| 187 |
-
obs_grp = ep_data_grp.create_group("obs")
|
| 188 |
-
obs_grp.create_dataset("gripper_states", data=np.stack(gripper_states, axis=0))
|
| 189 |
-
obs_grp.create_dataset("joint_states", data=np.stack(joint_states, axis=0))
|
| 190 |
-
obs_grp.create_dataset("ee_states", data=np.stack(ee_states, axis=0))
|
| 191 |
-
obs_grp.create_dataset("ee_pos", data=np.stack(ee_states, axis=0)[:, :3])
|
| 192 |
-
obs_grp.create_dataset("ee_ori", data=np.stack(ee_states, axis=0)[:, 3:])
|
| 193 |
-
obs_grp.create_dataset("agentview_rgb", data=np.stack(agentview_images, axis=0))
|
| 194 |
-
obs_grp.create_dataset("eye_in_hand_rgb", data=np.stack(eye_in_hand_images, axis=0))
|
| 195 |
-
ep_data_grp.create_dataset("actions", data=actions)
|
| 196 |
-
ep_data_grp.create_dataset("states", data=np.stack(states))
|
| 197 |
-
ep_data_grp.create_dataset("robot_states", data=np.stack(robot_states, axis=0))
|
| 198 |
-
ep_data_grp.create_dataset("rewards", data=rewards)
|
| 199 |
-
ep_data_grp.create_dataset("dones", data=dones)
|
| 200 |
-
|
| 201 |
-
num_success += 1
|
| 202 |
-
|
| 203 |
-
num_replays += 1
|
| 204 |
-
|
| 205 |
-
# Record success/false and initial environment state in metainfo dict
|
| 206 |
-
task_key = task_description.replace(" ", "_")
|
| 207 |
-
episode_key = f"demo_{i}"
|
| 208 |
-
if task_key not in metainfo_json_dict:
|
| 209 |
-
metainfo_json_dict[task_key] = {}
|
| 210 |
-
if episode_key not in metainfo_json_dict[task_key]:
|
| 211 |
-
metainfo_json_dict[task_key][episode_key] = {}
|
| 212 |
-
metainfo_json_dict[task_key][episode_key]["success"] = bool(done)
|
| 213 |
-
metainfo_json_dict[task_key][episode_key]["initial_state"] = orig_states[0].tolist()
|
| 214 |
-
|
| 215 |
-
# Write metainfo dict to JSON file
|
| 216 |
-
# (We repeatedly overwrite, rather than doing this once at the end, just in case the script crashes midway)
|
| 217 |
-
with open(metainfo_json_out_path, "w") as f:
|
| 218 |
-
json.dump(metainfo_json_dict, f, indent=2)
|
| 219 |
-
|
| 220 |
-
# Count total number of successful replays so far
|
| 221 |
-
print(
|
| 222 |
-
f"Total # episodes replayed: {num_replays}, Total # successes: {num_success} ({num_success / num_replays * 100:.1f} %)"
|
| 223 |
-
)
|
| 224 |
-
|
| 225 |
-
# Report total number of no-op actions filtered out so far
|
| 226 |
-
print(f" Total # no-op actions filtered out: {num_noops}")
|
| 227 |
-
|
| 228 |
-
# Close HDF5 files
|
| 229 |
-
orig_data_file.close()
|
| 230 |
-
new_data_file.close()
|
| 231 |
-
print(f"Saved regenerated demos for task '{task_description}' at: {new_data_path}")
|
| 232 |
-
|
| 233 |
-
print(f"Dataset regeneration complete! Saved new dataset at: {args.libero_target_dir}")
|
| 234 |
-
print(f"Saved metainfo JSON at: {metainfo_json_out_path}")
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
if __name__ == "__main__":
|
| 238 |
-
# Parse command-line arguments
|
| 239 |
-
parser = argparse.ArgumentParser()
|
| 240 |
-
parser.add_argument("--libero_task_suite", type=str, choices=["libero_spatial", "libero_object", "libero_goal", "libero_10", "libero_90"],
|
| 241 |
-
help="LIBERO task suite. Example: libero_spatial", required=True)
|
| 242 |
-
parser.add_argument("--libero_raw_data_dir", type=str,
|
| 243 |
-
help="Path to directory containing raw HDF5 dataset. Example: ./LIBERO/libero/datasets/libero_spatial", required=True)
|
| 244 |
-
parser.add_argument("--libero_target_dir", type=str,
|
| 245 |
-
help="Path to regenerated dataset directory. Example: ./LIBERO/libero/datasets/libero_spatial_no_noops", required=True)
|
| 246 |
-
args = parser.parse_args()
|
| 247 |
-
|
| 248 |
-
# Start data regeneration
|
| 249 |
-
main(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/experiments/robot/libero/run_libero_eval.py
DELETED
|
@@ -1,540 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
run_libero_eval.py
|
| 3 |
-
|
| 4 |
-
Evaluates a trained policy in a LIBERO simulation benchmark task suite.
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
import json
|
| 8 |
-
import logging
|
| 9 |
-
import os
|
| 10 |
-
import sys
|
| 11 |
-
from collections import deque
|
| 12 |
-
from dataclasses import dataclass
|
| 13 |
-
from enum import Enum
|
| 14 |
-
from pathlib import Path
|
| 15 |
-
from typing import Optional, Union
|
| 16 |
-
|
| 17 |
-
import draccus
|
| 18 |
-
import numpy as np
|
| 19 |
-
import tqdm
|
| 20 |
-
from libero.libero import benchmark
|
| 21 |
-
|
| 22 |
-
import wandb
|
| 23 |
-
|
| 24 |
-
# Append current directory so that interpreter can find experiments.robot
|
| 25 |
-
sys.path.append("../..")
|
| 26 |
-
from experiments.robot.libero.libero_utils import (
|
| 27 |
-
get_libero_dummy_action,
|
| 28 |
-
get_libero_env,
|
| 29 |
-
get_libero_image,
|
| 30 |
-
get_libero_wrist_image,
|
| 31 |
-
quat2axisangle,
|
| 32 |
-
save_rollout_video,
|
| 33 |
-
)
|
| 34 |
-
from experiments.robot.openvla_utils import (
|
| 35 |
-
get_action_head,
|
| 36 |
-
get_noisy_action_projector,
|
| 37 |
-
get_processor,
|
| 38 |
-
get_proprio_projector,
|
| 39 |
-
resize_image_for_policy,
|
| 40 |
-
)
|
| 41 |
-
from experiments.robot.robot_utils import (
|
| 42 |
-
DATE_TIME,
|
| 43 |
-
get_action,
|
| 44 |
-
get_image_resize_size,
|
| 45 |
-
get_model,
|
| 46 |
-
invert_gripper_action,
|
| 47 |
-
normalize_gripper_action,
|
| 48 |
-
set_seed_everywhere,
|
| 49 |
-
)
|
| 50 |
-
from prismatic.vla.constants import NUM_ACTIONS_CHUNK
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
# import debugpy
|
| 54 |
-
# try:
|
| 55 |
-
# debugpy.listen(("localhost", 9501))
|
| 56 |
-
# print("Waiting for debugger attach")
|
| 57 |
-
# debugpy.wait_for_client()
|
| 58 |
-
# except Exception as e:
|
| 59 |
-
# pass
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
# Define task suite constants
|
| 63 |
-
class TaskSuite(str, Enum):
|
| 64 |
-
LIBERO_SPATIAL = "libero_spatial"
|
| 65 |
-
LIBERO_OBJECT = "libero_object"
|
| 66 |
-
LIBERO_GOAL = "libero_goal"
|
| 67 |
-
LIBERO_10 = "libero_10"
|
| 68 |
-
LIBERO_90 = "libero_90"
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
# Define max steps for each task suite
|
| 72 |
-
TASK_MAX_STEPS = {
|
| 73 |
-
TaskSuite.LIBERO_SPATIAL: 220, # longest training demo has 193 steps
|
| 74 |
-
TaskSuite.LIBERO_OBJECT: 280, # longest training demo has 254 steps
|
| 75 |
-
TaskSuite.LIBERO_GOAL: 300, # longest training demo has 270 steps
|
| 76 |
-
TaskSuite.LIBERO_10: 520, # longest training demo has 505 steps
|
| 77 |
-
TaskSuite.LIBERO_90: 400, # longest training demo has 373 steps
|
| 78 |
-
}
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
# Set up logging
|
| 82 |
-
logging.basicConfig(
|
| 83 |
-
level=logging.INFO,
|
| 84 |
-
format="%(asctime)s [%(levelname)s] %(message)s",
|
| 85 |
-
handlers=[logging.StreamHandler()],
|
| 86 |
-
)
|
| 87 |
-
logger = logging.getLogger(__name__)
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
@dataclass
|
| 91 |
-
class GenerateConfig:
|
| 92 |
-
# fmt: off
|
| 93 |
-
|
| 94 |
-
#################################################################################################################
|
| 95 |
-
# Model-specific parameters
|
| 96 |
-
#################################################################################################################
|
| 97 |
-
model_family: str = "openvla" # Model family
|
| 98 |
-
pretrained_checkpoint: Union[str, Path] = "" # Pretrained checkpoint path
|
| 99 |
-
|
| 100 |
-
use_l1_regression: bool = True # If True, uses continuous action head with L1 regression objective
|
| 101 |
-
use_diffusion: bool = False # If True, uses continuous action head with diffusion modeling objective (DDIM)
|
| 102 |
-
num_diffusion_steps_train: int = 50 # (When `diffusion==True`) Number of diffusion steps used for training
|
| 103 |
-
num_diffusion_steps_inference: int = 50 # (When `diffusion==True`) Number of diffusion steps used for inference
|
| 104 |
-
use_film: bool = False # If True, uses FiLM to infuse language inputs into visual features
|
| 105 |
-
num_images_in_input: int = 2 # Number of images in the VLA input (default: 1)
|
| 106 |
-
use_proprio: bool = True # Whether to include proprio state in input
|
| 107 |
-
|
| 108 |
-
center_crop: bool = True # Center crop? (if trained w/ random crop image aug)
|
| 109 |
-
num_open_loop_steps: int = 8 # Number of actions to execute open-loop before requerying policy
|
| 110 |
-
|
| 111 |
-
lora_rank: int = 32 # Rank of LoRA weight matrix (MAKE SURE THIS MATCHES TRAINING!)
|
| 112 |
-
|
| 113 |
-
unnorm_key: Union[str, Path] = "" # Action un-normalization key
|
| 114 |
-
|
| 115 |
-
load_in_8bit: bool = False # (For OpenVLA only) Load with 8-bit quantization
|
| 116 |
-
load_in_4bit: bool = False # (For OpenVLA only) Load with 4-bit quantization
|
| 117 |
-
|
| 118 |
-
#################################################################################################################
|
| 119 |
-
# LIBERO environment-specific parameters
|
| 120 |
-
#################################################################################################################
|
| 121 |
-
task_suite_name: str = TaskSuite.LIBERO_SPATIAL # Task suite
|
| 122 |
-
num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize in sim
|
| 123 |
-
num_trials_per_task: int = 50 # Number of rollouts per task
|
| 124 |
-
initial_states_path: str = "DEFAULT" # "DEFAULT", or path to initial states JSON file
|
| 125 |
-
env_img_res: int = 256 # Resolution for environment images (not policy input resolution)
|
| 126 |
-
|
| 127 |
-
#################################################################################################################
|
| 128 |
-
# Utils
|
| 129 |
-
#################################################################################################################
|
| 130 |
-
run_id_note: Optional[str] = None # Extra note to add to end of run ID for logging
|
| 131 |
-
local_log_dir: str = "./experiments/logs" # Local directory for eval logs
|
| 132 |
-
|
| 133 |
-
use_wandb: bool = False # Whether to also log results in Weights & Biases
|
| 134 |
-
wandb_entity: str = "your-wandb-entity" # Name of WandB entity
|
| 135 |
-
wandb_project: str = "your-wandb-project" # Name of WandB project
|
| 136 |
-
|
| 137 |
-
seed: int = 7 # Random Seed (for reproducibility)
|
| 138 |
-
|
| 139 |
-
# fmt: on
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
def validate_config(cfg: GenerateConfig) -> None:
|
| 143 |
-
"""Validate configuration parameters."""
|
| 144 |
-
assert cfg.pretrained_checkpoint is not None, "pretrained_checkpoint must not be None!"
|
| 145 |
-
|
| 146 |
-
if "image_aug" in str(cfg.pretrained_checkpoint):
|
| 147 |
-
assert cfg.center_crop, "Expecting `center_crop==True` because model was trained with image augmentations!"
|
| 148 |
-
|
| 149 |
-
assert not (cfg.load_in_8bit and cfg.load_in_4bit), "Cannot use both 8-bit and 4-bit quantization!"
|
| 150 |
-
|
| 151 |
-
# Validate task suite
|
| 152 |
-
assert cfg.task_suite_name in [suite.value for suite in TaskSuite], f"Invalid task suite: {cfg.task_suite_name}"
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
def initialize_model(cfg: GenerateConfig):
|
| 156 |
-
"""Initialize model and associated components."""
|
| 157 |
-
# Load model
|
| 158 |
-
model = get_model(cfg)
|
| 159 |
-
|
| 160 |
-
# Load proprio projector if needed
|
| 161 |
-
proprio_projector = None
|
| 162 |
-
if cfg.use_proprio:
|
| 163 |
-
proprio_projector = get_proprio_projector(
|
| 164 |
-
cfg,
|
| 165 |
-
model.llm_dim,
|
| 166 |
-
proprio_dim=8, # 8-dimensional proprio for LIBERO
|
| 167 |
-
)
|
| 168 |
-
|
| 169 |
-
# Load action head if needed
|
| 170 |
-
action_head = None
|
| 171 |
-
if cfg.use_l1_regression or cfg.use_diffusion:
|
| 172 |
-
action_head = get_action_head(cfg, model.llm_dim)
|
| 173 |
-
|
| 174 |
-
# Load noisy action projector if using diffusion
|
| 175 |
-
noisy_action_projector = None
|
| 176 |
-
if cfg.use_diffusion:
|
| 177 |
-
noisy_action_projector = get_noisy_action_projector(cfg, model.llm_dim)
|
| 178 |
-
|
| 179 |
-
# Get OpenVLA processor if needed
|
| 180 |
-
processor = None
|
| 181 |
-
if cfg.model_family == "openvla":
|
| 182 |
-
processor = get_processor(cfg)
|
| 183 |
-
check_unnorm_key(cfg, model)
|
| 184 |
-
|
| 185 |
-
return model, action_head, proprio_projector, noisy_action_projector, processor
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
def check_unnorm_key(cfg: GenerateConfig, model) -> None:
|
| 189 |
-
"""Check that the model contains the action un-normalization key."""
|
| 190 |
-
# Initialize unnorm_key
|
| 191 |
-
unnorm_key = cfg.task_suite_name
|
| 192 |
-
|
| 193 |
-
# In some cases, the key must be manually modified (e.g. after training on a modified version of the dataset
|
| 194 |
-
# with the suffix "_no_noops" in the dataset name)
|
| 195 |
-
if unnorm_key not in model.norm_stats and f"{unnorm_key}_no_noops" in model.norm_stats:
|
| 196 |
-
unnorm_key = f"{unnorm_key}_no_noops"
|
| 197 |
-
|
| 198 |
-
assert unnorm_key in model.norm_stats, f"Action un-norm key {unnorm_key} not found in VLA `norm_stats`!"
|
| 199 |
-
|
| 200 |
-
# Set the unnorm_key in cfg
|
| 201 |
-
cfg.unnorm_key = unnorm_key
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
def setup_logging(cfg: GenerateConfig):
|
| 205 |
-
"""Set up logging to file and optionally to wandb."""
|
| 206 |
-
# Create run ID
|
| 207 |
-
run_id = f"EVAL-{cfg.task_suite_name}-{cfg.model_family}-{DATE_TIME}"
|
| 208 |
-
if cfg.run_id_note is not None:
|
| 209 |
-
run_id += f"--{cfg.run_id_note}"
|
| 210 |
-
|
| 211 |
-
# Set up local logging
|
| 212 |
-
os.makedirs(cfg.local_log_dir, exist_ok=True)
|
| 213 |
-
local_log_filepath = os.path.join(cfg.local_log_dir, run_id + ".txt")
|
| 214 |
-
log_file = open(local_log_filepath, "w")
|
| 215 |
-
logger.info(f"Logging to local log file: {local_log_filepath}")
|
| 216 |
-
|
| 217 |
-
# Initialize Weights & Biases logging if enabled
|
| 218 |
-
if cfg.use_wandb:
|
| 219 |
-
wandb.init(
|
| 220 |
-
entity=cfg.wandb_entity,
|
| 221 |
-
project=cfg.wandb_project,
|
| 222 |
-
name=run_id,
|
| 223 |
-
)
|
| 224 |
-
|
| 225 |
-
return log_file, local_log_filepath, run_id
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
def log_message(message: str, log_file=None):
|
| 229 |
-
"""Log a message to console and optionally to a log file."""
|
| 230 |
-
logger.info(message)
|
| 231 |
-
if log_file:
|
| 232 |
-
log_file.write(message + "\n")
|
| 233 |
-
log_file.flush()
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
def load_initial_states(cfg: GenerateConfig, task_suite, task_id: int, log_file=None):
|
| 237 |
-
"""Load initial states for the given task."""
|
| 238 |
-
# Get default initial states
|
| 239 |
-
initial_states = task_suite.get_task_init_states(task_id)
|
| 240 |
-
|
| 241 |
-
# If using custom initial states, load them from file
|
| 242 |
-
if cfg.initial_states_path != "DEFAULT":
|
| 243 |
-
with open(cfg.initial_states_path, "r") as f:
|
| 244 |
-
all_initial_states = json.load(f)
|
| 245 |
-
log_message(f"Using initial states from {cfg.initial_states_path}", log_file)
|
| 246 |
-
return initial_states, all_initial_states
|
| 247 |
-
else:
|
| 248 |
-
log_message("Using default initial states", log_file)
|
| 249 |
-
return initial_states, None
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
def prepare_observation(obs, resize_size):
|
| 253 |
-
"""Prepare observation for policy input."""
|
| 254 |
-
# Get preprocessed images
|
| 255 |
-
img = get_libero_image(obs)
|
| 256 |
-
wrist_img = get_libero_wrist_image(obs)
|
| 257 |
-
|
| 258 |
-
# Resize images to size expected by model
|
| 259 |
-
img_resized = resize_image_for_policy(img, resize_size)
|
| 260 |
-
wrist_img_resized = resize_image_for_policy(wrist_img, resize_size)
|
| 261 |
-
|
| 262 |
-
# Prepare observations dict
|
| 263 |
-
observation = {
|
| 264 |
-
"full_image": img_resized,
|
| 265 |
-
"wrist_image": wrist_img_resized,
|
| 266 |
-
"state": np.concatenate(
|
| 267 |
-
(obs["robot0_eef_pos"], quat2axisangle(obs["robot0_eef_quat"]), obs["robot0_gripper_qpos"])
|
| 268 |
-
),
|
| 269 |
-
}
|
| 270 |
-
|
| 271 |
-
return observation, img # Return both processed observation and original image for replay
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
def process_action(action, model_family):
|
| 275 |
-
"""Process action before sending to environment."""
|
| 276 |
-
# Normalize gripper action [0,1] -> [-1,+1] because the environment expects the latter
|
| 277 |
-
action = normalize_gripper_action(action, binarize=True)
|
| 278 |
-
|
| 279 |
-
# [OpenVLA] The dataloader flips the sign of the gripper action to align with other datasets
|
| 280 |
-
# (0 = close, 1 = open), so flip it back (-1 = open, +1 = close) before executing the action
|
| 281 |
-
if model_family == "openvla":
|
| 282 |
-
action = invert_gripper_action(action)
|
| 283 |
-
|
| 284 |
-
return action
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
def run_episode(
|
| 288 |
-
cfg: GenerateConfig,
|
| 289 |
-
env,
|
| 290 |
-
task_description: str,
|
| 291 |
-
model,
|
| 292 |
-
resize_size,
|
| 293 |
-
processor=None,
|
| 294 |
-
action_head=None,
|
| 295 |
-
proprio_projector=None,
|
| 296 |
-
noisy_action_projector=None,
|
| 297 |
-
initial_state=None,
|
| 298 |
-
log_file=None,
|
| 299 |
-
):
|
| 300 |
-
"""Run a single episode in the environment."""
|
| 301 |
-
# Reset environment
|
| 302 |
-
env.reset()
|
| 303 |
-
|
| 304 |
-
# Set initial state if provided
|
| 305 |
-
if initial_state is not None:
|
| 306 |
-
obs = env.set_init_state(initial_state)
|
| 307 |
-
else:
|
| 308 |
-
obs = env.get_observation()
|
| 309 |
-
|
| 310 |
-
# Initialize action queue
|
| 311 |
-
if cfg.num_open_loop_steps != NUM_ACTIONS_CHUNK:
|
| 312 |
-
print(f"WARNING: cfg.num_open_loop_steps ({cfg.num_open_loop_steps}) does not match the NUM_ACTIONS_CHUNK "
|
| 313 |
-
f"({NUM_ACTIONS_CHUNK}) constant defined in prismatic.vla.constants! For best performance (in terms of "
|
| 314 |
-
"both speed and success rate), we recommend executing the full action chunk.")
|
| 315 |
-
action_queue = deque(maxlen=cfg.num_open_loop_steps)
|
| 316 |
-
|
| 317 |
-
# Setup
|
| 318 |
-
t = 0
|
| 319 |
-
replay_images = []
|
| 320 |
-
max_steps = TASK_MAX_STEPS[cfg.task_suite_name]
|
| 321 |
-
|
| 322 |
-
# Run episode
|
| 323 |
-
success = False
|
| 324 |
-
try:
|
| 325 |
-
while t < max_steps + cfg.num_steps_wait:
|
| 326 |
-
# Do nothing for the first few timesteps to let objects stabilize
|
| 327 |
-
if t < cfg.num_steps_wait:
|
| 328 |
-
obs, reward, done, info = env.step(get_libero_dummy_action(cfg.model_family))
|
| 329 |
-
t += 1
|
| 330 |
-
continue
|
| 331 |
-
|
| 332 |
-
# Prepare observation
|
| 333 |
-
observation, img = prepare_observation(obs, resize_size)
|
| 334 |
-
replay_images.append(img)
|
| 335 |
-
|
| 336 |
-
# If action queue is empty, requery model
|
| 337 |
-
if len(action_queue) == 0:
|
| 338 |
-
# Query model to get action
|
| 339 |
-
actions = get_action(
|
| 340 |
-
cfg,
|
| 341 |
-
model,
|
| 342 |
-
observation,
|
| 343 |
-
task_description,
|
| 344 |
-
processor=processor,
|
| 345 |
-
action_head=action_head,
|
| 346 |
-
proprio_projector=proprio_projector,
|
| 347 |
-
noisy_action_projector=noisy_action_projector,
|
| 348 |
-
use_film=cfg.use_film,
|
| 349 |
-
)
|
| 350 |
-
action_queue.extend(actions)
|
| 351 |
-
|
| 352 |
-
# Get action from queue
|
| 353 |
-
action = action_queue.popleft()
|
| 354 |
-
|
| 355 |
-
# Process action
|
| 356 |
-
action = process_action(action, cfg.model_family)
|
| 357 |
-
|
| 358 |
-
# Execute action in environment
|
| 359 |
-
obs, reward, done, info = env.step(action.tolist())
|
| 360 |
-
if done:
|
| 361 |
-
success = True
|
| 362 |
-
break
|
| 363 |
-
t += 1
|
| 364 |
-
|
| 365 |
-
except Exception as e:
|
| 366 |
-
log_message(f"Episode error: {e}", log_file)
|
| 367 |
-
|
| 368 |
-
return success, replay_images
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
def run_task(
|
| 372 |
-
cfg: GenerateConfig,
|
| 373 |
-
task_suite,
|
| 374 |
-
task_id: int,
|
| 375 |
-
model,
|
| 376 |
-
resize_size,
|
| 377 |
-
processor=None,
|
| 378 |
-
action_head=None,
|
| 379 |
-
proprio_projector=None,
|
| 380 |
-
noisy_action_projector=None,
|
| 381 |
-
total_episodes=0,
|
| 382 |
-
total_successes=0,
|
| 383 |
-
log_file=None,
|
| 384 |
-
):
|
| 385 |
-
"""Run evaluation for a single task."""
|
| 386 |
-
# Get task
|
| 387 |
-
task = task_suite.get_task(task_id)
|
| 388 |
-
|
| 389 |
-
# Get initial states
|
| 390 |
-
initial_states, all_initial_states = load_initial_states(cfg, task_suite, task_id, log_file)
|
| 391 |
-
|
| 392 |
-
# Initialize environment and get task description
|
| 393 |
-
env, task_description = get_libero_env(task, cfg.model_family, resolution=cfg.env_img_res)
|
| 394 |
-
|
| 395 |
-
# Start episodes
|
| 396 |
-
task_episodes, task_successes = 0, 0
|
| 397 |
-
for episode_idx in tqdm.tqdm(range(cfg.num_trials_per_task)):
|
| 398 |
-
log_message(f"\nTask: {task_description}", log_file)
|
| 399 |
-
|
| 400 |
-
# Handle initial state
|
| 401 |
-
if cfg.initial_states_path == "DEFAULT":
|
| 402 |
-
# Use default initial state
|
| 403 |
-
initial_state = initial_states[episode_idx]
|
| 404 |
-
else:
|
| 405 |
-
# Get keys for fetching initial episode state from JSON
|
| 406 |
-
initial_states_task_key = task_description.replace(" ", "_")
|
| 407 |
-
episode_key = f"demo_{episode_idx}"
|
| 408 |
-
|
| 409 |
-
# Skip episode if expert demonstration failed to complete the task
|
| 410 |
-
if not all_initial_states[initial_states_task_key][episode_key]["success"]:
|
| 411 |
-
log_message(f"Skipping task {task_id} episode {episode_idx} due to failed expert demo!", log_file)
|
| 412 |
-
continue
|
| 413 |
-
|
| 414 |
-
# Get initial state
|
| 415 |
-
initial_state = np.array(all_initial_states[initial_states_task_key][episode_key]["initial_state"])
|
| 416 |
-
|
| 417 |
-
log_message(f"Starting episode {task_episodes + 1}...", log_file)
|
| 418 |
-
|
| 419 |
-
# Run episode
|
| 420 |
-
success, replay_images = run_episode(
|
| 421 |
-
cfg,
|
| 422 |
-
env,
|
| 423 |
-
task_description,
|
| 424 |
-
model,
|
| 425 |
-
resize_size,
|
| 426 |
-
processor,
|
| 427 |
-
action_head,
|
| 428 |
-
proprio_projector,
|
| 429 |
-
noisy_action_projector,
|
| 430 |
-
initial_state,
|
| 431 |
-
log_file,
|
| 432 |
-
)
|
| 433 |
-
|
| 434 |
-
# Update counters
|
| 435 |
-
task_episodes += 1
|
| 436 |
-
total_episodes += 1
|
| 437 |
-
if success:
|
| 438 |
-
task_successes += 1
|
| 439 |
-
total_successes += 1
|
| 440 |
-
|
| 441 |
-
# Save replay video
|
| 442 |
-
save_rollout_video(
|
| 443 |
-
replay_images, total_episodes, success=success, task_description=task_description, log_file=log_file
|
| 444 |
-
)
|
| 445 |
-
|
| 446 |
-
# Log results
|
| 447 |
-
log_message(f"Success: {success}", log_file)
|
| 448 |
-
log_message(f"# episodes completed so far: {total_episodes}", log_file)
|
| 449 |
-
log_message(f"# successes: {total_successes} ({total_successes / total_episodes * 100:.1f}%)", log_file)
|
| 450 |
-
|
| 451 |
-
# Log task results
|
| 452 |
-
task_success_rate = float(task_successes) / float(task_episodes) if task_episodes > 0 else 0
|
| 453 |
-
total_success_rate = float(total_successes) / float(total_episodes) if total_episodes > 0 else 0
|
| 454 |
-
|
| 455 |
-
log_message(f"Current task success rate: {task_success_rate}", log_file)
|
| 456 |
-
log_message(f"Current total success rate: {total_success_rate}", log_file)
|
| 457 |
-
|
| 458 |
-
# Log to wandb if enabled
|
| 459 |
-
if cfg.use_wandb:
|
| 460 |
-
wandb.log(
|
| 461 |
-
{
|
| 462 |
-
f"success_rate/{task_description}": task_success_rate,
|
| 463 |
-
f"num_episodes/{task_description}": task_episodes,
|
| 464 |
-
}
|
| 465 |
-
)
|
| 466 |
-
|
| 467 |
-
return total_episodes, total_successes
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
@draccus.wrap()
|
| 471 |
-
def eval_libero(cfg: GenerateConfig) -> float:
|
| 472 |
-
"""Main function to evaluate a trained policy on LIBERO benchmark tasks."""
|
| 473 |
-
# Validate configuration
|
| 474 |
-
validate_config(cfg)
|
| 475 |
-
|
| 476 |
-
# Set random seed
|
| 477 |
-
set_seed_everywhere(cfg.seed)
|
| 478 |
-
|
| 479 |
-
# Initialize model and components
|
| 480 |
-
model, action_head, proprio_projector, noisy_action_projector, processor = initialize_model(cfg)
|
| 481 |
-
|
| 482 |
-
# Get expected image dimensions
|
| 483 |
-
resize_size = get_image_resize_size(cfg)
|
| 484 |
-
|
| 485 |
-
# Setup logging
|
| 486 |
-
log_file, local_log_filepath, run_id = setup_logging(cfg)
|
| 487 |
-
|
| 488 |
-
# Initialize LIBERO task suite
|
| 489 |
-
benchmark_dict = benchmark.get_benchmark_dict()
|
| 490 |
-
task_suite = benchmark_dict[cfg.task_suite_name]()
|
| 491 |
-
num_tasks = task_suite.n_tasks
|
| 492 |
-
|
| 493 |
-
log_message(f"Task suite: {cfg.task_suite_name}", log_file)
|
| 494 |
-
|
| 495 |
-
# Start evaluation
|
| 496 |
-
total_episodes, total_successes = 0, 0
|
| 497 |
-
for task_id in tqdm.tqdm(range(num_tasks)):
|
| 498 |
-
total_episodes, total_successes = run_task(
|
| 499 |
-
cfg,
|
| 500 |
-
task_suite,
|
| 501 |
-
task_id,
|
| 502 |
-
model,
|
| 503 |
-
resize_size,
|
| 504 |
-
processor,
|
| 505 |
-
action_head,
|
| 506 |
-
proprio_projector,
|
| 507 |
-
noisy_action_projector,
|
| 508 |
-
total_episodes,
|
| 509 |
-
total_successes,
|
| 510 |
-
log_file,
|
| 511 |
-
)
|
| 512 |
-
|
| 513 |
-
# Calculate final success rate
|
| 514 |
-
final_success_rate = float(total_successes) / float(total_episodes) if total_episodes > 0 else 0
|
| 515 |
-
|
| 516 |
-
# Log final results
|
| 517 |
-
log_message("Final results:", log_file)
|
| 518 |
-
log_message(f"Total episodes: {total_episodes}", log_file)
|
| 519 |
-
log_message(f"Total successes: {total_successes}", log_file)
|
| 520 |
-
log_message(f"Overall success rate: {final_success_rate:.4f} ({final_success_rate * 100:.1f}%)", log_file)
|
| 521 |
-
|
| 522 |
-
# Log to wandb if enabled
|
| 523 |
-
if cfg.use_wandb:
|
| 524 |
-
wandb.log(
|
| 525 |
-
{
|
| 526 |
-
"success_rate/total": final_success_rate,
|
| 527 |
-
"num_episodes/total": total_episodes,
|
| 528 |
-
}
|
| 529 |
-
)
|
| 530 |
-
wandb.save(local_log_filepath)
|
| 531 |
-
|
| 532 |
-
# Close log file
|
| 533 |
-
if log_file:
|
| 534 |
-
log_file.close()
|
| 535 |
-
|
| 536 |
-
return final_success_rate
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
if __name__ == "__main__":
|
| 540 |
-
eval_libero()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/experiments/robot/libero/sample_libero_spatial_observation.pkl
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:326db6c78dd0a9d91c11f05af03b93fa3095338ee3cb5a5eb15adf3d87eb0109
|
| 3 |
-
size 301501
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/experiments/robot/openvla_utils.py
DELETED
|
@@ -1,818 +0,0 @@
|
|
| 1 |
-
"""Utils for evaluating OpenVLA or fine-tuned OpenVLA policies."""
|
| 2 |
-
|
| 3 |
-
import filecmp
|
| 4 |
-
import json
|
| 5 |
-
import os
|
| 6 |
-
import shutil
|
| 7 |
-
import time
|
| 8 |
-
from datetime import datetime
|
| 9 |
-
from pathlib import Path
|
| 10 |
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 11 |
-
|
| 12 |
-
import json_numpy
|
| 13 |
-
import numpy as np
|
| 14 |
-
import requests
|
| 15 |
-
import tensorflow as tf
|
| 16 |
-
import torch
|
| 17 |
-
from huggingface_hub import HfApi, hf_hub_download
|
| 18 |
-
from PIL import Image
|
| 19 |
-
from transformers import AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor
|
| 20 |
-
|
| 21 |
-
# Apply JSON numpy patch for serialization
|
| 22 |
-
json_numpy.patch()
|
| 23 |
-
|
| 24 |
-
from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig
|
| 25 |
-
from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction
|
| 26 |
-
from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor
|
| 27 |
-
from prismatic.models.action_heads import DiffusionActionHead, L1RegressionActionHead
|
| 28 |
-
from prismatic.models.film_vit_wrapper import FiLMedPrismaticVisionBackbone
|
| 29 |
-
from prismatic.models.projectors import NoisyActionProjector, ProprioProjector
|
| 30 |
-
from prismatic.vla.constants import (
|
| 31 |
-
ACTION_DIM,
|
| 32 |
-
ACTION_PROPRIO_NORMALIZATION_TYPE,
|
| 33 |
-
)
|
| 34 |
-
from prismatic.vla.datasets.rlds.utils.data_utils import NormalizationType
|
| 35 |
-
|
| 36 |
-
# Initialize important constants
|
| 37 |
-
DATE = time.strftime("%Y_%m_%d")
|
| 38 |
-
DATE_TIME = time.strftime("%Y_%m_%d-%H_%M_%S")
|
| 39 |
-
DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
| 40 |
-
OPENVLA_IMAGE_SIZE = 224 # Standard image size expected by OpenVLA
|
| 41 |
-
|
| 42 |
-
# Configure NumPy print settings
|
| 43 |
-
np.set_printoptions(formatter={"float": lambda x: "{0:0.3f}".format(x)})
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
def model_is_on_hf_hub(model_path: str) -> bool:
|
| 47 |
-
"""Checks whether a model path points to a model on Hugging Face Hub."""
|
| 48 |
-
# If the API call below runs without error, the model is on the hub
|
| 49 |
-
try:
|
| 50 |
-
HfApi().model_info(model_path)
|
| 51 |
-
return True
|
| 52 |
-
except Exception:
|
| 53 |
-
return False
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
def update_auto_map(pretrained_checkpoint: str) -> None:
|
| 57 |
-
"""
|
| 58 |
-
Update the AutoMap configuration in the checkpoint config.json file.
|
| 59 |
-
|
| 60 |
-
This loads the config.json file inside the checkpoint directory and overwrites
|
| 61 |
-
the AutoConfig and AutoModelForVision2Seq fields to use OpenVLA-specific classes.
|
| 62 |
-
|
| 63 |
-
Args:
|
| 64 |
-
pretrained_checkpoint: Path to the checkpoint directory
|
| 65 |
-
"""
|
| 66 |
-
if not os.path.isdir(pretrained_checkpoint):
|
| 67 |
-
return
|
| 68 |
-
|
| 69 |
-
config_path = os.path.join(pretrained_checkpoint, "config.json")
|
| 70 |
-
if not os.path.exists(config_path):
|
| 71 |
-
print(f"Warning: No config.json found at {config_path}")
|
| 72 |
-
return
|
| 73 |
-
|
| 74 |
-
# Create timestamped backup
|
| 75 |
-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 76 |
-
backup_path = os.path.join(pretrained_checkpoint, f"config.json.back.{timestamp}")
|
| 77 |
-
shutil.copy2(config_path, backup_path)
|
| 78 |
-
print(f"Created backup of original config at: {os.path.abspath(backup_path)}")
|
| 79 |
-
|
| 80 |
-
# Read and update the config
|
| 81 |
-
with open(config_path, "r") as f:
|
| 82 |
-
config = json.load(f)
|
| 83 |
-
|
| 84 |
-
config["auto_map"] = {
|
| 85 |
-
"AutoConfig": "configuration_prismatic.OpenVLAConfig",
|
| 86 |
-
"AutoModelForVision2Seq": "modeling_prismatic.OpenVLAForActionPrediction",
|
| 87 |
-
}
|
| 88 |
-
|
| 89 |
-
# Write back the updated config
|
| 90 |
-
with open(config_path, "w") as f:
|
| 91 |
-
json.dump(config, f, indent=2)
|
| 92 |
-
|
| 93 |
-
print(f"Updated config.json at: {os.path.abspath(config_path)}")
|
| 94 |
-
print("Changes made:")
|
| 95 |
-
print(' - Set AutoConfig to "configuration_prismatic.OpenVLAConfig"')
|
| 96 |
-
print(' - Set AutoModelForVision2Seq to "modeling_prismatic.OpenVLAForActionPrediction"')
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
def check_identical_files(path1: Union[str, Path], path2: Union[str, Path]) -> bool:
|
| 100 |
-
"""
|
| 101 |
-
Check if two files are identical in content.
|
| 102 |
-
|
| 103 |
-
Args:
|
| 104 |
-
path1: Path to the first file
|
| 105 |
-
path2: Path to the second file
|
| 106 |
-
|
| 107 |
-
Returns:
|
| 108 |
-
bool: True if files are identical, False otherwise
|
| 109 |
-
"""
|
| 110 |
-
path1, path2 = Path(path1), Path(path2)
|
| 111 |
-
|
| 112 |
-
# First check if file sizes match
|
| 113 |
-
if path1.stat().st_size != path2.stat().st_size:
|
| 114 |
-
return False
|
| 115 |
-
|
| 116 |
-
# Check if contents match
|
| 117 |
-
return filecmp.cmp(path1, path2, shallow=False)
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
def _handle_file_sync(curr_filepath: str, checkpoint_filepath: str, file_type: str) -> None:
|
| 121 |
-
"""
|
| 122 |
-
Handle syncing of files between current directory and checkpoint.
|
| 123 |
-
|
| 124 |
-
Creates backups if files exist but differ, and copies current versions to checkpoint.
|
| 125 |
-
|
| 126 |
-
Args:
|
| 127 |
-
curr_filepath: Path to the current file version
|
| 128 |
-
checkpoint_filepath: Path where the file should be in the checkpoint
|
| 129 |
-
file_type: Description of the file type for logging
|
| 130 |
-
"""
|
| 131 |
-
if os.path.exists(checkpoint_filepath):
|
| 132 |
-
# Check if existing files are identical
|
| 133 |
-
match = check_identical_files(curr_filepath, checkpoint_filepath)
|
| 134 |
-
|
| 135 |
-
if not match:
|
| 136 |
-
print(
|
| 137 |
-
"\n------------------------------------------------------------------------------------------------\n"
|
| 138 |
-
f"Found mismatch between:\n"
|
| 139 |
-
f"Current: {curr_filepath}\n"
|
| 140 |
-
f"Checkpoint: {checkpoint_filepath}\n"
|
| 141 |
-
)
|
| 142 |
-
|
| 143 |
-
# Create timestamped backup
|
| 144 |
-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 145 |
-
backup_path = f"{checkpoint_filepath}.back.{timestamp}"
|
| 146 |
-
shutil.copy2(checkpoint_filepath, backup_path)
|
| 147 |
-
print(f"Created backup of original checkpoint file at: {os.path.abspath(backup_path)}")
|
| 148 |
-
|
| 149 |
-
# Copy current version to checkpoint directory
|
| 150 |
-
shutil.copy2(curr_filepath, checkpoint_filepath)
|
| 151 |
-
print(f"Copied current version to checkpoint at: {os.path.abspath(checkpoint_filepath)}")
|
| 152 |
-
print(
|
| 153 |
-
f"Changes complete. The checkpoint will now use the current version of {file_type}"
|
| 154 |
-
"\n------------------------------------------------------------------------------------------------\n"
|
| 155 |
-
)
|
| 156 |
-
else:
|
| 157 |
-
# If file doesn't exist in checkpoint directory, copy it
|
| 158 |
-
shutil.copy2(curr_filepath, checkpoint_filepath)
|
| 159 |
-
print(
|
| 160 |
-
"\n------------------------------------------------------------------------------------------------\n"
|
| 161 |
-
f"No {file_type} found in checkpoint directory.\n"
|
| 162 |
-
f"Copied current version from: {curr_filepath}\n"
|
| 163 |
-
f"To checkpoint location: {os.path.abspath(checkpoint_filepath)}"
|
| 164 |
-
"\n------------------------------------------------------------------------------------------------\n"
|
| 165 |
-
)
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
def check_model_logic_mismatch(pretrained_checkpoint: str) -> None:
|
| 169 |
-
"""
|
| 170 |
-
Check and sync model logic files between current code and checkpoint.
|
| 171 |
-
|
| 172 |
-
Handles the relationship between current and checkpoint versions of both
|
| 173 |
-
modeling_prismatic.py and configuration_prismatic.py:
|
| 174 |
-
- If checkpoint file exists and differs: creates backup and copies current version
|
| 175 |
-
- If checkpoint file doesn't exist: copies current version
|
| 176 |
-
|
| 177 |
-
Args:
|
| 178 |
-
pretrained_checkpoint: Path to the checkpoint directory
|
| 179 |
-
"""
|
| 180 |
-
if not os.path.isdir(pretrained_checkpoint):
|
| 181 |
-
return
|
| 182 |
-
|
| 183 |
-
# Find current files
|
| 184 |
-
curr_files = {"modeling_prismatic.py": None, "configuration_prismatic.py": None}
|
| 185 |
-
|
| 186 |
-
for root, _, files in os.walk("./prismatic/"):
|
| 187 |
-
for filename in curr_files.keys():
|
| 188 |
-
if filename in files and curr_files[filename] is None:
|
| 189 |
-
curr_files[filename] = os.path.join(root, filename)
|
| 190 |
-
|
| 191 |
-
# Check and handle each file
|
| 192 |
-
for filename, curr_filepath in curr_files.items():
|
| 193 |
-
if curr_filepath is None:
|
| 194 |
-
print(f"WARNING: `{filename}` is not found anywhere in the current directory.")
|
| 195 |
-
continue
|
| 196 |
-
|
| 197 |
-
checkpoint_filepath = os.path.join(pretrained_checkpoint, filename)
|
| 198 |
-
_handle_file_sync(curr_filepath, checkpoint_filepath, filename)
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
def find_checkpoint_file(pretrained_checkpoint: str, file_pattern: str) -> str:
|
| 202 |
-
"""
|
| 203 |
-
Find a specific checkpoint file matching a pattern.
|
| 204 |
-
|
| 205 |
-
Args:
|
| 206 |
-
pretrained_checkpoint: Path to the checkpoint directory
|
| 207 |
-
file_pattern: String pattern to match in filenames
|
| 208 |
-
|
| 209 |
-
Returns:
|
| 210 |
-
str: Path to the matching checkpoint file
|
| 211 |
-
|
| 212 |
-
Raises:
|
| 213 |
-
AssertionError: If no files or multiple files match the pattern
|
| 214 |
-
"""
|
| 215 |
-
assert os.path.isdir(pretrained_checkpoint), f"Checkpoint path must be a directory: {pretrained_checkpoint}"
|
| 216 |
-
|
| 217 |
-
checkpoint_files = []
|
| 218 |
-
for filename in os.listdir(pretrained_checkpoint):
|
| 219 |
-
if file_pattern in filename and "checkpoint" in filename:
|
| 220 |
-
full_path = os.path.join(pretrained_checkpoint, filename)
|
| 221 |
-
checkpoint_files.append(full_path)
|
| 222 |
-
|
| 223 |
-
assert len(checkpoint_files) == 1, (
|
| 224 |
-
f"Expected exactly 1 {file_pattern} checkpoint but found {len(checkpoint_files)} in directory: {pretrained_checkpoint}"
|
| 225 |
-
)
|
| 226 |
-
|
| 227 |
-
return checkpoint_files[0]
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
def load_component_state_dict(checkpoint_path: str) -> Dict[str, torch.Tensor]:
|
| 231 |
-
"""
|
| 232 |
-
Load a component's state dict from checkpoint and handle DDP prefix if present.
|
| 233 |
-
|
| 234 |
-
Args:
|
| 235 |
-
checkpoint_path: Path to the checkpoint file
|
| 236 |
-
|
| 237 |
-
Returns:
|
| 238 |
-
Dict: The processed state dictionary for loading
|
| 239 |
-
"""
|
| 240 |
-
state_dict = torch.load(checkpoint_path, weights_only=True)
|
| 241 |
-
|
| 242 |
-
# If the component was trained with DDP, elements in the state dict have prefix "module." which we must remove
|
| 243 |
-
new_state_dict = {}
|
| 244 |
-
for k, v in state_dict.items():
|
| 245 |
-
if k.startswith("module."):
|
| 246 |
-
new_state_dict[k[7:]] = v
|
| 247 |
-
else:
|
| 248 |
-
new_state_dict[k] = v
|
| 249 |
-
|
| 250 |
-
return new_state_dict
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
def get_vla(cfg: Any) -> torch.nn.Module:
|
| 254 |
-
"""
|
| 255 |
-
Load and initialize the VLA model from checkpoint.
|
| 256 |
-
|
| 257 |
-
Args:
|
| 258 |
-
cfg: Configuration object
|
| 259 |
-
|
| 260 |
-
Returns:
|
| 261 |
-
torch.nn.Module: The initialized VLA model
|
| 262 |
-
"""
|
| 263 |
-
print("Instantiating pretrained VLA policy...")
|
| 264 |
-
|
| 265 |
-
# If loading a locally stored pretrained checkpoint, check whether config or model files
|
| 266 |
-
# need to be synced so that any changes the user makes to the VLA modeling code will
|
| 267 |
-
# actually go into effect
|
| 268 |
-
# If loading a pretrained checkpoint from Hugging Face Hub, we just assume that the policy
|
| 269 |
-
# will be used as is, with its original modeling logic
|
| 270 |
-
if not model_is_on_hf_hub(cfg.pretrained_checkpoint):
|
| 271 |
-
# Register OpenVLA model to HF Auto Classes (not needed if the model is on HF Hub)
|
| 272 |
-
AutoConfig.register("openvla", OpenVLAConfig)
|
| 273 |
-
AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor)
|
| 274 |
-
AutoProcessor.register(OpenVLAConfig, PrismaticProcessor)
|
| 275 |
-
AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction)
|
| 276 |
-
|
| 277 |
-
# Update config.json and sync model files
|
| 278 |
-
update_auto_map(cfg.pretrained_checkpoint)
|
| 279 |
-
check_model_logic_mismatch(cfg.pretrained_checkpoint)
|
| 280 |
-
|
| 281 |
-
# Load the model
|
| 282 |
-
vla = AutoModelForVision2Seq.from_pretrained(
|
| 283 |
-
cfg.pretrained_checkpoint,
|
| 284 |
-
# attn_implementation="flash_attention_2",
|
| 285 |
-
torch_dtype=torch.bfloat16,
|
| 286 |
-
load_in_8bit=cfg.load_in_8bit,
|
| 287 |
-
load_in_4bit=cfg.load_in_4bit,
|
| 288 |
-
low_cpu_mem_usage=True,
|
| 289 |
-
trust_remote_code=True,
|
| 290 |
-
)
|
| 291 |
-
|
| 292 |
-
# If using FiLM, wrap the vision backbone to allow for infusion of language inputs
|
| 293 |
-
if cfg.use_film:
|
| 294 |
-
vla = _apply_film_to_vla(vla, cfg)
|
| 295 |
-
|
| 296 |
-
# Set number of images in model input
|
| 297 |
-
vla.vision_backbone.set_num_images_in_input(cfg.num_images_in_input)
|
| 298 |
-
|
| 299 |
-
vla.eval()
|
| 300 |
-
|
| 301 |
-
# Move model to device if not using quantization
|
| 302 |
-
if not cfg.load_in_8bit and not cfg.load_in_4bit:
|
| 303 |
-
vla = vla.to(DEVICE)
|
| 304 |
-
|
| 305 |
-
# Load dataset stats for action normalization
|
| 306 |
-
_load_dataset_stats(vla, cfg.pretrained_checkpoint)
|
| 307 |
-
|
| 308 |
-
return vla
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
def _apply_film_to_vla(vla: torch.nn.Module, cfg: Any) -> torch.nn.Module:
|
| 312 |
-
"""
|
| 313 |
-
Apply FiLM (Feature-wise Linear Modulation) to the VLA vision backbone.
|
| 314 |
-
|
| 315 |
-
Args:
|
| 316 |
-
vla: The VLA model
|
| 317 |
-
cfg: Configuration object with model parameters
|
| 318 |
-
|
| 319 |
-
Returns:
|
| 320 |
-
torch.nn.Module: VLA model with FiLM applied
|
| 321 |
-
"""
|
| 322 |
-
from peft import LoraConfig, get_peft_model
|
| 323 |
-
|
| 324 |
-
# Apply LoRA configuration
|
| 325 |
-
lora_config = LoraConfig(
|
| 326 |
-
r=cfg.lora_rank,
|
| 327 |
-
lora_alpha=min(cfg.lora_rank, 16),
|
| 328 |
-
lora_dropout=0.0,
|
| 329 |
-
target_modules="all-linear",
|
| 330 |
-
init_lora_weights="gaussian",
|
| 331 |
-
)
|
| 332 |
-
vla = get_peft_model(vla, lora_config)
|
| 333 |
-
|
| 334 |
-
# Create and apply FiLMed vision backbone
|
| 335 |
-
new_vision_backbone = FiLMedPrismaticVisionBackbone(
|
| 336 |
-
vision_backbone=vla.vision_backbone, llm_dim=vla.llm_dim,
|
| 337 |
-
)
|
| 338 |
-
vla.model.vision_backbone = new_vision_backbone
|
| 339 |
-
|
| 340 |
-
# Load vision backbone checkpoint
|
| 341 |
-
checkpoint_path = find_checkpoint_file(cfg.pretrained_checkpoint, "vision_backbone")
|
| 342 |
-
state_dict = torch.load(checkpoint_path, weights_only=True)
|
| 343 |
-
vla.model.vision_backbone.load_state_dict(state_dict)
|
| 344 |
-
|
| 345 |
-
# Use the model component instead of wrapper and convert to bfloat16
|
| 346 |
-
vla = vla.model
|
| 347 |
-
vla.vision_backbone = vla.vision_backbone.to(torch.bfloat16)
|
| 348 |
-
|
| 349 |
-
return vla
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
def _load_dataset_stats(vla: torch.nn.Module, checkpoint_path: str) -> None:
|
| 353 |
-
"""
|
| 354 |
-
Load dataset statistics used during training for action normalization.
|
| 355 |
-
|
| 356 |
-
Args:
|
| 357 |
-
vla: The VLA model
|
| 358 |
-
checkpoint_path: Path to the checkpoint directory
|
| 359 |
-
"""
|
| 360 |
-
if model_is_on_hf_hub(checkpoint_path):
|
| 361 |
-
# Download dataset stats directly from HF Hub
|
| 362 |
-
dataset_statistics_path = hf_hub_download(
|
| 363 |
-
repo_id=checkpoint_path,
|
| 364 |
-
filename="dataset_statistics.json",
|
| 365 |
-
)
|
| 366 |
-
else:
|
| 367 |
-
dataset_statistics_path = os.path.join(checkpoint_path, "dataset_statistics.json")
|
| 368 |
-
if os.path.isfile(dataset_statistics_path):
|
| 369 |
-
with open(dataset_statistics_path, "r") as f:
|
| 370 |
-
norm_stats = json.load(f)
|
| 371 |
-
vla.norm_stats = norm_stats
|
| 372 |
-
else:
|
| 373 |
-
print(
|
| 374 |
-
"WARNING: No local dataset_statistics.json file found for current checkpoint.\n"
|
| 375 |
-
"You can ignore this if you are loading the base VLA (i.e. not fine-tuned) checkpoint."
|
| 376 |
-
"Otherwise, you may run into errors when trying to call `predict_action()` due to an absent `unnorm_key`."
|
| 377 |
-
)
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
def get_processor(cfg: Any) -> AutoProcessor:
|
| 381 |
-
"""
|
| 382 |
-
Get the VLA model's Hugging Face processor.
|
| 383 |
-
|
| 384 |
-
Args:
|
| 385 |
-
cfg: Configuration object with model parameters
|
| 386 |
-
|
| 387 |
-
Returns:
|
| 388 |
-
AutoProcessor: The model's processor
|
| 389 |
-
"""
|
| 390 |
-
return AutoProcessor.from_pretrained(cfg.pretrained_checkpoint, trust_remote_code=True)
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
def get_proprio_projector(cfg: Any, llm_dim: int, proprio_dim: int) -> ProprioProjector:
|
| 394 |
-
"""
|
| 395 |
-
Get proprioception projector for the VLA model.
|
| 396 |
-
|
| 397 |
-
Args:
|
| 398 |
-
cfg: Configuration object with model parameters
|
| 399 |
-
llm_dim: Dimension of the language model
|
| 400 |
-
proprio_dim: Dimension of proprioception data
|
| 401 |
-
|
| 402 |
-
Returns:
|
| 403 |
-
ProprioProjector: The initialized proprio projector
|
| 404 |
-
"""
|
| 405 |
-
# Initialize projector and move to device
|
| 406 |
-
proprio_projector = ProprioProjector(
|
| 407 |
-
llm_dim=llm_dim,
|
| 408 |
-
proprio_dim=proprio_dim,
|
| 409 |
-
).to(DEVICE)
|
| 410 |
-
proprio_projector = proprio_projector.to(torch.bfloat16).to(DEVICE)
|
| 411 |
-
proprio_projector.eval()
|
| 412 |
-
|
| 413 |
-
# Find and load checkpoint (may be on Hugging Face Hub or stored locally)
|
| 414 |
-
if model_is_on_hf_hub(cfg.pretrained_checkpoint):
|
| 415 |
-
model_path_to_proprio_projector_name = {
|
| 416 |
-
"moojink/openvla-7b-oft-finetuned-libero-spatial": "proprio_projector--150000_checkpoint.pt",
|
| 417 |
-
"moojink/openvla-7b-oft-finetuned-libero-object": "proprio_projector--150000_checkpoint.pt",
|
| 418 |
-
"moojink/openvla-7b-oft-finetuned-libero-goal": "proprio_projector--50000_checkpoint.pt",
|
| 419 |
-
"moojink/openvla-7b-oft-finetuned-libero-10": "proprio_projector--150000_checkpoint.pt",
|
| 420 |
-
"moojink/openvla-7b-oft-finetuned-libero-spatial-object-goal-10": "proprio_projector--300000_checkpoint.pt",
|
| 421 |
-
}
|
| 422 |
-
if cfg.pretrained_checkpoint not in model_path_to_proprio_projector_name.keys():
|
| 423 |
-
raise ValueError("Unsupported HF Hub pretrained checkpoint found!")
|
| 424 |
-
# Download proprio projector directly from HF Hub
|
| 425 |
-
proprio_projector_path = hf_hub_download(
|
| 426 |
-
repo_id=cfg.pretrained_checkpoint, filename=model_path_to_proprio_projector_name[cfg.pretrained_checkpoint]
|
| 427 |
-
)
|
| 428 |
-
state_dict = load_component_state_dict(proprio_projector_path)
|
| 429 |
-
proprio_projector.load_state_dict(state_dict)
|
| 430 |
-
else:
|
| 431 |
-
checkpoint_path = find_checkpoint_file(cfg.pretrained_checkpoint, "proprio_projector")
|
| 432 |
-
state_dict = load_component_state_dict(checkpoint_path)
|
| 433 |
-
proprio_projector.load_state_dict(state_dict)
|
| 434 |
-
|
| 435 |
-
return proprio_projector
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
def get_noisy_action_projector(cfg: Any, llm_dim: int) -> NoisyActionProjector:
|
| 439 |
-
"""
|
| 440 |
-
Get noisy action projector for diffusion-based action prediction.
|
| 441 |
-
|
| 442 |
-
Args:
|
| 443 |
-
cfg: Configuration object with model parameters
|
| 444 |
-
llm_dim: Dimension of the language model
|
| 445 |
-
|
| 446 |
-
Returns:
|
| 447 |
-
NoisyActionProjector: The initialized noisy action projector
|
| 448 |
-
"""
|
| 449 |
-
# Initialize projector and move to device
|
| 450 |
-
noisy_action_projector = NoisyActionProjector(
|
| 451 |
-
llm_dim=llm_dim,
|
| 452 |
-
).to(DEVICE)
|
| 453 |
-
noisy_action_projector = noisy_action_projector.to(torch.bfloat16).to(DEVICE)
|
| 454 |
-
noisy_action_projector.eval()
|
| 455 |
-
|
| 456 |
-
# Find and load checkpoint
|
| 457 |
-
checkpoint_path = find_checkpoint_file(cfg.pretrained_checkpoint, "noisy_action_projector")
|
| 458 |
-
state_dict = load_component_state_dict(checkpoint_path)
|
| 459 |
-
noisy_action_projector.load_state_dict(state_dict)
|
| 460 |
-
|
| 461 |
-
return noisy_action_projector
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
def get_action_head(cfg: Any, llm_dim: int) -> Union[L1RegressionActionHead, DiffusionActionHead]:
|
| 465 |
-
"""
|
| 466 |
-
Get action head for continuous value prediction.
|
| 467 |
-
|
| 468 |
-
Args:
|
| 469 |
-
cfg: Configuration object with model parameters
|
| 470 |
-
llm_dim: Dimension of the language model
|
| 471 |
-
|
| 472 |
-
Returns:
|
| 473 |
-
Union[L1RegressionActionHead, DiffusionActionHead]: The initialized action head
|
| 474 |
-
|
| 475 |
-
Raises:
|
| 476 |
-
AssertionError: If both L1 regression and diffusion are specified
|
| 477 |
-
"""
|
| 478 |
-
assert not (cfg.use_l1_regression and cfg.use_diffusion), "Cannot use both L1 regression and diffusion action head!"
|
| 479 |
-
|
| 480 |
-
# Initialize appropriate action head based on configuration
|
| 481 |
-
if cfg.use_l1_regression:
|
| 482 |
-
action_head = L1RegressionActionHead(input_dim=llm_dim, hidden_dim=llm_dim, action_dim=ACTION_DIM)
|
| 483 |
-
elif cfg.use_diffusion:
|
| 484 |
-
action_head = DiffusionActionHead(
|
| 485 |
-
input_dim=llm_dim, hidden_dim=llm_dim, action_dim=ACTION_DIM, num_diffusion_steps_train=cfg.num_diffusion_steps_train
|
| 486 |
-
)
|
| 487 |
-
# Set number of diffusion steps for inference
|
| 488 |
-
action_head.noise_scheduler.set_timesteps(cfg.num_diffusion_steps_inference)
|
| 489 |
-
else:
|
| 490 |
-
raise ValueError("Either use_l1_regression or use_diffusion must be True")
|
| 491 |
-
|
| 492 |
-
action_head = action_head.to(torch.bfloat16).to(DEVICE)
|
| 493 |
-
action_head.eval()
|
| 494 |
-
|
| 495 |
-
# Find and load checkpoint (may be on Hugging Face Hub or stored locally)
|
| 496 |
-
if model_is_on_hf_hub(cfg.pretrained_checkpoint):
|
| 497 |
-
model_path_to_action_head_name = {
|
| 498 |
-
"moojink/openvla-7b-oft-finetuned-libero-spatial": "action_head--150000_checkpoint.pt",
|
| 499 |
-
"moojink/openvla-7b-oft-finetuned-libero-object": "action_head--150000_checkpoint.pt",
|
| 500 |
-
"moojink/openvla-7b-oft-finetuned-libero-goal": "action_head--50000_checkpoint.pt",
|
| 501 |
-
"moojink/openvla-7b-oft-finetuned-libero-10": "action_head--150000_checkpoint.pt",
|
| 502 |
-
"moojink/openvla-7b-oft-finetuned-libero-spatial-object-goal-10": "action_head--300000_checkpoint.pt",
|
| 503 |
-
}
|
| 504 |
-
if cfg.pretrained_checkpoint not in model_path_to_action_head_name.keys():
|
| 505 |
-
raise ValueError("Unsupported HF Hub pretrained checkpoint found!")
|
| 506 |
-
# Download proprio projector directly from HF Hub
|
| 507 |
-
action_head_path = hf_hub_download(
|
| 508 |
-
repo_id=cfg.pretrained_checkpoint, filename=model_path_to_action_head_name[cfg.pretrained_checkpoint]
|
| 509 |
-
)
|
| 510 |
-
state_dict = load_component_state_dict(action_head_path)
|
| 511 |
-
action_head.load_state_dict(state_dict)
|
| 512 |
-
else:
|
| 513 |
-
checkpoint_path = find_checkpoint_file(cfg.pretrained_checkpoint, "action_head")
|
| 514 |
-
state_dict = load_component_state_dict(checkpoint_path)
|
| 515 |
-
action_head.load_state_dict(state_dict)
|
| 516 |
-
|
| 517 |
-
return action_head
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
def resize_image_for_policy(img: np.ndarray, resize_size: Union[int, Tuple[int, int]]) -> np.ndarray:
|
| 521 |
-
"""
|
| 522 |
-
Resize an image to match the policy's expected input size.
|
| 523 |
-
|
| 524 |
-
Uses the same resizing scheme as in the training data pipeline for distribution matching.
|
| 525 |
-
|
| 526 |
-
Args:
|
| 527 |
-
img: Numpy array containing the image
|
| 528 |
-
resize_size: Target size as int (square) or (height, width) tuple
|
| 529 |
-
|
| 530 |
-
Returns:
|
| 531 |
-
np.ndarray: The resized image
|
| 532 |
-
"""
|
| 533 |
-
assert isinstance(resize_size, int) or isinstance(resize_size, tuple)
|
| 534 |
-
if isinstance(resize_size, int):
|
| 535 |
-
resize_size = (resize_size, resize_size)
|
| 536 |
-
|
| 537 |
-
# Resize using the same pipeline as in RLDS dataset builder
|
| 538 |
-
img = tf.image.encode_jpeg(img) # Encode as JPEG
|
| 539 |
-
img = tf.io.decode_image(img, expand_animations=False, dtype=tf.uint8) # Decode back
|
| 540 |
-
img = tf.image.resize(img, resize_size, method="lanczos3", antialias=True)
|
| 541 |
-
img = tf.cast(tf.clip_by_value(tf.round(img), 0, 255), tf.uint8)
|
| 542 |
-
|
| 543 |
-
return img.numpy()
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
def crop_and_resize(image: tf.Tensor, crop_scale: float, batch_size: int) -> tf.Tensor:
|
| 547 |
-
"""
|
| 548 |
-
Center-crop an image and resize it back to original dimensions.
|
| 549 |
-
|
| 550 |
-
Uses the same logic as in the training data pipeline for distribution matching.
|
| 551 |
-
|
| 552 |
-
Args:
|
| 553 |
-
image: TF Tensor of shape (batch_size, H, W, C) or (H, W, C) with values in [0,1]
|
| 554 |
-
crop_scale: Area of center crop relative to original image
|
| 555 |
-
batch_size: Batch size
|
| 556 |
-
|
| 557 |
-
Returns:
|
| 558 |
-
tf.Tensor: The cropped and resized image
|
| 559 |
-
"""
|
| 560 |
-
# Handle 3D inputs by adding batch dimension if needed
|
| 561 |
-
assert image.shape.ndims in (3, 4), "Image must be 3D or 4D tensor"
|
| 562 |
-
expanded_dims = False
|
| 563 |
-
if image.shape.ndims == 3:
|
| 564 |
-
image = tf.expand_dims(image, axis=0)
|
| 565 |
-
expanded_dims = True
|
| 566 |
-
|
| 567 |
-
# Calculate crop dimensions (note: we use sqrt(crop_scale) for h/w)
|
| 568 |
-
new_heights = tf.reshape(tf.clip_by_value(tf.sqrt(crop_scale), 0, 1), shape=(batch_size,))
|
| 569 |
-
new_widths = tf.reshape(tf.clip_by_value(tf.sqrt(crop_scale), 0, 1), shape=(batch_size,))
|
| 570 |
-
|
| 571 |
-
# Create bounding box for the crop
|
| 572 |
-
height_offsets = (1 - new_heights) / 2
|
| 573 |
-
width_offsets = (1 - new_widths) / 2
|
| 574 |
-
bounding_boxes = tf.stack(
|
| 575 |
-
[
|
| 576 |
-
height_offsets,
|
| 577 |
-
width_offsets,
|
| 578 |
-
height_offsets + new_heights,
|
| 579 |
-
width_offsets + new_widths,
|
| 580 |
-
],
|
| 581 |
-
axis=1,
|
| 582 |
-
)
|
| 583 |
-
|
| 584 |
-
# Apply crop and resize
|
| 585 |
-
image = tf.image.crop_and_resize(
|
| 586 |
-
image, bounding_boxes, tf.range(batch_size), (OPENVLA_IMAGE_SIZE, OPENVLA_IMAGE_SIZE)
|
| 587 |
-
)
|
| 588 |
-
|
| 589 |
-
# Remove batch dimension if it was added
|
| 590 |
-
if expanded_dims:
|
| 591 |
-
image = image[0]
|
| 592 |
-
|
| 593 |
-
return image
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
def center_crop_image(image: Union[np.ndarray, Image.Image]) -> Image.Image:
|
| 597 |
-
"""
|
| 598 |
-
Center crop an image to match training data distribution.
|
| 599 |
-
|
| 600 |
-
Args:
|
| 601 |
-
image: Input image (PIL or numpy array)
|
| 602 |
-
|
| 603 |
-
Returns:
|
| 604 |
-
Image.Image: Cropped PIL Image
|
| 605 |
-
"""
|
| 606 |
-
batch_size = 1
|
| 607 |
-
crop_scale = 0.9
|
| 608 |
-
|
| 609 |
-
# Convert to TF Tensor if needed
|
| 610 |
-
if not isinstance(image, tf.Tensor):
|
| 611 |
-
image = tf.convert_to_tensor(np.array(image))
|
| 612 |
-
|
| 613 |
-
orig_dtype = image.dtype
|
| 614 |
-
|
| 615 |
-
# Convert to float32 in range [0,1]
|
| 616 |
-
image = tf.image.convert_image_dtype(image, tf.float32)
|
| 617 |
-
|
| 618 |
-
# Apply center crop and resize
|
| 619 |
-
image = crop_and_resize(image, crop_scale, batch_size)
|
| 620 |
-
|
| 621 |
-
# Convert back to original data type
|
| 622 |
-
image = tf.clip_by_value(image, 0, 1)
|
| 623 |
-
image = tf.image.convert_image_dtype(image, orig_dtype, saturate=True)
|
| 624 |
-
|
| 625 |
-
# Convert to PIL Image
|
| 626 |
-
return Image.fromarray(image.numpy()).convert("RGB")
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
def check_image_format(image: Any) -> None:
|
| 630 |
-
"""
|
| 631 |
-
Validate input image format.
|
| 632 |
-
|
| 633 |
-
Args:
|
| 634 |
-
image: Image to check
|
| 635 |
-
|
| 636 |
-
Raises:
|
| 637 |
-
AssertionError: If image format is invalid
|
| 638 |
-
"""
|
| 639 |
-
is_numpy_array = isinstance(image, np.ndarray)
|
| 640 |
-
has_correct_shape = len(image.shape) == 3 and image.shape[-1] == 3
|
| 641 |
-
has_correct_dtype = image.dtype == np.uint8
|
| 642 |
-
|
| 643 |
-
assert is_numpy_array and has_correct_shape and has_correct_dtype, (
|
| 644 |
-
"Incorrect image format detected! Make sure that the input image is a "
|
| 645 |
-
"numpy array with shape (H, W, 3) and dtype np.uint8!"
|
| 646 |
-
)
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
def normalize_proprio(proprio: np.ndarray, norm_stats: Dict[str, Any]) -> np.ndarray:
|
| 650 |
-
"""
|
| 651 |
-
Normalize proprioception data to match training distribution.
|
| 652 |
-
|
| 653 |
-
Args:
|
| 654 |
-
proprio: Raw proprioception data
|
| 655 |
-
norm_stats: Normalization statistics
|
| 656 |
-
|
| 657 |
-
Returns:
|
| 658 |
-
np.ndarray: Normalized proprioception data
|
| 659 |
-
"""
|
| 660 |
-
if ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS:
|
| 661 |
-
mask = norm_stats.get("mask", np.ones_like(norm_stats["min"], dtype=bool))
|
| 662 |
-
proprio_high, proprio_low = np.array(norm_stats["max"]), np.array(norm_stats["min"])
|
| 663 |
-
elif ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS_Q99:
|
| 664 |
-
mask = norm_stats.get("mask", np.ones_like(norm_stats["q01"], dtype=bool))
|
| 665 |
-
proprio_high, proprio_low = np.array(norm_stats["q99"]), np.array(norm_stats["q01"])
|
| 666 |
-
else:
|
| 667 |
-
raise ValueError("Unsupported action/proprio normalization type detected!")
|
| 668 |
-
|
| 669 |
-
normalized_proprio = np.clip(
|
| 670 |
-
np.where(
|
| 671 |
-
mask,
|
| 672 |
-
2 * (proprio - proprio_low) / (proprio_high - proprio_low + 1e-8) - 1,
|
| 673 |
-
proprio,
|
| 674 |
-
),
|
| 675 |
-
a_min=-1.0,
|
| 676 |
-
a_max=1.0,
|
| 677 |
-
)
|
| 678 |
-
|
| 679 |
-
return normalized_proprio
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
def prepare_images_for_vla(images: List[np.ndarray], cfg: Any) -> List[Image.Image]:
|
| 683 |
-
"""
|
| 684 |
-
Prepare images for VLA input by resizing and cropping as needed.
|
| 685 |
-
|
| 686 |
-
Args:
|
| 687 |
-
images: List of input images as numpy arrays
|
| 688 |
-
cfg: Configuration object with parameters
|
| 689 |
-
|
| 690 |
-
Returns:
|
| 691 |
-
List[Image.Image]: Processed images ready for the model
|
| 692 |
-
"""
|
| 693 |
-
processed_images = []
|
| 694 |
-
|
| 695 |
-
for image in images:
|
| 696 |
-
# Validate format
|
| 697 |
-
check_image_format(image)
|
| 698 |
-
|
| 699 |
-
# Resize if needed
|
| 700 |
-
if image.shape != (OPENVLA_IMAGE_SIZE, OPENVLA_IMAGE_SIZE, 3):
|
| 701 |
-
image = resize_image_for_policy(image, OPENVLA_IMAGE_SIZE)
|
| 702 |
-
|
| 703 |
-
# Convert to PIL image
|
| 704 |
-
pil_image = Image.fromarray(image).convert("RGB")
|
| 705 |
-
|
| 706 |
-
# Apply center crop if configured
|
| 707 |
-
if cfg.center_crop:
|
| 708 |
-
pil_image = center_crop_image(pil_image)
|
| 709 |
-
|
| 710 |
-
processed_images.append(pil_image)
|
| 711 |
-
|
| 712 |
-
return processed_images
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
def get_vla_action(
|
| 716 |
-
cfg: Any,
|
| 717 |
-
vla: torch.nn.Module,
|
| 718 |
-
processor: Any,
|
| 719 |
-
obs: Dict[str, Any],
|
| 720 |
-
task_label: str,
|
| 721 |
-
action_head: Optional[torch.nn.Module] = None,
|
| 722 |
-
proprio_projector: Optional[torch.nn.Module] = None,
|
| 723 |
-
noisy_action_projector: Optional[torch.nn.Module] = None,
|
| 724 |
-
use_film: bool = False,
|
| 725 |
-
) -> List[np.ndarray]:
|
| 726 |
-
"""
|
| 727 |
-
Generate action predictions with the VLA policy.
|
| 728 |
-
|
| 729 |
-
Args:
|
| 730 |
-
cfg: Configuration object with parameters
|
| 731 |
-
vla: The VLA model
|
| 732 |
-
processor: Model processor for inputs
|
| 733 |
-
obs: Observation dictionary
|
| 734 |
-
task_label: Text description of the task
|
| 735 |
-
action_head: Optional action head for continuous actions
|
| 736 |
-
proprio_projector: Optional proprioception projector
|
| 737 |
-
noisy_action_projector: Optional noisy action projector for diffusion
|
| 738 |
-
use_film: Whether to use FiLM
|
| 739 |
-
|
| 740 |
-
Returns:
|
| 741 |
-
List[np.ndarray]: Predicted actions
|
| 742 |
-
"""
|
| 743 |
-
with torch.inference_mode():
|
| 744 |
-
|
| 745 |
-
# Collect all input images
|
| 746 |
-
all_images = [obs["full_image"]]
|
| 747 |
-
if cfg.num_images_in_input > 1:
|
| 748 |
-
all_images.extend([obs[k] for k in obs.keys() if "wrist" in k])
|
| 749 |
-
|
| 750 |
-
# Process images
|
| 751 |
-
all_images = prepare_images_for_vla(all_images, cfg)
|
| 752 |
-
|
| 753 |
-
# Extract primary image and additional images
|
| 754 |
-
primary_image = all_images.pop(0)
|
| 755 |
-
|
| 756 |
-
# Build VLA prompt
|
| 757 |
-
prompt = f"In: What action should the robot take to {task_label.lower()}?\nOut:"
|
| 758 |
-
|
| 759 |
-
# Process primary image
|
| 760 |
-
inputs = processor(prompt, primary_image).to(DEVICE, dtype=torch.bfloat16)
|
| 761 |
-
|
| 762 |
-
# Process additional wrist images if any
|
| 763 |
-
if all_images:
|
| 764 |
-
all_wrist_inputs = [
|
| 765 |
-
processor(prompt, image_wrist).to(DEVICE, dtype=torch.bfloat16) for image_wrist in all_images
|
| 766 |
-
]
|
| 767 |
-
# Concatenate all images
|
| 768 |
-
primary_pixel_values = inputs["pixel_values"]
|
| 769 |
-
all_wrist_pixel_values = [wrist_inputs["pixel_values"] for wrist_inputs in all_wrist_inputs]
|
| 770 |
-
inputs["pixel_values"] = torch.cat([primary_pixel_values] + all_wrist_pixel_values, dim=1)
|
| 771 |
-
|
| 772 |
-
# Process proprioception data if used
|
| 773 |
-
proprio = None
|
| 774 |
-
if cfg.use_proprio:
|
| 775 |
-
proprio = obs["state"]
|
| 776 |
-
proprio_norm_stats = vla.norm_stats[cfg.unnorm_key]["proprio"]
|
| 777 |
-
obs["state"] = normalize_proprio(proprio, proprio_norm_stats)
|
| 778 |
-
proprio = obs["state"]
|
| 779 |
-
|
| 780 |
-
# Generate action
|
| 781 |
-
if action_head is None:
|
| 782 |
-
# Standard VLA output (single-image inputs, discrete actions)
|
| 783 |
-
action, _ = vla.predict_action(**inputs, unnorm_key=cfg.unnorm_key, do_sample=False)
|
| 784 |
-
else:
|
| 785 |
-
# Custom action head for continuous actions
|
| 786 |
-
action, _ = vla.predict_action(
|
| 787 |
-
**inputs,
|
| 788 |
-
unnorm_key=cfg.unnorm_key,
|
| 789 |
-
do_sample=False,
|
| 790 |
-
proprio=proprio,
|
| 791 |
-
proprio_projector=proprio_projector,
|
| 792 |
-
noisy_action_projector=noisy_action_projector,
|
| 793 |
-
action_head=action_head,
|
| 794 |
-
use_film=use_film,
|
| 795 |
-
)
|
| 796 |
-
|
| 797 |
-
# Return action chunk as list of actions
|
| 798 |
-
return [action[i] for i in range(len(action))]
|
| 799 |
-
|
| 800 |
-
|
| 801 |
-
def get_action_from_server(
|
| 802 |
-
observation: Dict[str, Any], server_endpoint: str = "http://0.0.0.0:8777/act"
|
| 803 |
-
) -> Dict[str, Any]:
|
| 804 |
-
"""
|
| 805 |
-
Get VLA action from remote inference server.
|
| 806 |
-
|
| 807 |
-
Args:
|
| 808 |
-
observation: Observation data to send to server
|
| 809 |
-
server_endpoint: URL of the inference server
|
| 810 |
-
|
| 811 |
-
Returns:
|
| 812 |
-
Dict[str, Any]: Action response from server
|
| 813 |
-
"""
|
| 814 |
-
response = requests.post(
|
| 815 |
-
server_endpoint,
|
| 816 |
-
json=observation,
|
| 817 |
-
)
|
| 818 |
-
return response.json()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/experiments/robot/robot_utils.py
DELETED
|
@@ -1,199 +0,0 @@
|
|
| 1 |
-
"""Utils for evaluating robot policies in various environments."""
|
| 2 |
-
|
| 3 |
-
import os
|
| 4 |
-
import random
|
| 5 |
-
import time
|
| 6 |
-
from typing import Any, Dict, List, Optional, Union
|
| 7 |
-
|
| 8 |
-
import numpy as np
|
| 9 |
-
import torch
|
| 10 |
-
|
| 11 |
-
from experiments.robot.openvla_utils import (
|
| 12 |
-
get_vla,
|
| 13 |
-
get_vla_action,
|
| 14 |
-
)
|
| 15 |
-
|
| 16 |
-
# Initialize important constants
|
| 17 |
-
ACTION_DIM = 7
|
| 18 |
-
DATE = time.strftime("%Y_%m_%d")
|
| 19 |
-
DATE_TIME = time.strftime("%Y_%m_%d-%H_%M_%S")
|
| 20 |
-
DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
| 21 |
-
|
| 22 |
-
# Configure NumPy print settings
|
| 23 |
-
np.set_printoptions(formatter={"float": lambda x: "{0:0.3f}".format(x)})
|
| 24 |
-
|
| 25 |
-
# Initialize system prompt for OpenVLA v0.1
|
| 26 |
-
OPENVLA_V01_SYSTEM_PROMPT = (
|
| 27 |
-
"A chat between a curious user and an artificial intelligence assistant. "
|
| 28 |
-
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
| 29 |
-
)
|
| 30 |
-
|
| 31 |
-
# Model image size configuration
|
| 32 |
-
MODEL_IMAGE_SIZES = {
|
| 33 |
-
"openvla": 224,
|
| 34 |
-
# Add other models as needed
|
| 35 |
-
}
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
def set_seed_everywhere(seed: int) -> None:
|
| 39 |
-
"""
|
| 40 |
-
Set random seed for all random number generators for reproducibility.
|
| 41 |
-
|
| 42 |
-
Args:
|
| 43 |
-
seed: The random seed to use
|
| 44 |
-
"""
|
| 45 |
-
torch.manual_seed(seed)
|
| 46 |
-
torch.cuda.manual_seed_all(seed)
|
| 47 |
-
np.random.seed(seed)
|
| 48 |
-
random.seed(seed)
|
| 49 |
-
torch.backends.cudnn.deterministic = True
|
| 50 |
-
torch.backends.cudnn.benchmark = False
|
| 51 |
-
os.environ["PYTHONHASHSEED"] = str(seed)
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def get_model(cfg: Any, wrap_diffusion_policy_for_droid: bool = False) -> torch.nn.Module:
|
| 55 |
-
"""
|
| 56 |
-
Load and initialize model for evaluation based on configuration.
|
| 57 |
-
|
| 58 |
-
Args:
|
| 59 |
-
cfg: Configuration object with model parameters
|
| 60 |
-
wrap_diffusion_policy_for_droid: Whether to wrap diffusion policy for DROID
|
| 61 |
-
|
| 62 |
-
Returns:
|
| 63 |
-
torch.nn.Module: The loaded model
|
| 64 |
-
|
| 65 |
-
Raises:
|
| 66 |
-
ValueError: If model family is not supported
|
| 67 |
-
"""
|
| 68 |
-
if cfg.model_family == "openvla":
|
| 69 |
-
model = get_vla(cfg)
|
| 70 |
-
else:
|
| 71 |
-
raise ValueError(f"Unsupported model family: {cfg.model_family}")
|
| 72 |
-
|
| 73 |
-
print(f"Loaded model: {type(model)}")
|
| 74 |
-
return model
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
def get_image_resize_size(cfg: Any) -> Union[int, tuple]:
|
| 78 |
-
"""
|
| 79 |
-
Get image resize dimensions for a specific model.
|
| 80 |
-
|
| 81 |
-
If returned value is an int, the resized image will be a square.
|
| 82 |
-
If returned value is a tuple, the resized image will be a rectangle.
|
| 83 |
-
|
| 84 |
-
Args:
|
| 85 |
-
cfg: Configuration object with model parameters
|
| 86 |
-
|
| 87 |
-
Returns:
|
| 88 |
-
Union[int, tuple]: Image resize dimensions
|
| 89 |
-
|
| 90 |
-
Raises:
|
| 91 |
-
ValueError: If model family is not supported
|
| 92 |
-
"""
|
| 93 |
-
if cfg.model_family not in MODEL_IMAGE_SIZES:
|
| 94 |
-
raise ValueError(f"Unsupported model family: {cfg.model_family}")
|
| 95 |
-
|
| 96 |
-
return MODEL_IMAGE_SIZES[cfg.model_family]
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
def get_action(
|
| 100 |
-
cfg: Any,
|
| 101 |
-
model: torch.nn.Module,
|
| 102 |
-
obs: Dict[str, Any],
|
| 103 |
-
task_label: str,
|
| 104 |
-
processor: Optional[Any] = None,
|
| 105 |
-
action_head: Optional[torch.nn.Module] = None,
|
| 106 |
-
proprio_projector: Optional[torch.nn.Module] = None,
|
| 107 |
-
noisy_action_projector: Optional[torch.nn.Module] = None,
|
| 108 |
-
use_film: bool = False,
|
| 109 |
-
) -> Union[List[np.ndarray], np.ndarray]:
|
| 110 |
-
"""
|
| 111 |
-
Query the model to get action predictions.
|
| 112 |
-
|
| 113 |
-
Args:
|
| 114 |
-
cfg: Configuration object with model parameters
|
| 115 |
-
model: The loaded model
|
| 116 |
-
obs: Observation dictionary
|
| 117 |
-
task_label: Text description of the task
|
| 118 |
-
processor: Model processor for inputs
|
| 119 |
-
action_head: Optional action head for continuous actions
|
| 120 |
-
proprio_projector: Optional proprioception projector
|
| 121 |
-
noisy_action_projector: Optional noisy action projector for diffusion
|
| 122 |
-
use_film: Whether to use FiLM
|
| 123 |
-
|
| 124 |
-
Returns:
|
| 125 |
-
Union[List[np.ndarray], np.ndarray]: Predicted actions
|
| 126 |
-
|
| 127 |
-
Raises:
|
| 128 |
-
ValueError: If model family is not supported
|
| 129 |
-
"""
|
| 130 |
-
with torch.no_grad():
|
| 131 |
-
if cfg.model_family == "openvla":
|
| 132 |
-
action = get_vla_action(
|
| 133 |
-
cfg=cfg,
|
| 134 |
-
vla=model,
|
| 135 |
-
processor=processor,
|
| 136 |
-
obs=obs,
|
| 137 |
-
task_label=task_label,
|
| 138 |
-
action_head=action_head,
|
| 139 |
-
proprio_projector=proprio_projector,
|
| 140 |
-
noisy_action_projector=noisy_action_projector,
|
| 141 |
-
use_film=use_film,
|
| 142 |
-
)
|
| 143 |
-
else:
|
| 144 |
-
raise ValueError(f"Unsupported model family: {cfg.model_family}")
|
| 145 |
-
|
| 146 |
-
return action
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
def normalize_gripper_action(action: np.ndarray, binarize: bool = True) -> np.ndarray:
|
| 150 |
-
"""
|
| 151 |
-
Normalize gripper action from [0,1] to [-1,+1] range.
|
| 152 |
-
|
| 153 |
-
This is necessary for some environments because the dataset wrapper
|
| 154 |
-
standardizes gripper actions to [0,1]. Note that unlike the other action
|
| 155 |
-
dimensions, the gripper action is not normalized to [-1,+1] by default.
|
| 156 |
-
|
| 157 |
-
Normalization formula: y = 2 * (x - orig_low) / (orig_high - orig_low) - 1
|
| 158 |
-
|
| 159 |
-
Args:
|
| 160 |
-
action: Action array with gripper action in the last dimension
|
| 161 |
-
binarize: Whether to binarize gripper action to -1 or +1
|
| 162 |
-
|
| 163 |
-
Returns:
|
| 164 |
-
np.ndarray: Action array with normalized gripper action
|
| 165 |
-
"""
|
| 166 |
-
# Create a copy to avoid modifying the original
|
| 167 |
-
normalized_action = action.copy()
|
| 168 |
-
|
| 169 |
-
# Normalize the last action dimension to [-1,+1]
|
| 170 |
-
orig_low, orig_high = 0.0, 1.0
|
| 171 |
-
normalized_action[..., -1] = 2 * (normalized_action[..., -1] - orig_low) / (orig_high - orig_low) - 1
|
| 172 |
-
|
| 173 |
-
if binarize:
|
| 174 |
-
# Binarize to -1 or +1
|
| 175 |
-
normalized_action[..., -1] = np.sign(normalized_action[..., -1])
|
| 176 |
-
|
| 177 |
-
return normalized_action
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
def invert_gripper_action(action: np.ndarray) -> np.ndarray:
|
| 181 |
-
"""
|
| 182 |
-
Flip the sign of the gripper action (last dimension of action vector).
|
| 183 |
-
|
| 184 |
-
This is necessary for environments where -1 = open, +1 = close, since
|
| 185 |
-
the RLDS dataloader aligns gripper actions such that 0 = close, 1 = open.
|
| 186 |
-
|
| 187 |
-
Args:
|
| 188 |
-
action: Action array with gripper action in the last dimension
|
| 189 |
-
|
| 190 |
-
Returns:
|
| 191 |
-
np.ndarray: Action array with inverted gripper action
|
| 192 |
-
"""
|
| 193 |
-
# Create a copy to avoid modifying the original
|
| 194 |
-
inverted_action = action.copy()
|
| 195 |
-
|
| 196 |
-
# Invert the gripper action
|
| 197 |
-
inverted_action[..., -1] *= -1.0
|
| 198 |
-
|
| 199 |
-
return inverted_action
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/prismatic/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
from .models import available_model_names, available_models, get_model_description, load
|
|
|
|
|
|
capvector-oft/prismatic/conf/__init__.py
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
from .datasets import DatasetConfig, DatasetRegistry
|
| 2 |
-
from .models import ModelConfig, ModelRegistry
|
| 3 |
-
from .vla import VLAConfig, VLARegistry
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/prismatic/conf/datasets.py
DELETED
|
@@ -1,133 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
datasets.py
|
| 3 |
-
|
| 4 |
-
Draccus Dataclass Definition for a DatasetConfig object, with various registered subclasses for each dataset variant
|
| 5 |
-
and processing scheme. A given dataset variant (e.g., `llava-lightning`) configures the following attributes:
|
| 6 |
-
- Dataset Variant (Identifier) --> e.g., "llava-v15"
|
| 7 |
-
- Align Stage Dataset Components (annotations, images)
|
| 8 |
-
- Finetune Stage Dataset Components (annotations, images)
|
| 9 |
-
- Dataset Root Directory (Path)
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
from dataclasses import dataclass
|
| 13 |
-
from enum import Enum, unique
|
| 14 |
-
from pathlib import Path
|
| 15 |
-
from typing import Tuple
|
| 16 |
-
|
| 17 |
-
from draccus import ChoiceRegistry
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
@dataclass
|
| 21 |
-
class DatasetConfig(ChoiceRegistry):
|
| 22 |
-
# fmt: off
|
| 23 |
-
dataset_id: str # Unique ID that fully specifies a dataset variant
|
| 24 |
-
|
| 25 |
-
# Dataset Components for each Stage in < align | finetune >
|
| 26 |
-
align_stage_components: Tuple[Path, Path] # Path to annotation file and images directory for `align` stage
|
| 27 |
-
finetune_stage_components: Tuple[Path, Path] # Path to annotation file and images directory for `finetune` stage
|
| 28 |
-
|
| 29 |
-
dataset_root_dir: Path # Path to dataset root directory; others paths are relative to root
|
| 30 |
-
# fmt: on
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
# [Reproduction] LLaVa-v15 (exact dataset used in all public LLaVa-v15 models)
|
| 34 |
-
@dataclass
|
| 35 |
-
class LLaVa_V15_Config(DatasetConfig):
|
| 36 |
-
dataset_id: str = "llava-v15"
|
| 37 |
-
|
| 38 |
-
align_stage_components: Tuple[Path, Path] = (
|
| 39 |
-
Path("download/llava-laion-cc-sbu-558k/chat.json"),
|
| 40 |
-
Path("download/llava-laion-cc-sbu-558k/"),
|
| 41 |
-
)
|
| 42 |
-
finetune_stage_components: Tuple[Path, Path] = (
|
| 43 |
-
Path("download/llava-v1.5-instruct/llava_v1_5_mix665k.json"),
|
| 44 |
-
Path("download/llava-v1.5-instruct/"),
|
| 45 |
-
)
|
| 46 |
-
dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms")
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
# [Multimodal-Only] LLava-v15 WITHOUT the Language-Only ShareGPT Data (No Co-Training)
|
| 50 |
-
@dataclass
|
| 51 |
-
class LLaVa_Multimodal_Only_Config(DatasetConfig):
|
| 52 |
-
dataset_id: str = "llava-multimodal"
|
| 53 |
-
|
| 54 |
-
align_stage_components: Tuple[Path, Path] = (
|
| 55 |
-
Path("download/llava-laion-cc-sbu-558k/chat.json"),
|
| 56 |
-
Path("download/llava-laion-cc-sbu-558k/"),
|
| 57 |
-
)
|
| 58 |
-
finetune_stage_components: Tuple[Path, Path] = (
|
| 59 |
-
Path("download/llava-v1.5-instruct/llava_v1_5_stripped625k.json"),
|
| 60 |
-
Path("download/llava-v1.5-instruct/"),
|
| 61 |
-
)
|
| 62 |
-
dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms")
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
# LLaVa-v15 + LVIS-Instruct-4V
|
| 66 |
-
@dataclass
|
| 67 |
-
class LLaVa_LVIS4V_Config(DatasetConfig):
|
| 68 |
-
dataset_id: str = "llava-lvis4v"
|
| 69 |
-
|
| 70 |
-
align_stage_components: Tuple[Path, Path] = (
|
| 71 |
-
Path("download/llava-laion-cc-sbu-558k/chat.json"),
|
| 72 |
-
Path("download/llava-laion-cc-sbu-558k/"),
|
| 73 |
-
)
|
| 74 |
-
finetune_stage_components: Tuple[Path, Path] = (
|
| 75 |
-
Path("download/llava-v1.5-instruct/llava_v1_5_lvis4v_mix888k.json"),
|
| 76 |
-
Path("download/llava-v1.5-instruct/"),
|
| 77 |
-
)
|
| 78 |
-
dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms")
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
# LLaVa-v15 + LRV-Instruct
|
| 82 |
-
@dataclass
|
| 83 |
-
class LLaVa_LRV_Config(DatasetConfig):
|
| 84 |
-
dataset_id: str = "llava-lrv"
|
| 85 |
-
|
| 86 |
-
align_stage_components: Tuple[Path, Path] = (
|
| 87 |
-
Path("download/llava-laion-cc-sbu-558k/chat.json"),
|
| 88 |
-
Path("download/llava-laion-cc-sbu-558k/"),
|
| 89 |
-
)
|
| 90 |
-
finetune_stage_components: Tuple[Path, Path] = (
|
| 91 |
-
Path("download/llava-v1.5-instruct/llava_v1_5_lrv_mix1008k.json"),
|
| 92 |
-
Path("download/llava-v1.5-instruct/"),
|
| 93 |
-
)
|
| 94 |
-
dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms")
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
# LLaVa-v15 + LVIS-Instruct-4V + LRV-Instruct
|
| 98 |
-
@dataclass
|
| 99 |
-
class LLaVa_LVIS4V_LRV_Config(DatasetConfig):
|
| 100 |
-
dataset_id: str = "llava-lvis4v-lrv"
|
| 101 |
-
|
| 102 |
-
align_stage_components: Tuple[Path, Path] = (
|
| 103 |
-
Path("download/llava-laion-cc-sbu-558k/chat.json"),
|
| 104 |
-
Path("download/llava-laion-cc-sbu-558k/"),
|
| 105 |
-
)
|
| 106 |
-
finetune_stage_components: Tuple[Path, Path] = (
|
| 107 |
-
Path("download/llava-v1.5-instruct/llava_v1_5_lvis4v_lrv_mix1231k.json"),
|
| 108 |
-
Path("download/llava-v1.5-instruct/"),
|
| 109 |
-
)
|
| 110 |
-
dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms")
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
# === Define a Dataset Registry Enum for Reference & Validation =>> all *new* datasets must be added here! ===
|
| 114 |
-
@unique
|
| 115 |
-
class DatasetRegistry(Enum):
|
| 116 |
-
# === LLaVa v1.5 ===
|
| 117 |
-
LLAVA_V15 = LLaVa_V15_Config
|
| 118 |
-
|
| 119 |
-
LLAVA_MULTIMODAL_ONLY = LLaVa_Multimodal_Only_Config
|
| 120 |
-
|
| 121 |
-
LLAVA_LVIS4V = LLaVa_LVIS4V_Config
|
| 122 |
-
LLAVA_LRV = LLaVa_LRV_Config
|
| 123 |
-
|
| 124 |
-
LLAVA_LVIS4V_LRV = LLaVa_LVIS4V_LRV_Config
|
| 125 |
-
|
| 126 |
-
@property
|
| 127 |
-
def dataset_id(self) -> str:
|
| 128 |
-
return self.value.dataset_id
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
# Register Datasets in Choice Registry
|
| 132 |
-
for dataset_variant in DatasetRegistry:
|
| 133 |
-
DatasetConfig.register_subclass(dataset_variant.dataset_id, dataset_variant.value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/prismatic/conf/models.py
DELETED
|
@@ -1,584 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
models.py
|
| 3 |
-
|
| 4 |
-
Draccus Dataclass Definition for a ModelConfig object, with various registered subclasses for each model family and
|
| 5 |
-
variant thereof. A given model variant configures the following attributes:
|
| 6 |
-
- Pretrained Visual Representation (e.g., OpenAI CLIP ViT-L/14) + Pretrained LLM Backbone (e.g., LLaMa-2 7B)
|
| 7 |
-
- VLM Configuration + Parameters (e.g., MLP Projector, Image Preprocessing, etc.)
|
| 8 |
-
- [Optional] Stage 1 (`align`) Optimization Hyperparameters
|
| 9 |
-
- Stage 2 (`finetune`) Optimization Hyperparameters
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
from dataclasses import dataclass
|
| 13 |
-
from enum import Enum, unique
|
| 14 |
-
from typing import Optional
|
| 15 |
-
|
| 16 |
-
from draccus import ChoiceRegistry
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
@dataclass
|
| 20 |
-
class ModelConfig(ChoiceRegistry):
|
| 21 |
-
# fmt: off
|
| 22 |
-
model_id: str # Unique Model ID that fully specifies a given variant
|
| 23 |
-
arch_specifier: str # Architecture specifier string (e.g., "gelu-mlp")
|
| 24 |
-
|
| 25 |
-
# Pretrained Backbones
|
| 26 |
-
vision_backbone_id: str # Pretrained Visual Featurizer (from TIMM) to load
|
| 27 |
-
llm_backbone_id: str # Pretrained LLM (from HF Transformers) to load
|
| 28 |
-
|
| 29 |
-
# Backbone Parameters
|
| 30 |
-
image_resize_strategy: str # Resizing strategy in < crop | letterbox | corner-pad >
|
| 31 |
-
llm_max_length: int # Maximum context length for LLM (can be < than max!)
|
| 32 |
-
|
| 33 |
-
# === Multi-Stage Optimization Hyperparameters ===
|
| 34 |
-
# By default, we assume an AdamW optimizer with FSDP (Gradient Sharding or Full Sharding depending on stage)
|
| 35 |
-
|
| 36 |
-
# Align Stage Optimization Parameters
|
| 37 |
-
align_epochs: int # Epochs to Run (in case `max_steps` is not specified)
|
| 38 |
-
align_max_steps: Optional[int] # [Optional] Max Gradient Steps (overrides epochs)
|
| 39 |
-
align_global_batch_size: int # Global Batch Size (divided across processes)
|
| 40 |
-
align_per_device_batch_size: int # Per-Device Batch Size (per-process)
|
| 41 |
-
# => # of accumulation steps is auto-computed
|
| 42 |
-
|
| 43 |
-
align_learning_rate: float # Peak Learning Rate (lr_scheduler sets warmup/decay)
|
| 44 |
-
align_weight_decay: float # Weight Decay for AdamW Optimizer
|
| 45 |
-
align_max_grad_norm: float # Max Grad Norm (for global gradient clipping)
|
| 46 |
-
align_lr_scheduler_type: str # LR Scheduler (default: "linear-warmup+cosine-decay")
|
| 47 |
-
align_warmup_ratio: float # Fraction of total steps to warmup
|
| 48 |
-
|
| 49 |
-
align_train_strategy: str # Align Train Strategy (default: "fsdp-shard-grad-op")
|
| 50 |
-
|
| 51 |
-
# Finetune Stage Optimization Parameters
|
| 52 |
-
finetune_epochs: int # Epochs to Run (in case `max_steps` is not specified)
|
| 53 |
-
finetune_max_steps: Optional[int] # [Optional] Max Gradient Steps (overrides epochs)
|
| 54 |
-
finetune_global_batch_size: int # Global Batch Size (divided across processes)
|
| 55 |
-
finetune_per_device_batch_size: int # Per-Device Batch Size (per-process)
|
| 56 |
-
# => # of accumulation steps is auto-computed
|
| 57 |
-
|
| 58 |
-
finetune_learning_rate: float # Peak Learning Rate (lr_scheduler sets warmup/decay)
|
| 59 |
-
finetune_weight_decay: float # Weight Decay for AdamW Optimizer
|
| 60 |
-
finetune_max_grad_norm: float # Max Grad Norm (for global gradient clipping)
|
| 61 |
-
finetune_lr_scheduler_type: str # LR Scheduler (default: "linear-warmup+cosine-decay")
|
| 62 |
-
finetune_warmup_ratio: float # Fraction of total steps to warmup
|
| 63 |
-
|
| 64 |
-
finetune_train_strategy: str # Finetune Train Strategy (default: "fsdp-full-shard")
|
| 65 |
-
|
| 66 |
-
# Enable Gradient/Activation Checkpointing (for the LLM Backbone)
|
| 67 |
-
enable_gradient_checkpointing: bool = True
|
| 68 |
-
|
| 69 |
-
# Enable Traditional Mixed Precision Training via Torch Native AMP (`autocast`)
|
| 70 |
-
enable_mixed_precision_training: bool = True # Whether to enable mixed precision training
|
| 71 |
-
reduce_in_full_precision: bool = False # Whether to run gradient reduction in FP32
|
| 72 |
-
|
| 73 |
-
# fmt: on
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
# === LLaVa v1.5 Reproduction - Fully Specified Configurations ===
|
| 77 |
-
@dataclass
|
| 78 |
-
class LLaVa_v15_Reproduction_7B(ModelConfig):
|
| 79 |
-
model_id: str = "reproduction-llava-v15+7b"
|
| 80 |
-
arch_specifier: str = "gelu-mlp"
|
| 81 |
-
|
| 82 |
-
vision_backbone_id: str = "clip-vit-l-336px"
|
| 83 |
-
llm_backbone_id: str = "vicuna-v15-7b"
|
| 84 |
-
|
| 85 |
-
image_resize_strategy: str = "letterbox"
|
| 86 |
-
llm_max_length: int = 2048
|
| 87 |
-
|
| 88 |
-
# Align Stage Optimization Parameters
|
| 89 |
-
align_epochs: int = 1
|
| 90 |
-
align_max_steps: Optional[int] = None
|
| 91 |
-
align_global_batch_size: int = 256
|
| 92 |
-
align_per_device_batch_size: int = 16
|
| 93 |
-
|
| 94 |
-
align_learning_rate: float = 1e-3
|
| 95 |
-
align_weight_decay: float = 0.0
|
| 96 |
-
align_max_grad_norm: float = 1.0
|
| 97 |
-
align_lr_scheduler_type: str = "linear-warmup+cosine-decay"
|
| 98 |
-
align_warmup_ratio: float = 0.03
|
| 99 |
-
|
| 100 |
-
align_train_strategy: str = "fsdp-shard-grad-op"
|
| 101 |
-
|
| 102 |
-
# Finetune Stage Optimization Parameters
|
| 103 |
-
finetune_epochs: int = 1
|
| 104 |
-
finetune_max_steps: Optional[int] = None
|
| 105 |
-
finetune_global_batch_size: int = 128
|
| 106 |
-
finetune_per_device_batch_size: int = 16
|
| 107 |
-
|
| 108 |
-
finetune_learning_rate: float = 2e-5
|
| 109 |
-
finetune_weight_decay: float = 0.1
|
| 110 |
-
finetune_max_grad_norm: float = 1.0
|
| 111 |
-
finetune_lr_scheduler_type: str = "linear-warmup+cosine-decay"
|
| 112 |
-
finetune_warmup_ratio: float = 0.03
|
| 113 |
-
|
| 114 |
-
finetune_train_strategy: str = "fsdp-full-shard"
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
@dataclass
|
| 118 |
-
class LLaVa_v15_Reproduction_13B(LLaVa_v15_Reproduction_7B):
|
| 119 |
-
model_id: str = "reproduction-llava-v15+13b"
|
| 120 |
-
llm_backbone_id: str = "vicuna-v15-13b"
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
# === Section 4.1 :: Optimization Procedure ===
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
# Section 4.1A :: 🚀 --> Necessity of Multi-Stage Training
|
| 127 |
-
@dataclass
|
| 128 |
-
class Exp_7B_One_Stage(LLaVa_v15_Reproduction_7B):
|
| 129 |
-
model_id: str = "one-stage+7b"
|
| 130 |
-
arch_specifier: str = "no-align+gelu-mlp"
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
@dataclass
|
| 134 |
-
class Exp_13B_One_Stage(LLaVa_v15_Reproduction_13B):
|
| 135 |
-
model_id: str = "one-stage+13b"
|
| 136 |
-
arch_specifier: str = "no-align+gelu-mlp"
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
# Section 4.1B :: 🛠️ --> Full Finetuning through Visual Backbones
|
| 140 |
-
# =>> Note :: Run with `--stage full-finetune`
|
| 141 |
-
@dataclass
|
| 142 |
-
class Exp_7B_Full_Finetune_Multi_Stage(LLaVa_v15_Reproduction_7B):
|
| 143 |
-
model_id: str = "full-ft-multi-stage+7b"
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
@dataclass
|
| 147 |
-
class Exp_7B_Full_Finetune_One_Stage(Exp_7B_One_Stage):
|
| 148 |
-
model_id: str = "full-ft-one-stage+7b"
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
# === Section 4.2 :: Image Processing and Visual Representations ===
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
# Section 4.2A :: 📸 --> Choosing a Pretrained Representation
|
| 155 |
-
@dataclass
|
| 156 |
-
class Exp_7B_IN1K_ViT_L_p16_224px(Exp_7B_One_Stage):
|
| 157 |
-
model_id: str = "in1k-224px+7b"
|
| 158 |
-
vision_backbone_id: str = "in1k-vit-l"
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
@dataclass
|
| 162 |
-
class Exp_7B_DINOv2_ViT_L_p14_224px(Exp_7B_One_Stage):
|
| 163 |
-
model_id: str = "dinov2-224px+7b"
|
| 164 |
-
vision_backbone_id: str = "dinov2-vit-l"
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
@dataclass
|
| 168 |
-
class Exp_7B_CLIP_ViT_L_p14_224px(Exp_7B_One_Stage):
|
| 169 |
-
model_id: str = "clip-224px+7b"
|
| 170 |
-
vision_backbone_id: str = "clip-vit-l"
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
@dataclass
|
| 174 |
-
class Exp_7B_SigLIP_ViT_SO_p14_224px(Exp_7B_One_Stage):
|
| 175 |
-
model_id: str = "siglip-224px+7b"
|
| 176 |
-
vision_backbone_id: str = "siglip-vit-so400m"
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
# Section 4.2B :: 📐 --> Choosing an Image Preprocessing Strategy
|
| 180 |
-
@dataclass
|
| 181 |
-
class Exp_7B_CLIP_ViT_L_p14_336px_Resize_Crop(Exp_7B_One_Stage):
|
| 182 |
-
model_id: str = "clip-336px-resize-crop+7b"
|
| 183 |
-
image_resize_strategy: str = "resize-crop"
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
@dataclass
|
| 187 |
-
class Exp_7B_CLIP_ViT_L_p14_336px_Resize_Naive(Exp_7B_One_Stage):
|
| 188 |
-
model_id: str = "clip-336px-resize-naive+7b"
|
| 189 |
-
image_resize_strategy: str = "resize-naive"
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
@dataclass
|
| 193 |
-
class Exp_7B_SigLIP_ViT_SO_p14_384px_Letterbox(Exp_7B_One_Stage):
|
| 194 |
-
model_id: str = "siglip-384px-letterbox+7b"
|
| 195 |
-
vision_backbone_id: str = "siglip-vit-so400m-384px"
|
| 196 |
-
image_resize_strategy: str = "letterbox"
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
@dataclass
|
| 200 |
-
class Exp_7B_SigLIP_ViT_SO_p14_384px_Resize_Crop(Exp_7B_One_Stage):
|
| 201 |
-
model_id: str = "siglip-384px-resize-crop+7b"
|
| 202 |
-
vision_backbone_id: str = "siglip-vit-so400m-384px"
|
| 203 |
-
image_resize_strategy: str = "resize-crop"
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
@dataclass
|
| 207 |
-
class Exp_7B_SigLIP_ViT_SO_p14_384px_Resize_Naive(Exp_7B_One_Stage):
|
| 208 |
-
model_id: str = "siglip-384px-resize-naive+7b"
|
| 209 |
-
vision_backbone_id: str = "siglip-vit-so400m-384px"
|
| 210 |
-
image_resize_strategy: str = "resize-naive"
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
# Section 4.2D :: 🥞 --> Stacking/Ensembling Visual Representations
|
| 214 |
-
@dataclass
|
| 215 |
-
class Exp_7B_DINOCLIP_ViT_L_p14_336px_Letterbox(Exp_7B_One_Stage):
|
| 216 |
-
model_id: str = "dinoclip-336px-letterbox+7b"
|
| 217 |
-
vision_backbone_id: str = "dinoclip-vit-l-336px"
|
| 218 |
-
image_resize_strategy: str = "letterbox"
|
| 219 |
-
arch_specifier: str = "no-align+fused-gelu-mlp"
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
@dataclass
|
| 223 |
-
class Exp_7B_DINOCLIP_ViT_L_p14_336px_Resize_Naive(Exp_7B_One_Stage):
|
| 224 |
-
model_id: str = "dinoclip-336px-resize-naive+7b"
|
| 225 |
-
vision_backbone_id: str = "dinoclip-vit-l-336px"
|
| 226 |
-
image_resize_strategy: str = "resize-naive"
|
| 227 |
-
arch_specifier: str = "no-align+fused-gelu-mlp"
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
@dataclass
|
| 231 |
-
class Exp_7B_DINOSigLIP_ViT_L_p14_384px_Letterbox(Exp_7B_One_Stage):
|
| 232 |
-
model_id: str = "dinosiglip-384px-letterbox+7b"
|
| 233 |
-
vision_backbone_id: str = "dinosiglip-vit-so-384px"
|
| 234 |
-
image_resize_strategy: str = "letterbox"
|
| 235 |
-
arch_specifier: str = "no-align+fused-gelu-mlp"
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
@dataclass
|
| 239 |
-
class Exp_7B_DINOSigLIP_ViT_L_p14_384px_Resize_Naive(Exp_7B_One_Stage):
|
| 240 |
-
model_id: str = "dinosiglip-384px-resize-naive+7b"
|
| 241 |
-
vision_backbone_id: str = "dinosiglip-vit-so-384px"
|
| 242 |
-
image_resize_strategy: str = "resize-naive"
|
| 243 |
-
arch_specifier: str = "no-align+fused-gelu-mlp"
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
# === Section 4.3 :: Language Models ===
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
# Section 4.3A :: 📝 --> Base vs. Instruct-Tuned (Chat) LLMs
|
| 250 |
-
@dataclass
|
| 251 |
-
class Exp_7B_Llama2(Exp_7B_One_Stage):
|
| 252 |
-
model_id: str = "llama2+7b"
|
| 253 |
-
llm_backbone_id: str = "llama2-7b-pure"
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
@dataclass
|
| 257 |
-
class Exp_13B_Llama2(Exp_13B_One_Stage):
|
| 258 |
-
model_id: str = "llama2+13b"
|
| 259 |
-
llm_backbone_id: str = "llama2-13b-pure"
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
# ~ Additional LLM Backbones :: LLaMa-2 Chat, Mistral v0.1, Mistral v0.1 Instruct, Phi-2 ~
|
| 263 |
-
@dataclass
|
| 264 |
-
class Ext_Exp_7B_Llama2_Chat(Exp_7B_One_Stage):
|
| 265 |
-
model_id: str = "llama2-chat+7b"
|
| 266 |
-
llm_backbone_id: str = "llama2-7b-chat"
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
@dataclass
|
| 270 |
-
class Ext_Exp_13B_Llama2_Chat(Exp_13B_One_Stage):
|
| 271 |
-
model_id: str = "llama2-chat+13b"
|
| 272 |
-
llm_backbone_id: str = "llama2-13b-chat"
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
@dataclass
|
| 276 |
-
class Ext_Exp_7B_Mistral_V1(Exp_7B_One_Stage):
|
| 277 |
-
model_id: str = "mistral-v0.1+7b"
|
| 278 |
-
llm_backbone_id: str = "mistral-v0.1-7b-pure"
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
@dataclass
|
| 282 |
-
class Ext_Exp_7B_Mistral_Instruct_V1(Exp_7B_One_Stage):
|
| 283 |
-
model_id: str = "mistral-instruct-v0.1+7b"
|
| 284 |
-
llm_backbone_id: str = "mistral-v0.1-7b-instruct"
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
@dataclass
|
| 288 |
-
class Ext_Exp_3B_Phi_2(Exp_7B_One_Stage):
|
| 289 |
-
model_id: str = "phi-2+3b"
|
| 290 |
-
llm_backbone_id: str = "phi-2-3b"
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
# Section 4.3B :: ✌️ --> Co-training on Language-only Data
|
| 294 |
-
# =>> Note :: Run with `--dataset.type "llava-multimodal" (multimodal data only / no co-training)
|
| 295 |
-
@dataclass
|
| 296 |
-
class Exp_7B_Vicuna_No_Cotraining(Exp_7B_One_Stage):
|
| 297 |
-
model_id: str = "vicuna-no-cotraining+7b"
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
@dataclass
|
| 301 |
-
class Exp_7B_Llama2_No_Cotraining(Exp_7B_One_Stage):
|
| 302 |
-
model_id: str = "llama2-no-cotraining+7b"
|
| 303 |
-
llm_backbone_id: str = "llama2-7b-pure"
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
# === Section 4.4 :: Scaling Properties - Train Time & Data ===
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
# Section 4.4A :: ⏰ --> Scaling Train Time
|
| 310 |
-
@dataclass
|
| 311 |
-
class Exp_7B_1p25_Epochs(Exp_7B_One_Stage):
|
| 312 |
-
model_id: str = "train-1.25-epochs+7b"
|
| 313 |
-
finetune_max_steps: int = 6500
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
@dataclass
|
| 317 |
-
class Exp_7B_1p5_Epochs(Exp_7B_One_Stage):
|
| 318 |
-
model_id: str = "train-1.5-epochs+7b"
|
| 319 |
-
finetune_max_steps: int = 7800
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
@dataclass
|
| 323 |
-
class Exp_7B_2_Epochs(Exp_7B_One_Stage):
|
| 324 |
-
model_id: str = "train-2-epochs+7b"
|
| 325 |
-
finetune_epochs: int = 2
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
@dataclass
|
| 329 |
-
class Exp_7B_3_Epochs(Exp_7B_One_Stage):
|
| 330 |
-
model_id: str = "train-3-epochs+7b"
|
| 331 |
-
finetune_epochs: int = 3
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
# Section 4.4B :: 📚 --> Scaling Data
|
| 335 |
-
# =>> Note :: Run with `--dataset.type "llava-lvis4v"`
|
| 336 |
-
@dataclass
|
| 337 |
-
class Exp_7B_LLaVa_LVIS4V(Exp_7B_One_Stage):
|
| 338 |
-
model_id: str = "llava-lvis4v+7b"
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
# =>> Note :: Run with `--dataset.type "llava-lrv"`
|
| 342 |
-
@dataclass
|
| 343 |
-
class Exp_7B_LLaVa_LRV(Exp_7B_One_Stage):
|
| 344 |
-
model_id: str = "llava-lrv+7b"
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
# =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"`
|
| 348 |
-
@dataclass
|
| 349 |
-
class Exp_7B_LLaVa_LVIS4V_LRV(Exp_7B_One_Stage):
|
| 350 |
-
model_id: str = "llava-lvis4v-lrv+7b"
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
# === Section 5 :: Prisms ===
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
# Prism-CLIP
|
| 357 |
-
@dataclass
|
| 358 |
-
class Prism_7B_CLIP_Controlled(Exp_7B_One_Stage):
|
| 359 |
-
model_id: str = "prism-clip-controlled+7b"
|
| 360 |
-
vision_backbone_id: str = "clip-vit-l-336px"
|
| 361 |
-
image_resize_strategy: str = "resize-naive"
|
| 362 |
-
llm_backbone_id: str = "llama2-7b-pure"
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
@dataclass
|
| 366 |
-
class Prism_13B_CLIP_Controlled(Exp_13B_One_Stage):
|
| 367 |
-
model_id: str = "prism-clip-controlled+13b"
|
| 368 |
-
vision_backbone_id: str = "clip-vit-l-336px"
|
| 369 |
-
image_resize_strategy: str = "resize-naive"
|
| 370 |
-
llm_backbone_id: str = "llama2-13b-pure"
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
# =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"`
|
| 374 |
-
@dataclass
|
| 375 |
-
class Prism_7B_CLIP(Exp_7B_One_Stage):
|
| 376 |
-
model_id: str = "prism-clip+7b"
|
| 377 |
-
vision_backbone_id: str = "clip-vit-l-336px"
|
| 378 |
-
image_resize_strategy: str = "resize-naive"
|
| 379 |
-
llm_backbone_id: str = "llama2-7b-pure"
|
| 380 |
-
finetune_epochs: int = 2
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
# =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"`
|
| 384 |
-
@dataclass
|
| 385 |
-
class Prism_13B_CLIP(Exp_13B_One_Stage):
|
| 386 |
-
model_id: str = "prism-clip+13b"
|
| 387 |
-
vision_backbone_id: str = "clip-vit-l-336px"
|
| 388 |
-
image_resize_strategy: str = "resize-naive"
|
| 389 |
-
llm_backbone_id: str = "llama2-13b-pure"
|
| 390 |
-
finetune_epochs: int = 2
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
# Prism-SigLIP
|
| 394 |
-
@dataclass
|
| 395 |
-
class Prism_7B_SigLIP_Controlled(Exp_7B_One_Stage):
|
| 396 |
-
model_id: str = "prism-siglip-controlled+7b"
|
| 397 |
-
vision_backbone_id: str = "siglip-vit-so400m-384px"
|
| 398 |
-
image_resize_strategy: str = "resize-naive"
|
| 399 |
-
llm_backbone_id: str = "llama2-7b-pure"
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
@dataclass
|
| 403 |
-
class Prism_13B_SigLIP_Controlled(Exp_13B_One_Stage):
|
| 404 |
-
model_id: str = "prism-siglip-controlled+13b"
|
| 405 |
-
vision_backbone_id: str = "siglip-vit-so400m-384px"
|
| 406 |
-
image_resize_strategy: str = "resize-naive"
|
| 407 |
-
llm_backbone_id: str = "llama2-13b-pure"
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
# =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"`
|
| 411 |
-
@dataclass
|
| 412 |
-
class Prism_7B_SigLIP(Exp_7B_One_Stage):
|
| 413 |
-
model_id: str = "prism-siglip+7b"
|
| 414 |
-
vision_backbone_id: str = "siglip-vit-so400m-384px"
|
| 415 |
-
image_resize_strategy: str = "resize-naive"
|
| 416 |
-
llm_backbone_id: str = "llama2-7b-pure"
|
| 417 |
-
finetune_epochs: int = 2
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
# =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"`
|
| 421 |
-
@dataclass
|
| 422 |
-
class Prism_13B_SigLIP(Exp_13B_One_Stage):
|
| 423 |
-
model_id: str = "prism-siglip+13b"
|
| 424 |
-
vision_backbone_id: str = "clip-vit-l-336px"
|
| 425 |
-
image_resize_strategy: str = "resize-naive"
|
| 426 |
-
llm_backbone_id: str = "llama2-13b-pure"
|
| 427 |
-
finetune_epochs: int = 2
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
# Prism-DINOSigLIP
|
| 431 |
-
@dataclass
|
| 432 |
-
class Prism_7B_DINOSigLIP_Controlled(Exp_7B_One_Stage):
|
| 433 |
-
model_id: str = "prism-dinosiglip-controlled+7b"
|
| 434 |
-
vision_backbone_id: str = "dinosiglip-vit-so-384px"
|
| 435 |
-
image_resize_strategy: str = "resize-naive"
|
| 436 |
-
llm_backbone_id: str = "llama2-7b-pure"
|
| 437 |
-
arch_specifier: str = "no-align+fused-gelu-mlp"
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
@dataclass
|
| 441 |
-
class Prism_13B_DINOSigLIP_Controlled(Exp_13B_One_Stage):
|
| 442 |
-
model_id: str = "prism-dinosiglip-controlled+13b"
|
| 443 |
-
vision_backbone_id: str = "dinosiglip-vit-so-384px"
|
| 444 |
-
image_resize_strategy: str = "resize-naive"
|
| 445 |
-
llm_backbone_id: str = "llama2-13b-pure"
|
| 446 |
-
arch_specifier: str = "no-align+fused-gelu-mlp"
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
# =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"`
|
| 450 |
-
@dataclass
|
| 451 |
-
class Prism_7B_DINOSigLIP(Exp_7B_One_Stage):
|
| 452 |
-
model_id: str = "prism-dinosiglip+7b"
|
| 453 |
-
vision_backbone_id: str = "dinosiglip-vit-so-384px"
|
| 454 |
-
image_resize_strategy: str = "resize-naive"
|
| 455 |
-
llm_backbone_id: str = "llama2-7b-pure"
|
| 456 |
-
arch_specifier: str = "no-align+fused-gelu-mlp"
|
| 457 |
-
finetune_epochs: int = 2
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
# =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"`
|
| 461 |
-
@dataclass
|
| 462 |
-
class Prism_13B_DINOSigLIP(Exp_13B_One_Stage):
|
| 463 |
-
model_id: str = "prism-dinosiglip+13b"
|
| 464 |
-
vision_backbone_id: str = "dinosiglip-vit-so-384px"
|
| 465 |
-
image_resize_strategy: str = "resize-naive"
|
| 466 |
-
llm_backbone_id: str = "llama2-13b-pure"
|
| 467 |
-
arch_specifier: str = "no-align+fused-gelu-mlp"
|
| 468 |
-
finetune_epochs: int = 2
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
# [Inference-Optimized] 224px Prisms
|
| 472 |
-
@dataclass
|
| 473 |
-
class Opt_7B_DINOSigLIP_ViT_SO_p14_224px_Resize_Naive(Exp_7B_One_Stage):
|
| 474 |
-
model_id: str = "dinosiglip-224px-resize-naive+7b"
|
| 475 |
-
vision_backbone_id: str = "dinosiglip-vit-so-224px"
|
| 476 |
-
image_resize_strategy: str = "resize-naive"
|
| 477 |
-
arch_specifier: str = "no-align+fused-gelu-mlp"
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
@dataclass
|
| 481 |
-
class Prism_7B_DINOSigLIP_224px_Controlled(Exp_7B_One_Stage):
|
| 482 |
-
model_id: str = "prism-dinosiglip-224px-controlled+7b"
|
| 483 |
-
vision_backbone_id: str = "dinosiglip-vit-so-224px"
|
| 484 |
-
image_resize_strategy: str = "resize-naive"
|
| 485 |
-
llm_backbone_id: str = "llama2-7b-pure"
|
| 486 |
-
arch_specifier: str = "no-align+fused-gelu-mlp"
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
# =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"`
|
| 490 |
-
@dataclass
|
| 491 |
-
class Prism_7B_DINOSigLIP_224px(Exp_7B_One_Stage):
|
| 492 |
-
model_id: str = "prism-dinosiglip-224px+7b"
|
| 493 |
-
vision_backbone_id: str = "dinosiglip-vit-so-224px"
|
| 494 |
-
image_resize_strategy: str = "resize-naive"
|
| 495 |
-
llm_backbone_id: str = "llama2-7b-pure"
|
| 496 |
-
arch_specifier: str = "no-align+fused-gelu-mlp"
|
| 497 |
-
finetune_epochs: int = 2
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
# === Define a Model Registry Enum for Reference & Validation ===
|
| 501 |
-
@unique
|
| 502 |
-
class ModelRegistry(Enum):
|
| 503 |
-
# === LLaVa v1.5 Base Reproductions ===
|
| 504 |
-
REPRODUCTION_7B = LLaVa_v15_Reproduction_7B
|
| 505 |
-
REPRODUCTION_13B = LLaVa_v15_Reproduction_13B
|
| 506 |
-
|
| 507 |
-
# === Section 4.1 :: Optimization Procedure ===
|
| 508 |
-
EXP_ONE_STAGE_7B = Exp_7B_One_Stage
|
| 509 |
-
EXP_ONE_STAGE_13B = Exp_13B_One_Stage
|
| 510 |
-
|
| 511 |
-
EXP_FULL_FT_MULTI_STAGE = Exp_7B_Full_Finetune_Multi_Stage
|
| 512 |
-
EXP_FULL_FT_ONE_STAGE = Exp_7B_Full_Finetune_One_Stage
|
| 513 |
-
|
| 514 |
-
# === Section 4.2 :: Image Processing and Visual Representations ===
|
| 515 |
-
EXP_IN1K_224PX = Exp_7B_IN1K_ViT_L_p16_224px
|
| 516 |
-
EXP_DINOV2_224PX = Exp_7B_DINOv2_ViT_L_p14_224px
|
| 517 |
-
EXP_CLIP_224PX = Exp_7B_CLIP_ViT_L_p14_224px
|
| 518 |
-
EXP_SIGLIP_224PX = Exp_7B_SigLIP_ViT_SO_p14_224px
|
| 519 |
-
|
| 520 |
-
EXP_CLIP_336PX_RESIZE_CROP = Exp_7B_CLIP_ViT_L_p14_336px_Resize_Crop
|
| 521 |
-
EXP_CLIP_336PX_RESIZE_NAIVE = Exp_7B_CLIP_ViT_L_p14_336px_Resize_Naive
|
| 522 |
-
EXP_SIGLIP_384PX_LETTERBOX = Exp_7B_SigLIP_ViT_SO_p14_384px_Letterbox
|
| 523 |
-
EXP_SIGLIP_384PX_RESIZE_CROP = Exp_7B_SigLIP_ViT_SO_p14_384px_Resize_Crop
|
| 524 |
-
EXP_SIGLIP_384PX_RESIZE_NAIVE = Exp_7B_SigLIP_ViT_SO_p14_384px_Resize_Naive
|
| 525 |
-
|
| 526 |
-
EXP_DINOCLIP_336PX_LETTERBOX = Exp_7B_DINOCLIP_ViT_L_p14_336px_Letterbox
|
| 527 |
-
EXP_DINOCLIP_336PX_RESIZE_NAIVE = Exp_7B_DINOCLIP_ViT_L_p14_336px_Resize_Naive
|
| 528 |
-
EXP_DINOSIGLIP_384PX_LETTERBOX = Exp_7B_DINOSigLIP_ViT_L_p14_384px_Letterbox
|
| 529 |
-
EXP_DINOSIGLIP_384PX_RESIZE_NAIVE = Exp_7B_DINOSigLIP_ViT_L_p14_384px_Resize_Naive
|
| 530 |
-
|
| 531 |
-
# === Section 4.3 :: Language Models ===
|
| 532 |
-
EXP_LLAMA2_7B = Exp_7B_Llama2
|
| 533 |
-
EXP_LLAMA2_13B = Exp_13B_Llama2
|
| 534 |
-
|
| 535 |
-
# ~ Additional LLM Backbone Experiments :: LLaMa-2 Chat, Mistral v0.1, Mistral v0.1 Instruct ~
|
| 536 |
-
EXT_EXP_LLAMA2_CHAT_7B = Ext_Exp_7B_Llama2_Chat
|
| 537 |
-
EXT_EXP_LLAMA2_CHAT_13B = Ext_Exp_13B_Llama2_Chat
|
| 538 |
-
EXT_EXP_MISTRAL_V1_7B = Ext_Exp_7B_Mistral_V1
|
| 539 |
-
EXT_EXP_MISTRAL_INSTRUCT_V1_7B = Ext_Exp_7B_Mistral_Instruct_V1
|
| 540 |
-
EXT_EXP_PHI_2_3B = Ext_Exp_3B_Phi_2
|
| 541 |
-
|
| 542 |
-
# Cotraining w/ Unimodal Data
|
| 543 |
-
EXP_VICUNA_NO_COTRAINING_7B = Exp_7B_Vicuna_No_Cotraining
|
| 544 |
-
EXP_LLAMA2_NO_COTRAINING_7B = Exp_7B_Llama2_No_Cotraining
|
| 545 |
-
|
| 546 |
-
# === Section 4.4 :: Scaling Properties - Train Time & Data ===
|
| 547 |
-
EXP_1P25_EPOCHS = Exp_7B_1p25_Epochs
|
| 548 |
-
EXP_1P5_EPOCHS = Exp_7B_1p5_Epochs
|
| 549 |
-
EXP_2_EPOCHS = Exp_7B_2_Epochs
|
| 550 |
-
EXP_3_EPOCHS = Exp_7B_3_Epochs
|
| 551 |
-
|
| 552 |
-
EXP_LLAVA_LVIS4V = Exp_7B_LLaVa_LVIS4V
|
| 553 |
-
EXP_LLAVA_LRV = Exp_7B_LLaVa_LRV
|
| 554 |
-
EXP_LLAVA_LVIS4V_LRV = Exp_7B_LLaVa_LVIS4V_LRV
|
| 555 |
-
|
| 556 |
-
# === Section 5 :: Prisms ===
|
| 557 |
-
PRISM_CLIP_CONTROLLED_7B = Prism_7B_CLIP_Controlled
|
| 558 |
-
PRISM_CLIP_CONTROLLED_13B = Prism_13B_CLIP_Controlled
|
| 559 |
-
PRISM_CLIP_7B = Prism_7B_CLIP
|
| 560 |
-
PRISM_CLIP_13B = Prism_13B_CLIP
|
| 561 |
-
|
| 562 |
-
PRISM_SIGLIP_CONTROLLED_7B = Prism_7B_SigLIP_Controlled
|
| 563 |
-
PRISM_SIGLIP_CONTROLLED_13B = Prism_13B_SigLIP_Controlled
|
| 564 |
-
PRISM_SIGLIP_7B = Prism_7B_SigLIP
|
| 565 |
-
PRISM_SIGLIP_13B = Prism_13B_SigLIP
|
| 566 |
-
|
| 567 |
-
PRISM_DINOSIGLIP_CONTROLLED_7B = Prism_7B_DINOSigLIP_Controlled
|
| 568 |
-
PRISM_DINOSIGLIP_CONTROLLED_13B = Prism_13B_DINOSigLIP_Controlled
|
| 569 |
-
PRISM_DINOSIGLIP_7B = Prism_7B_DINOSigLIP
|
| 570 |
-
PRISM_DINOSIGLIP_13B = Prism_13B_DINOSigLIP
|
| 571 |
-
|
| 572 |
-
# === Inference Optimized :: 224px Prisms ===
|
| 573 |
-
OPT_DINOSIGLIP_224PX_RESIZE_NAIVE = Opt_7B_DINOSigLIP_ViT_SO_p14_224px_Resize_Naive
|
| 574 |
-
PRISM_DINOSIGLIP_224PX_CONTROLLED_7B = Prism_7B_DINOSigLIP_224px_Controlled
|
| 575 |
-
PRISM_DINOSIGLIP_224PX_7B = Prism_7B_DINOSigLIP_224px
|
| 576 |
-
|
| 577 |
-
@property
|
| 578 |
-
def model_id(self) -> str:
|
| 579 |
-
return self.value.model_id
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
# Register Models in Choice Registry
|
| 583 |
-
for model_variant in ModelRegistry:
|
| 584 |
-
ModelConfig.register_subclass(model_variant.model_id, model_variant.value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/prismatic/conf/vla.py
DELETED
|
@@ -1,235 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
vla.py
|
| 3 |
-
|
| 4 |
-
Draccus Dataclass Definition for a VLAConfig object, with various registered subclasses for each VLA experiment and
|
| 5 |
-
model configuration thereof. A given VLA model (`policy`) configures the following attributes:
|
| 6 |
-
- Data Mixture (e.g., Bridge, OXE_MAGIC_SOUP, etc.)
|
| 7 |
-
- Base VLM from Prismatic Registry (e.g., `prism-dinosiglip+7b`)
|
| 8 |
-
- VLA Model Architecture / Parameters (e.g., freeze vision encoder, last layer finetuning)
|
| 9 |
-
- Training / Optimization Hyperparameters
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
from dataclasses import dataclass
|
| 13 |
-
from enum import Enum, unique
|
| 14 |
-
from pathlib import Path
|
| 15 |
-
from typing import Optional, Union
|
| 16 |
-
|
| 17 |
-
from draccus import ChoiceRegistry
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
@dataclass
|
| 21 |
-
class VLAConfig(ChoiceRegistry):
|
| 22 |
-
# fmt: off
|
| 23 |
-
vla_id: str # Unique VLA Policy ID that fully specifies a configuration variant
|
| 24 |
-
base_vlm: Union[str, Path] # Base VLM as ID/Path to Run Directory (e.g., `prism-dinosiglip+7b`)
|
| 25 |
-
freeze_vision_backbone: bool # Freeze Vision Backbone Parameters (akin to pretraining)
|
| 26 |
-
freeze_llm_backbone: bool # Freeze LLM Backbone parameters
|
| 27 |
-
unfreeze_last_llm_layer: bool # Unfreeze final layer of LLM (only takes effect if LLM is frozen)
|
| 28 |
-
|
| 29 |
-
# Data Mixture Parameters
|
| 30 |
-
data_mix: str # Open-X Embodiment Dataset =>> Unique Mixture ID (e.g., `bridge`)
|
| 31 |
-
shuffle_buffer_size: int # Size of Shuffle Buffer (100K for Bridge, 1M for OXE)
|
| 32 |
-
|
| 33 |
-
# Optimization Parameters
|
| 34 |
-
epochs: int # Epochs to Run (in case `max_steps` is not specified)
|
| 35 |
-
max_steps: Optional[int] # [Optional] Max Gradient Steps to Run (overrides `epochs`)
|
| 36 |
-
|
| 37 |
-
expected_world_size: int # Expected # of GPUs =>> allows us to gate training on hardware
|
| 38 |
-
global_batch_size: int # Global Batch Size (divided across processes / world size)
|
| 39 |
-
per_device_batch_size: int # Per-Device Batch Size (per-process / individual GPU)
|
| 40 |
-
# =>> # of accumulation steps is auto-computed
|
| 41 |
-
|
| 42 |
-
learning_rate: float # Peak Learning Rate (`lr_scheduler_type` sets warmup/decay)
|
| 43 |
-
weight_decay: float # Weight Decay for AdamW Optimizer
|
| 44 |
-
max_grad_norm: float # Max Grad Norm (for global gradient clipping)
|
| 45 |
-
lr_scheduler_type: str # LR Scheduler (usually: "constant" | "linear-warmup+cosine-decay")
|
| 46 |
-
warmup_ratio: float # Fraction of Steps to Warmup (for warmup LR schedulers)
|
| 47 |
-
|
| 48 |
-
train_strategy: str # Train Strategy (default "fsdp-full-shard")
|
| 49 |
-
|
| 50 |
-
# Enable Gradient/Activation Checkpointing (for the LLM Backbone)
|
| 51 |
-
enable_gradient_checkpointing: bool = True # Enable Gradient/Activation Checkpointing during Training
|
| 52 |
-
|
| 53 |
-
# Mixed Precision Training via Torch Native AMP (`autocast`)
|
| 54 |
-
enable_mixed_precision_training: bool = True # Enable Traditional BF16 Mixed Precision
|
| 55 |
-
reduce_in_full_precision: bool = True # Accumulate/Reduce All-Gather Gradients in FP32 Full Precision
|
| 56 |
-
|
| 57 |
-
# fmt: on
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
# === OpenVLA Training Configurations ===
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
# = [8 GPU] Fast Iteration =>> SigLIP 224px + Bridge =
|
| 64 |
-
@dataclass
|
| 65 |
-
class Exp_SigLIP_224px_Bridge(VLAConfig):
|
| 66 |
-
vla_id: str = "siglip-224px+mx-bridge"
|
| 67 |
-
base_vlm: Union[str, Path] = "siglip-224px+7b"
|
| 68 |
-
|
| 69 |
-
freeze_vision_backbone: bool = False
|
| 70 |
-
freeze_llm_backbone: bool = False
|
| 71 |
-
unfreeze_last_llm_layer: bool = False
|
| 72 |
-
|
| 73 |
-
# Data Mixture Parameters
|
| 74 |
-
data_mix: str = "bridge"
|
| 75 |
-
shuffle_buffer_size: int = 256_000
|
| 76 |
-
|
| 77 |
-
# Optimization Parameters
|
| 78 |
-
epochs: int = 1000
|
| 79 |
-
max_steps: Optional[int] = None
|
| 80 |
-
|
| 81 |
-
expected_world_size: int = 8
|
| 82 |
-
global_batch_size: int = 256
|
| 83 |
-
per_device_batch_size: int = 32
|
| 84 |
-
|
| 85 |
-
learning_rate: float = 2e-5
|
| 86 |
-
weight_decay: float = 0.0
|
| 87 |
-
max_grad_norm: float = 1.0
|
| 88 |
-
lr_scheduler_type: str = "constant"
|
| 89 |
-
warmup_ratio: float = 0.0
|
| 90 |
-
|
| 91 |
-
train_strategy: str = "fsdp-full-shard"
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
# = [8 GPU] SigLIP 224px Frozen Vision Backbone + Bridge =
|
| 95 |
-
@dataclass
|
| 96 |
-
class Exp_FreezeVIT_SigLIP_224px_Bridge(Exp_SigLIP_224px_Bridge):
|
| 97 |
-
vla_id: str = "siglip-224px-icy+mx-bridge"
|
| 98 |
-
base_vlm: Union[str, Path] = "siglip-224px+7b"
|
| 99 |
-
freeze_vision_backbone: bool = True
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
# = [8 GPU] Fast Iteration =>> DINO-SigLIP 224px + Bridge =
|
| 103 |
-
@dataclass
|
| 104 |
-
class Exp_DinoSigLIP_224px_Bridge(Exp_SigLIP_224px_Bridge):
|
| 105 |
-
vla_id: str = "prism-dinosiglip-224px+mx-bridge"
|
| 106 |
-
base_vlm: Union[str, Path] = "prism-dinosiglip-224px+7b"
|
| 107 |
-
|
| 108 |
-
data_mix: str = "bridge"
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
# = [64 GPU] SigLIP 224px + OXE Magic Soup =
|
| 112 |
-
@dataclass
|
| 113 |
-
class Exp_SigLIP_224px_OXE_Magic_Soup(Exp_SigLIP_224px_Bridge):
|
| 114 |
-
vla_id: str = "siglip-224px+mx-oxe-magic-soup"
|
| 115 |
-
base_vlm: Union[str, Path] = "siglip-224px+7b"
|
| 116 |
-
|
| 117 |
-
data_mix: str = "oxe_magic_soup"
|
| 118 |
-
|
| 119 |
-
expected_world_size: int = 64
|
| 120 |
-
global_batch_size: int = 2048
|
| 121 |
-
per_device_batch_size: int = 32
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
# = [64 GPU] DINO-SigLIP 224px + OXE Magic Soup++ =
|
| 125 |
-
@dataclass
|
| 126 |
-
class Exp_DinoSigLIP_224px_OXE_Magic_Soup_Plus(Exp_SigLIP_224px_Bridge):
|
| 127 |
-
vla_id: str = "prism-dinosiglip-224px+mx-oxe-magic-soup-plus"
|
| 128 |
-
base_vlm: Union[str, Path] = "prism-dinosiglip-224px+7b"
|
| 129 |
-
|
| 130 |
-
# Note =>> We adopt two stages, training on a mixture including DROID for 70% of training, before resampling!
|
| 131 |
-
# data_mix: str = "oxe_magic_soup_plus"
|
| 132 |
-
data_mix: str = "oxe_magic_soup_plus_minus"
|
| 133 |
-
|
| 134 |
-
expected_world_size: int = 64
|
| 135 |
-
global_batch_size: int = 2048
|
| 136 |
-
per_device_batch_size: int = 32
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
# === OpenVLA Fine-tuning Configurations ===
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
# = [8 GPU] SigLIP 224px + T-DROID =
|
| 143 |
-
@dataclass
|
| 144 |
-
class Exp_SigLIP_224px_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge):
|
| 145 |
-
vla_id: str = "siglip-224px+mx-tdroid_carrot_in_bowl"
|
| 146 |
-
base_vlm: Union[str, Path] = "siglip-224px+7b"
|
| 147 |
-
|
| 148 |
-
data_mix: str = "tdroid_carrot_in_bowl"
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
@dataclass
|
| 152 |
-
class Exp_SigLIP_224px_TDROID_PourCornInPot(Exp_SigLIP_224px_Bridge):
|
| 153 |
-
vla_id: str = "siglip-224px+mx-tdroid_pour_corn_in_pot"
|
| 154 |
-
base_vlm: Union[str, Path] = "siglip-224px+7b"
|
| 155 |
-
|
| 156 |
-
data_mix: str = "tdroid_pour_corn_in_pot"
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
# = [8 GPU] SigLIP 224px + T-DROID -- Partial Finetuning =
|
| 160 |
-
@dataclass
|
| 161 |
-
class Exp_SigLIP_224px_Icy_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge):
|
| 162 |
-
vla_id: str = "siglip-224px-icy+mx-tdroid_carrot_in_bowl"
|
| 163 |
-
base_vlm: Union[str, Path] = "siglip-224px+7b"
|
| 164 |
-
freeze_vision_backbone: bool = True
|
| 165 |
-
freeze_llm_backbone: bool = False
|
| 166 |
-
|
| 167 |
-
data_mix: str = "tdroid_carrot_in_bowl"
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
@dataclass
|
| 171 |
-
class Exp_SigLIP_224px_LastLayer_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge):
|
| 172 |
-
vla_id: str = "siglip-224px-last_layer+mx-tdroid_carrot_in_bowl"
|
| 173 |
-
base_vlm: Union[str, Path] = "siglip-224px+7b"
|
| 174 |
-
freeze_vision_backbone: bool = True
|
| 175 |
-
freeze_llm_backbone: bool = True
|
| 176 |
-
unfreeze_last_llm_layer: bool = True
|
| 177 |
-
|
| 178 |
-
data_mix: str = "tdroid_carrot_in_bowl"
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
@dataclass
|
| 182 |
-
class Exp_SigLIP_224px_Sandwich_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge):
|
| 183 |
-
vla_id: str = "siglip-224px-sandwich+mx-tdroid_carrot_in_bowl"
|
| 184 |
-
base_vlm: Union[str, Path] = "siglip-224px+7b"
|
| 185 |
-
freeze_vision_backbone: bool = False
|
| 186 |
-
freeze_llm_backbone: bool = True
|
| 187 |
-
unfreeze_last_llm_layer: bool = True
|
| 188 |
-
|
| 189 |
-
data_mix: str = "tdroid_carrot_in_bowl"
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
# === [8 GPU] SigLIP 224px + FrankaWipe ===
|
| 193 |
-
@dataclass
|
| 194 |
-
class Exp_SigLIP_224px_Droid_Wipe(Exp_SigLIP_224px_Bridge):
|
| 195 |
-
vla_id: str = "siglip-224px+mx-droid_wipe"
|
| 196 |
-
base_vlm: Union[str, Path] = "siglip-224px+7b"
|
| 197 |
-
|
| 198 |
-
data_mix: str = "droid_wipe"
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
# === Define a VLA Registry Enum for Reference & Validation ===
|
| 202 |
-
@unique
|
| 203 |
-
class VLARegistry(Enum):
|
| 204 |
-
# Sanity Check Configurations =>> BridgeV2
|
| 205 |
-
SIGLIP_224PX_MX_BRIDGE = Exp_SigLIP_224px_Bridge
|
| 206 |
-
DINOSIGLIP_224PX_MX_BRIDGE = Exp_DinoSigLIP_224px_Bridge
|
| 207 |
-
|
| 208 |
-
# SigLIP Frozen Backbone Experiment
|
| 209 |
-
FREEZE_SIGLIP_224PX_MX_BRIDGE = Exp_FreezeVIT_SigLIP_224px_Bridge
|
| 210 |
-
|
| 211 |
-
# [OpenVLA v0.1 7B] SigLIP 224px + OXE Magic Soup
|
| 212 |
-
SIGLIP_224PX_MX_OXE_MAGIC_SOUP = Exp_SigLIP_224px_OXE_Magic_Soup
|
| 213 |
-
|
| 214 |
-
# [OpenVLA 7B] DINO + SigLIP 224px + OXE Magic Soup++
|
| 215 |
-
DINOSIGLIP_224PX_MX_OXE_MAGIC_SOUP_PLUS = Exp_DinoSigLIP_224px_OXE_Magic_Soup_Plus
|
| 216 |
-
|
| 217 |
-
# === TDROID Fine-tuning Configs ===
|
| 218 |
-
SIGLIP_224PX_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_TDROID_CarrotInBowl
|
| 219 |
-
SIGLIP_224PX_MX_TDROID_POUR_CORN_IN_POT = Exp_SigLIP_224px_TDROID_PourCornInPot
|
| 220 |
-
|
| 221 |
-
SIGLIP_224PX_ICY_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_Icy_TDROID_CarrotInBowl
|
| 222 |
-
SIGLIP_224PX_LASTLAYER_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_LastLayer_TDROID_CarrotInBowl
|
| 223 |
-
SIGLIP_224PX_SANDWICH_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_Sandwich_TDROID_CarrotInBowl
|
| 224 |
-
|
| 225 |
-
# === DROID Fine-tuning Configs ===
|
| 226 |
-
SIGLIP_224PX_MX_DROID_WIPE = Exp_SigLIP_224px_Droid_Wipe
|
| 227 |
-
|
| 228 |
-
@property
|
| 229 |
-
def vla_id(self) -> str:
|
| 230 |
-
return self.value.vla_id
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
# Register VLAs in Choice Registry
|
| 234 |
-
for vla_variant in VLARegistry:
|
| 235 |
-
VLAConfig.register_subclass(vla_variant.vla_id, vla_variant.value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/prismatic/extern/__init__.py
DELETED
|
File without changes
|
capvector-oft/prismatic/extern/hf/__init__.py
DELETED
|
File without changes
|
capvector-oft/prismatic/extern/hf/configuration_prismatic.py
DELETED
|
@@ -1,140 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
configuration_prismatic.py
|
| 3 |
-
|
| 4 |
-
HuggingFace-style configuration definition for Prismatic VLMs, inheriting from `transformers.PretrainedConfig`.
|
| 5 |
-
Default configuration specifies `siglip-224px+7b`.
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
from typing import Any, Dict, List, Optional
|
| 9 |
-
|
| 10 |
-
from transformers import PretrainedConfig
|
| 11 |
-
from transformers.models.auto import CONFIG_MAPPING
|
| 12 |
-
|
| 13 |
-
# === Utilities for Mapping Prismatic names to HF names ===
|
| 14 |
-
# fmt: off
|
| 15 |
-
VISION_BACKBONE_TO_RESOLUTION: Dict[str, List[int]] = {
|
| 16 |
-
"clip-vit-l": [224], "siglip-vit-so400m": [224], "dinov2-vit-l": [224], "in1k-vit-l": [224],
|
| 17 |
-
|
| 18 |
-
"clip-vit-l-336px": [336],
|
| 19 |
-
"siglip-vit-so400m-384px": [384],
|
| 20 |
-
|
| 21 |
-
"dinoclip-vit-l-336px": [336, 336],
|
| 22 |
-
"dinosiglip-vit-so-224px": [224, 224],
|
| 23 |
-
"dinosiglip-vit-so-384px": [384, 384],
|
| 24 |
-
}
|
| 25 |
-
VISION_BACKBONE_TO_TIMM_ID: Dict[str, List[str]] = {
|
| 26 |
-
"clip-vit-l": ["vit_large_patch14_clip_224.openai"],
|
| 27 |
-
"clip-vit-l-336px": ["vit_large_patch14_clip_336.openai"],
|
| 28 |
-
|
| 29 |
-
"dinov2-vit-l": ["vit_large_patch14_reg4_dinov2.lvd142m"],
|
| 30 |
-
"in1k-vit-l": ["vit_large_patch16_224.augreg_in21k_ft_in1k"],
|
| 31 |
-
|
| 32 |
-
"siglip-vit-so400m": ["vit_so400m_patch14_siglip_224"],
|
| 33 |
-
"siglip-vit-so400m-384px": ["vit_so400m_patch14_siglip_384"],
|
| 34 |
-
|
| 35 |
-
"dinoclip-vit-l-336px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_large_patch14_clip_336.openai"],
|
| 36 |
-
"dinosiglip-vit-so-224px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_224"],
|
| 37 |
-
"dinosiglip-vit-so-384px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_384"],
|
| 38 |
-
}
|
| 39 |
-
TIMM_OVERRIDE_ACT_LAYER: Dict[str, List[Optional[str]]] = {
|
| 40 |
-
"clip-vit-l": ["quick_gelu"], "clip-vit-l-336px": ["quick_gelu"],
|
| 41 |
-
"dinov2-vit-l": [None], "in1k-vit-l": [None],
|
| 42 |
-
"siglip-vit-so400m": [None], "siglip-vit-so400m-384px": [None],
|
| 43 |
-
"dinoclip-vit-l-336px": [None, "quick_gelu"],
|
| 44 |
-
"dinosiglip-vit-so-224px": [None, None], "dinosiglip-vit-so-384px": [None, None]
|
| 45 |
-
}
|
| 46 |
-
|
| 47 |
-
LLM_BACKBONE_TO_HF_PATH = {
|
| 48 |
-
"llama2-7b-pure": "meta-llama/Llama-2-7b-hf", "llama2-13b-pure": "meta-llama/Llama-2-13b-hf",
|
| 49 |
-
"llama2-7b-chat": "meta-llama/Llama-2-7b-chat-hf", "llama2-13b-chat": "meta-llama/Llama-2-13b-chat-hf",
|
| 50 |
-
|
| 51 |
-
"vicuna-v15-7b": "lmsys/vicuna-7b-v1.5", "vicuna-v15-13b": "lmsys/vicuna-13b-v1.5",
|
| 52 |
-
|
| 53 |
-
"mistral-v0.1-7b-pure": "mistralai/Mistral-7B-v0.1",
|
| 54 |
-
"mistral-v0.1-7b-instruct": "mistralai/Mistral-7B-Instruct-v0.1",
|
| 55 |
-
|
| 56 |
-
"phi-2-3b": "microsoft/phi-2",
|
| 57 |
-
}
|
| 58 |
-
LLM_BACKBONE_TO_HF_METACLASS = {
|
| 59 |
-
"llama2-7b-pure": "llama", "llama2-13b-pure": "llama", "llama2-7b-chat": "llama", "llama2-13b-chat": "llama",
|
| 60 |
-
"vicuna-v15-7b": "llama", "vicuna-v15-13b": "llama",
|
| 61 |
-
|
| 62 |
-
"mistral-v0.1-7b-pure": "mistral", "mistral-v0.1-7b-instruct": "mistral",
|
| 63 |
-
|
| 64 |
-
"phi-2-3b": "phi",
|
| 65 |
-
}
|
| 66 |
-
|
| 67 |
-
VALID_VISION_BACKBONES = set(VISION_BACKBONE_TO_RESOLUTION.keys())
|
| 68 |
-
VALID_LLM_BACKBONES = set(LLM_BACKBONE_TO_HF_PATH)
|
| 69 |
-
# fmt: on
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
class PrismaticConfig(PretrainedConfig):
|
| 73 |
-
model_type: str = "prismatic"
|
| 74 |
-
is_composition: bool = False
|
| 75 |
-
|
| 76 |
-
def __init__(
|
| 77 |
-
self,
|
| 78 |
-
vision_backbone_id: str = "siglip-vit-so400m",
|
| 79 |
-
llm_backbone_id: str = "vicuna-v15-7b",
|
| 80 |
-
arch_specifier: str = "no-align+gelu-mlp",
|
| 81 |
-
use_fused_vision_backbone: Optional[bool] = None,
|
| 82 |
-
image_resize_strategy: str = "letterbox",
|
| 83 |
-
text_config: Optional[Dict[str, Any]] = None,
|
| 84 |
-
llm_max_length: int = 2048,
|
| 85 |
-
pad_token_id: int = 32000,
|
| 86 |
-
pad_to_multiple_of: int = 64,
|
| 87 |
-
output_projector_states: bool = False,
|
| 88 |
-
**kwargs: str,
|
| 89 |
-
) -> None:
|
| 90 |
-
if vision_backbone_id not in VALID_VISION_BACKBONES:
|
| 91 |
-
raise ValueError(f"Vision backbone `{vision_backbone_id}` not in {VALID_VISION_BACKBONES = }")
|
| 92 |
-
|
| 93 |
-
if llm_backbone_id not in VALID_LLM_BACKBONES:
|
| 94 |
-
raise ValueError(f"LLM backbone `{llm_backbone_id}` not in {VALID_LLM_BACKBONES = }")
|
| 95 |
-
|
| 96 |
-
# Set Prismatic Configuration Fields
|
| 97 |
-
self.vision_backbone_id = vision_backbone_id
|
| 98 |
-
self.llm_backbone_id = llm_backbone_id
|
| 99 |
-
self.arch_specifier = arch_specifier
|
| 100 |
-
self.output_projector_states = output_projector_states
|
| 101 |
-
|
| 102 |
-
# [Contract] All vision backbone parameters are lists =>> supports fused backbones with different preprocessing
|
| 103 |
-
self.use_fused_vision_backbone = (
|
| 104 |
-
use_fused_vision_backbone
|
| 105 |
-
if use_fused_vision_backbone is not None
|
| 106 |
-
else any(self.vision_backbone_id.startswith(v) for v in ["dinoclip", "dinosiglip"])
|
| 107 |
-
)
|
| 108 |
-
|
| 109 |
-
self.timm_model_ids = VISION_BACKBONE_TO_TIMM_ID[self.vision_backbone_id]
|
| 110 |
-
self.timm_override_act_layers = TIMM_OVERRIDE_ACT_LAYER[self.vision_backbone_id]
|
| 111 |
-
self.image_sizes = VISION_BACKBONE_TO_RESOLUTION[self.vision_backbone_id]
|
| 112 |
-
self.image_resize_strategy = image_resize_strategy
|
| 113 |
-
|
| 114 |
-
self.hf_llm_id = LLM_BACKBONE_TO_HF_PATH[self.llm_backbone_id]
|
| 115 |
-
self.llm_max_length = llm_max_length
|
| 116 |
-
self.pad_token_id, self.pad_to_multiple_of = pad_token_id, pad_to_multiple_of
|
| 117 |
-
|
| 118 |
-
# [IMPORTANT] HF Utilities actually look for a `text_config` field... we need to use that specific naming!
|
| 119 |
-
self.text_config = (
|
| 120 |
-
CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]](**text_config)
|
| 121 |
-
if text_config is not None
|
| 122 |
-
else CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]]()
|
| 123 |
-
)
|
| 124 |
-
|
| 125 |
-
# Dispatch **kwargs to super() =>> note that `pad_token_id` collides, so we pass it in here as well...
|
| 126 |
-
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
class OpenVLAConfig(PrismaticConfig):
|
| 130 |
-
model_type: str = "openvla"
|
| 131 |
-
|
| 132 |
-
def __init__(
|
| 133 |
-
self,
|
| 134 |
-
norm_stats: Optional[Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]]] = None,
|
| 135 |
-
n_action_bins: int = 256,
|
| 136 |
-
**kwargs: str,
|
| 137 |
-
) -> None:
|
| 138 |
-
self.norm_stats, self.n_action_bins = norm_stats, n_action_bins
|
| 139 |
-
|
| 140 |
-
super().__init__(**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/prismatic/extern/hf/modeling_prismatic.py
DELETED
|
@@ -1,1085 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
modeling_prismatic.py
|
| 3 |
-
|
| 4 |
-
Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions.
|
| 5 |
-
Inherits from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained,
|
| 6 |
-
but exactly replicate the logic in `prismatic.models.vlms.prismatic.py`.
|
| 7 |
-
"""
|
| 8 |
-
|
| 9 |
-
import logging
|
| 10 |
-
from dataclasses import dataclass
|
| 11 |
-
from functools import partial
|
| 12 |
-
from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union
|
| 13 |
-
|
| 14 |
-
import numpy as np
|
| 15 |
-
import timm
|
| 16 |
-
import tokenizers
|
| 17 |
-
import torch
|
| 18 |
-
import torch.nn as nn
|
| 19 |
-
import transformers
|
| 20 |
-
from timm.models.vision_transformer import LayerScale
|
| 21 |
-
from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
| 22 |
-
from transformers.modeling_outputs import ModelOutput
|
| 23 |
-
|
| 24 |
-
from prismatic.training.train_utils import (
|
| 25 |
-
get_current_action_mask,
|
| 26 |
-
get_next_actions_mask,
|
| 27 |
-
)
|
| 28 |
-
from prismatic.vla.constants import (
|
| 29 |
-
ACTION_DIM,
|
| 30 |
-
ACTION_PROPRIO_NORMALIZATION_TYPE,
|
| 31 |
-
ACTION_TOKEN_BEGIN_IDX,
|
| 32 |
-
IGNORE_INDEX,
|
| 33 |
-
NUM_ACTIONS_CHUNK,
|
| 34 |
-
STOP_INDEX,
|
| 35 |
-
NormalizationType,
|
| 36 |
-
)
|
| 37 |
-
|
| 38 |
-
from .configuration_prismatic import OpenVLAConfig, PrismaticConfig
|
| 39 |
-
|
| 40 |
-
# Set up logger
|
| 41 |
-
logger = logging.getLogger(__name__)
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
# === Utility Functions for Monkey-Patching ===
|
| 45 |
-
def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
|
| 46 |
-
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
| 47 |
-
result = fn(*args, **kwargs)
|
| 48 |
-
return result[0] if isinstance(result, tuple) else result
|
| 49 |
-
|
| 50 |
-
return wrapper
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
# HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale.
|
| 54 |
-
# =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109
|
| 55 |
-
# =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960
|
| 56 |
-
def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 57 |
-
return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
def ls_apply_patch(ls_module: LayerScale):
|
| 61 |
-
ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone())
|
| 62 |
-
ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale)
|
| 63 |
-
del ls_module.gamma
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
# === Prismatic Vision Backbone (nn.Module) Definitions (w/ Fused Backbone Support) ===
|
| 67 |
-
class PrismaticVisionBackbone(nn.Module):
|
| 68 |
-
"""
|
| 69 |
-
Vision backbone for Prismatic models that handles image feature extraction.
|
| 70 |
-
|
| 71 |
-
Supports both single backbone (e.g., SigLIP) and fused backbone (e.g., SigLIP + DINOv2) configurations.
|
| 72 |
-
For fused backbones, features from both models are concatenated along the feature dimension.
|
| 73 |
-
"""
|
| 74 |
-
|
| 75 |
-
def __init__(
|
| 76 |
-
self,
|
| 77 |
-
use_fused_vision_backbone: bool,
|
| 78 |
-
image_sizes: List[int],
|
| 79 |
-
timm_model_ids: List[str],
|
| 80 |
-
timm_override_act_layers: List[Optional[str]],
|
| 81 |
-
) -> None:
|
| 82 |
-
"""
|
| 83 |
-
Initialize the vision backbone.
|
| 84 |
-
|
| 85 |
-
Args:
|
| 86 |
-
use_fused_vision_backbone: Whether to use two backbones and fuse their features
|
| 87 |
-
image_sizes: List of image sizes for each backbone
|
| 88 |
-
timm_model_ids: List of TIMM model IDs to use for each backbone
|
| 89 |
-
timm_override_act_layers: List of activation layer overrides for each backbone
|
| 90 |
-
"""
|
| 91 |
-
super().__init__()
|
| 92 |
-
self.use_fused_vision_backbone = use_fused_vision_backbone
|
| 93 |
-
self.num_images_in_input = 1 # Default value, can be overridden later
|
| 94 |
-
|
| 95 |
-
# Validate number of (fused) vision backbones
|
| 96 |
-
if len(timm_model_ids) > 2:
|
| 97 |
-
raise ValueError("Prismatic models only support up to 2 (fused) vision backbones!")
|
| 98 |
-
|
| 99 |
-
# Create primary featurizer
|
| 100 |
-
self.featurizer = self._create_featurizer(
|
| 101 |
-
model_id=timm_model_ids[0], img_size=image_sizes[0], act_layer=timm_override_act_layers[0]
|
| 102 |
-
)
|
| 103 |
-
self.embed_dim = self.featurizer.embed_dim
|
| 104 |
-
|
| 105 |
-
# Create secondary featurizer if using fused backbone
|
| 106 |
-
if self.use_fused_vision_backbone:
|
| 107 |
-
self.fused_featurizer = self._create_featurizer(
|
| 108 |
-
model_id=timm_model_ids[1], img_size=image_sizes[1], act_layer=timm_override_act_layers[1]
|
| 109 |
-
)
|
| 110 |
-
self.embed_dim += self.fused_featurizer.embed_dim
|
| 111 |
-
|
| 112 |
-
# Patch LayerScale modules for HF compatibility
|
| 113 |
-
self._patch_layer_scales()
|
| 114 |
-
|
| 115 |
-
def _create_featurizer(self, model_id: str, img_size: int, act_layer: Optional[str]) -> nn.Module:
|
| 116 |
-
"""
|
| 117 |
-
Create a TIMM-based featurizer model with appropriate configurations.
|
| 118 |
-
|
| 119 |
-
Args:
|
| 120 |
-
model_id: The TIMM model ID to load
|
| 121 |
-
img_size: Input image size for the model
|
| 122 |
-
act_layer: Override for the activation layer type
|
| 123 |
-
|
| 124 |
-
Returns:
|
| 125 |
-
A configured featurizer model
|
| 126 |
-
"""
|
| 127 |
-
featurizer = timm.create_model(
|
| 128 |
-
model_id,
|
| 129 |
-
pretrained=False,
|
| 130 |
-
num_classes=0,
|
| 131 |
-
img_size=img_size,
|
| 132 |
-
act_layer=act_layer,
|
| 133 |
-
)
|
| 134 |
-
|
| 135 |
-
# Monkey-patch the forward function to extract the second-to-last layer features
|
| 136 |
-
num_blocks = len(featurizer.blocks)
|
| 137 |
-
featurizer.forward = unpack_tuple(partial(featurizer.get_intermediate_layers, n={num_blocks - 2}))
|
| 138 |
-
|
| 139 |
-
return featurizer
|
| 140 |
-
|
| 141 |
-
def _patch_layer_scales(self) -> None:
|
| 142 |
-
"""
|
| 143 |
-
Patch all LayerScale modules to be compatible with HF's parameter naming.
|
| 144 |
-
|
| 145 |
-
HF Transformers overwrites parameters with names containing 'gamma',
|
| 146 |
-
so we need to rename and modify the forward method.
|
| 147 |
-
"""
|
| 148 |
-
# Patch primary featurizer
|
| 149 |
-
for module in self.featurizer.modules():
|
| 150 |
-
if isinstance(module, LayerScale):
|
| 151 |
-
ls_apply_patch(module)
|
| 152 |
-
|
| 153 |
-
# Patch secondary featurizer if it exists
|
| 154 |
-
if self.use_fused_vision_backbone:
|
| 155 |
-
for module in self.fused_featurizer.modules():
|
| 156 |
-
if isinstance(module, LayerScale):
|
| 157 |
-
ls_apply_patch(module)
|
| 158 |
-
|
| 159 |
-
def get_num_patches(self) -> int:
|
| 160 |
-
"""
|
| 161 |
-
Returns the number of vision patches output by the vision backbone.
|
| 162 |
-
|
| 163 |
-
Returns:
|
| 164 |
-
Number of patches per image
|
| 165 |
-
"""
|
| 166 |
-
return self.featurizer.patch_embed.num_patches
|
| 167 |
-
|
| 168 |
-
def get_num_images_in_input(self) -> int:
|
| 169 |
-
"""
|
| 170 |
-
Returns the number of input images for the vision backbone.
|
| 171 |
-
|
| 172 |
-
Returns:
|
| 173 |
-
Number of images expected in the input
|
| 174 |
-
"""
|
| 175 |
-
return self.num_images_in_input
|
| 176 |
-
|
| 177 |
-
def set_num_images_in_input(self, num_images_in_input: int) -> None:
|
| 178 |
-
"""
|
| 179 |
-
Sets the number of input images for the vision backbone.
|
| 180 |
-
|
| 181 |
-
Args:
|
| 182 |
-
num_images_in_input: Number of images to expect in the input
|
| 183 |
-
"""
|
| 184 |
-
self.num_images_in_input = num_images_in_input
|
| 185 |
-
|
| 186 |
-
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| 187 |
-
"""
|
| 188 |
-
Implements the forward pass for the vision backbone.
|
| 189 |
-
|
| 190 |
-
If `self.use_fused_vision_backbone == True`, uses both SigLIP and DINOv2 transformers to extract visual features
|
| 191 |
-
(otherwise uses SigLIP only). Allows multi-image inputs (but only for fused vision backbone).
|
| 192 |
-
|
| 193 |
-
Args:
|
| 194 |
-
pixel_values (torch.Tensor): Pixels for input image(s), (B, C, H, W).
|
| 195 |
-
"""
|
| 196 |
-
if self.num_images_in_input == 1:
|
| 197 |
-
if not self.use_fused_vision_backbone:
|
| 198 |
-
return self.featurizer(pixel_values)
|
| 199 |
-
|
| 200 |
-
# Split `pixel_values :: [bsz, 2 * 3, resolution, resolution]` =>> featurize =>> channel stack
|
| 201 |
-
img, img_fused = torch.split(pixel_values, [3, 3], dim=1)
|
| 202 |
-
patches, patches_fused = self.featurizer(img), self.fused_featurizer(img_fused)
|
| 203 |
-
|
| 204 |
-
return torch.cat([patches, patches_fused], dim=2)
|
| 205 |
-
|
| 206 |
-
else:
|
| 207 |
-
assert self.use_fused_vision_backbone, "Multi-image inputs require using fused backbone!"
|
| 208 |
-
|
| 209 |
-
# Split `pixel_values` into individual images (each with 6 channels: 3 for SigLIP + 3 for DINOv2)
|
| 210 |
-
images = torch.split(pixel_values, [6] * self.num_images_in_input, dim=1)
|
| 211 |
-
|
| 212 |
-
# Process each image and collect patches
|
| 213 |
-
all_patches = []
|
| 214 |
-
for img in images:
|
| 215 |
-
# Split each image further into two stacks of channels (each with 3 channels)
|
| 216 |
-
img_regular, img_fused = torch.split(img, [3, 3], dim=1)
|
| 217 |
-
|
| 218 |
-
# Get patches from both SigLIP and DINOv2 vision transformers
|
| 219 |
-
patches = self.featurizer(img_regular)
|
| 220 |
-
patches_fused = self.fused_featurizer(img_fused)
|
| 221 |
-
|
| 222 |
-
# Concatenate SigLIP and DINOv2 patches along the hidden dimension
|
| 223 |
-
combined_patches = torch.cat([patches, patches_fused], dim=2)
|
| 224 |
-
all_patches.append(combined_patches)
|
| 225 |
-
|
| 226 |
-
# Concatenate all patches along the patch dimension
|
| 227 |
-
return torch.cat(all_patches, dim=1)
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
# === Prismatic Projector (nn.Module) Definitions ===
|
| 231 |
-
class PrismaticProjector(nn.Module):
|
| 232 |
-
def __init__(self, use_fused_vision_backbone: bool, vision_dim: int, llm_dim: int) -> None:
|
| 233 |
-
super().__init__()
|
| 234 |
-
self.use_fused_vision_backbone = use_fused_vision_backbone
|
| 235 |
-
self.vision_dim, self.llm_dim = vision_dim, llm_dim
|
| 236 |
-
|
| 237 |
-
# Switch on `use_fused_vision_backbone` =>> use slightly different MLPs and projection factors!
|
| 238 |
-
if not self.use_fused_vision_backbone:
|
| 239 |
-
self.fc1 = nn.Linear(self.vision_dim, self.llm_dim, bias=True)
|
| 240 |
-
self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
|
| 241 |
-
self.act_fn1 = nn.GELU()
|
| 242 |
-
else:
|
| 243 |
-
initial_projection_dim = 4 * vision_dim
|
| 244 |
-
self.fc1 = nn.Linear(self.vision_dim, initial_projection_dim, bias=True)
|
| 245 |
-
self.fc2 = nn.Linear(initial_projection_dim, self.llm_dim, bias=True)
|
| 246 |
-
self.fc3 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
|
| 247 |
-
self.act_fn1 = nn.GELU()
|
| 248 |
-
self.act_fn2 = nn.GELU()
|
| 249 |
-
|
| 250 |
-
def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
|
| 251 |
-
if not self.use_fused_vision_backbone:
|
| 252 |
-
projected_features = self.fc1(img_patches)
|
| 253 |
-
projected_features = self.act_fn1(projected_features)
|
| 254 |
-
projected_features = self.fc2(projected_features)
|
| 255 |
-
else:
|
| 256 |
-
projected_features = self.fc1(img_patches)
|
| 257 |
-
projected_features = self.act_fn1(projected_features)
|
| 258 |
-
projected_features = self.fc2(projected_features)
|
| 259 |
-
projected_features = self.act_fn2(projected_features)
|
| 260 |
-
projected_features = self.fc3(projected_features)
|
| 261 |
-
|
| 262 |
-
return projected_features
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
# === Main HF Class Definitions ===
|
| 266 |
-
@dataclass
|
| 267 |
-
class PrismaticCausalLMOutputWithPast(ModelOutput):
|
| 268 |
-
"""Base class for Prismatic casual (visually-conditioned) language model outputs; also exposes visual features."""
|
| 269 |
-
|
| 270 |
-
loss: Optional[torch.FloatTensor] = None
|
| 271 |
-
logits: torch.FloatTensor = None
|
| 272 |
-
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
| 273 |
-
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 274 |
-
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 275 |
-
|
| 276 |
-
# Additions for VLMs
|
| 277 |
-
projector_features: Optional[torch.FloatTensor] = None
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
class PrismaticPreTrainedModel(PreTrainedModel):
|
| 281 |
-
config_class: PretrainedConfig = PrismaticConfig
|
| 282 |
-
base_model_prefix: str = "model"
|
| 283 |
-
supports_gradient_checkpointing: bool = True
|
| 284 |
-
|
| 285 |
-
_no_split_modules: ClassVar[List[str]] = ["PrismaticProjector"]
|
| 286 |
-
_skip_keys_device_placement: str = "past_key_values"
|
| 287 |
-
_supports_flash_attn_2: bool = True
|
| 288 |
-
|
| 289 |
-
def _init_weights(self, module: nn.Module) -> None:
|
| 290 |
-
# Important :: this HF ported version is *not* meant for training from scratch; only inference and fine-tuning!
|
| 291 |
-
# => As such, this init_weights code is not correct; if training VLMs from scratch, use the main codebase at
|
| 292 |
-
# https://github.com/TRI-ML/prismatic-vlms
|
| 293 |
-
std = (
|
| 294 |
-
self.config.initializer_range
|
| 295 |
-
if hasattr(self.config, "initializer_range")
|
| 296 |
-
else self.config.text_config.initializer_range
|
| 297 |
-
)
|
| 298 |
-
|
| 299 |
-
if hasattr(module, "class_embedding"):
|
| 300 |
-
module.class_embedding.data.normal_(mean=0.0, std=std)
|
| 301 |
-
|
| 302 |
-
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
| 303 |
-
module.weight.data.normal_(mean=0.0, std=std)
|
| 304 |
-
if module.bias is not None:
|
| 305 |
-
module.bias.data.zero_()
|
| 306 |
-
elif isinstance(module, nn.Embedding):
|
| 307 |
-
module.weight.data.normal_(mean=0.0, std=std)
|
| 308 |
-
if module.padding_idx is not None:
|
| 309 |
-
module.weight.data[module.padding_idx].zero_()
|
| 310 |
-
|
| 311 |
-
@property
|
| 312 |
-
def _supports_sdpa(self) -> bool:
|
| 313 |
-
"""Check LLM supports SDPA Attention"""
|
| 314 |
-
return self.language_model._supports_sdpa
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
|
| 318 |
-
def __init__(self, config: PrismaticConfig) -> None:
|
| 319 |
-
super().__init__(config)
|
| 320 |
-
|
| 321 |
-
# [Validation] Lightweight Validate on `config` Fields + Dependency Versions
|
| 322 |
-
if config.use_fused_vision_backbone is None:
|
| 323 |
-
raise ValueError("Missing config field `use_fused_vision_backbone`")
|
| 324 |
-
|
| 325 |
-
if timm.__version__ not in {"0.9.10", "0.9.11", "0.9.12", "0.9.16"}:
|
| 326 |
-
raise NotImplementedError(
|
| 327 |
-
"TIMM Version must be >= 0.9.10 and < 1.0.0 (breaking); please raise a GitHub Issue "
|
| 328 |
-
"if you urgently need support for latest TIMM versions."
|
| 329 |
-
)
|
| 330 |
-
|
| 331 |
-
if (transformers.__version__ != "4.40.1") or (tokenizers.__version__ != "0.19.1"):
|
| 332 |
-
logger.warning(
|
| 333 |
-
f"Expected `transformers==4.40.1` and `tokenizers==0.19.1` but got "
|
| 334 |
-
f"`transformers=={transformers.__version__}` and `tokenizers=={tokenizers.__version__}`; "
|
| 335 |
-
f"there might be inference-time regressions due to dependency changes. If in doubt, please"
|
| 336 |
-
f"use the above versions."
|
| 337 |
-
)
|
| 338 |
-
|
| 339 |
-
# Instantiate PrismaticVisionBackbone (w/ Potential Fused Backbone)
|
| 340 |
-
self.vision_backbone = PrismaticVisionBackbone(
|
| 341 |
-
config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers
|
| 342 |
-
)
|
| 343 |
-
|
| 344 |
-
# Create Multimodal Projector
|
| 345 |
-
self.projector = PrismaticProjector(
|
| 346 |
-
config.use_fused_vision_backbone,
|
| 347 |
-
vision_dim=self.vision_backbone.embed_dim,
|
| 348 |
-
llm_dim=config.text_config.hidden_size,
|
| 349 |
-
)
|
| 350 |
-
|
| 351 |
-
# Instantiate LLM Backbone
|
| 352 |
-
self.language_model = AutoModelForCausalLM.from_config(
|
| 353 |
-
config.text_config, attn_implementation=config._attn_implementation
|
| 354 |
-
)
|
| 355 |
-
self.vocab_size = config.text_config.vocab_size
|
| 356 |
-
self.pad_token_id = config.pad_token_id
|
| 357 |
-
self.llm_dim = config.text_config.hidden_size
|
| 358 |
-
|
| 359 |
-
# HF Boilerplate =>> initializes weights via `_init_weights()` and sets gradient checkpointing
|
| 360 |
-
self.post_init()
|
| 361 |
-
|
| 362 |
-
# === `PreTrainedModel` Boilerplate ===
|
| 363 |
-
def get_input_embeddings(self) -> nn.Module:
|
| 364 |
-
return self.language_model.get_input_embeddings()
|
| 365 |
-
|
| 366 |
-
def set_input_embeddings(self, value: nn.Module) -> None:
|
| 367 |
-
self.language_model.set_input_embeddings(value)
|
| 368 |
-
|
| 369 |
-
def get_output_embeddings(self) -> nn.Module:
|
| 370 |
-
return self.language_model.get_output_embeddings()
|
| 371 |
-
|
| 372 |
-
def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
|
| 373 |
-
self.language_model.set_output_embeddings(new_embeddings)
|
| 374 |
-
|
| 375 |
-
def get_decoder(self) -> nn.Module:
|
| 376 |
-
return self.language_model.get_decoder()
|
| 377 |
-
|
| 378 |
-
def set_decoder(self, decoder: nn.Module) -> None:
|
| 379 |
-
self.language_model.set_decoder(decoder)
|
| 380 |
-
|
| 381 |
-
def tie_weights(self) -> None:
|
| 382 |
-
self.language_model.tie_weights() # Note: `Llama-2` and `Mistral` don't tie weights (no-op)
|
| 383 |
-
|
| 384 |
-
def resize_token_embeddings(
|
| 385 |
-
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
|
| 386 |
-
) -> nn.Embedding:
|
| 387 |
-
updated_embeddings = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
| 388 |
-
|
| 389 |
-
# Update config/instance variables
|
| 390 |
-
self.config.text_config.vocab_size = updated_embeddings.num_embeddings
|
| 391 |
-
self.vocab_size = updated_embeddings.num_embeddings
|
| 392 |
-
|
| 393 |
-
return updated_embeddings
|
| 394 |
-
|
| 395 |
-
def _replace_input_embeddings(self, input_embeddings, all_actions_mask, noisy_action_features):
|
| 396 |
-
"""
|
| 397 |
-
Replace embeddings in input_embeddings at positions where all_actions_mask is True
|
| 398 |
-
with embeddings from noisy_action_features, using vectorized operations.
|
| 399 |
-
|
| 400 |
-
Args:
|
| 401 |
-
input_embeddings: Tensor of shape (B, S, D)
|
| 402 |
-
all_actions_mask: Boolean tensor of shape (B, S)
|
| 403 |
-
noisy_action_features: Tensor of shape (B, K, D) where K is the number of True values in mask per sample
|
| 404 |
-
|
| 405 |
-
Returns:
|
| 406 |
-
Modified input_embeddings tensor
|
| 407 |
-
"""
|
| 408 |
-
# Clone input to avoid modifying the original tensor
|
| 409 |
-
new_input_embeddings = input_embeddings.clone()
|
| 410 |
-
|
| 411 |
-
# Create a tensor with the same shape of input_embeddings to hold the noisy action features
|
| 412 |
-
repositioned_noisy_action_features = torch.zeros_like(input_embeddings)
|
| 413 |
-
|
| 414 |
-
# Create batch indices for splicing
|
| 415 |
-
batch_indices = torch.arange(input_embeddings.shape[0], device=input_embeddings.device)
|
| 416 |
-
batch_indices = batch_indices.unsqueeze(1).expand(-1, noisy_action_features.shape[1])
|
| 417 |
-
|
| 418 |
-
# Get indices where mask is True for each sample
|
| 419 |
-
masked_indices = torch.stack([torch.where(mask)[0] for mask in all_actions_mask])
|
| 420 |
-
|
| 421 |
-
# Move the noisy action features into their correct positions
|
| 422 |
-
repositioned_noisy_action_features[batch_indices, masked_indices] = noisy_action_features
|
| 423 |
-
|
| 424 |
-
# Combine original input embeddings and noisy action embeddings using the mask
|
| 425 |
-
new_input_embeddings = torch.where(
|
| 426 |
-
all_actions_mask.unsqueeze(-1), repositioned_noisy_action_features, new_input_embeddings
|
| 427 |
-
)
|
| 428 |
-
|
| 429 |
-
return new_input_embeddings
|
| 430 |
-
|
| 431 |
-
def _process_action_masks(self, labels):
|
| 432 |
-
"""Helper to get action masks from labels"""
|
| 433 |
-
current_action_mask = get_current_action_mask(labels)
|
| 434 |
-
next_actions_mask = get_next_actions_mask(labels)
|
| 435 |
-
all_actions_mask = current_action_mask | next_actions_mask # (B, seq_len)
|
| 436 |
-
return all_actions_mask
|
| 437 |
-
|
| 438 |
-
def _process_vision_features(self, pixel_values, language_embeddings=None, use_film=False):
|
| 439 |
-
"""Process vision features with optional FiLM conditioning"""
|
| 440 |
-
if use_film:
|
| 441 |
-
# FiLM: Infuse language inputs into visual features
|
| 442 |
-
patch_features = self.vision_backbone(pixel_values, language_embeddings) # (bsz, 256 * num_images, D)
|
| 443 |
-
else:
|
| 444 |
-
patch_features = self.vision_backbone(pixel_values) # (bsz, 256 * num_images, D)
|
| 445 |
-
|
| 446 |
-
# Project patch embeddings into language embedding space
|
| 447 |
-
return self.projector(patch_features)
|
| 448 |
-
|
| 449 |
-
def _process_proprio_features(self, projected_patch_embeddings, proprio, proprio_projector):
|
| 450 |
-
"""Process proprioceptive features and append to vision features"""
|
| 451 |
-
if proprio_projector is not None and proprio is not None:
|
| 452 |
-
# projected_patch_embeddings: (bsz, num_patches * num_images, llm_dim)
|
| 453 |
-
# proprio: (bsz, proprio_dim) or (propro_dim,)
|
| 454 |
-
proprio = proprio.reshape(projected_patch_embeddings.shape[0], -1) # (bsz, proprio_dim)
|
| 455 |
-
proprio_features = proprio_projector(proprio) # (bsz, llm_dim)
|
| 456 |
-
proprio_features = proprio_features.unsqueeze(dim=1) # (bsz, 1, llm_dim)
|
| 457 |
-
# For simplicity, just append proprio token to the end of projected vision patch tokens
|
| 458 |
-
return torch.cat((projected_patch_embeddings, proprio_features), dim=1)
|
| 459 |
-
return projected_patch_embeddings
|
| 460 |
-
|
| 461 |
-
def _build_multimodal_attention(self, input_embeddings, projected_patch_embeddings, attention_mask):
|
| 462 |
-
"""Build multimodal embeddings and attention mask"""
|
| 463 |
-
# Update attention mask
|
| 464 |
-
projected_patch_attention_mask = None
|
| 465 |
-
if attention_mask is not None:
|
| 466 |
-
projected_patch_attention_mask = torch.full(
|
| 467 |
-
(projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
|
| 468 |
-
fill_value=True,
|
| 469 |
-
dtype=attention_mask.dtype,
|
| 470 |
-
device=attention_mask.device,
|
| 471 |
-
)
|
| 472 |
-
|
| 473 |
-
# Build multimodal embeddings & attention mask; insert embeddings after <BOS> token (1:)
|
| 474 |
-
multimodal_embeddings = torch.cat(
|
| 475 |
-
[input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1
|
| 476 |
-
)
|
| 477 |
-
|
| 478 |
-
multimodal_attention_mask = None
|
| 479 |
-
if attention_mask is not None:
|
| 480 |
-
multimodal_attention_mask = torch.cat(
|
| 481 |
-
[attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1
|
| 482 |
-
)
|
| 483 |
-
|
| 484 |
-
return multimodal_embeddings, multimodal_attention_mask
|
| 485 |
-
|
| 486 |
-
def _build_multimodal_labels(self, labels, projected_patch_embeddings):
|
| 487 |
-
"""Build multimodal labels with IGNORE_INDEX for patch embeddings"""
|
| 488 |
-
if labels is not None:
|
| 489 |
-
projected_patch_labels = torch.full(
|
| 490 |
-
(projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
|
| 491 |
-
fill_value=IGNORE_INDEX,
|
| 492 |
-
dtype=labels.dtype,
|
| 493 |
-
device=labels.device,
|
| 494 |
-
)
|
| 495 |
-
return torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1)
|
| 496 |
-
return None
|
| 497 |
-
|
| 498 |
-
# === Core Prismatic VLM `forward()` Logic ===
|
| 499 |
-
def forward(
|
| 500 |
-
self,
|
| 501 |
-
input_ids: Optional[torch.LongTensor] = None,
|
| 502 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 503 |
-
pixel_values: Optional[torch.FloatTensor] = None,
|
| 504 |
-
labels: Optional[torch.LongTensor] = None,
|
| 505 |
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 506 |
-
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 507 |
-
use_cache: Optional[bool] = None,
|
| 508 |
-
output_attentions: Optional[bool] = None,
|
| 509 |
-
output_hidden_states: Optional[bool] = None,
|
| 510 |
-
output_projector_features: Optional[bool] = None,
|
| 511 |
-
return_dict: Optional[bool] = None,
|
| 512 |
-
proprio=None,
|
| 513 |
-
proprio_projector=None,
|
| 514 |
-
noisy_actions=None,
|
| 515 |
-
noisy_action_projector=None,
|
| 516 |
-
diffusion_timestep_embeddings=None,
|
| 517 |
-
use_film: bool = False,
|
| 518 |
-
) -> Union[Tuple, PrismaticCausalLMOutputWithPast]:
|
| 519 |
-
"""Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance."""
|
| 520 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 521 |
-
output_hidden_states = (
|
| 522 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 523 |
-
)
|
| 524 |
-
output_projector_features = output_projector_features if output_projector_features is not None else False
|
| 525 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 526 |
-
|
| 527 |
-
# Respect `use_cache` only if not training (even if `gradient_checkpointing` is off)
|
| 528 |
-
use_cache = use_cache and not self.training
|
| 529 |
-
|
| 530 |
-
# Instantiate Placeholder for Projector Features
|
| 531 |
-
projected_patch_embeddings = None
|
| 532 |
-
|
| 533 |
-
# === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
|
| 534 |
-
if input_ids.shape[1] == 1:
|
| 535 |
-
assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
|
| 536 |
-
assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
|
| 537 |
-
assert labels is None, "Unexpected key `labels` provided during cached generation!"
|
| 538 |
-
|
| 539 |
-
language_model_output = self.language_model(
|
| 540 |
-
input_ids=input_ids,
|
| 541 |
-
attention_mask=None,
|
| 542 |
-
position_ids=None,
|
| 543 |
-
past_key_values=past_key_values,
|
| 544 |
-
inputs_embeds=None,
|
| 545 |
-
labels=None,
|
| 546 |
-
use_cache=use_cache,
|
| 547 |
-
output_attentions=output_attentions,
|
| 548 |
-
output_hidden_states=output_hidden_states,
|
| 549 |
-
return_dict=return_dict,
|
| 550 |
-
)
|
| 551 |
-
|
| 552 |
-
# === Handle Unimodal Forward ===
|
| 553 |
-
elif pixel_values is None:
|
| 554 |
-
assert (input_ids is not None) and (inputs_embeds is None), "Missing `input_ids` in language-only forward!"
|
| 555 |
-
assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
|
| 556 |
-
|
| 557 |
-
language_model_output = self.language_model(
|
| 558 |
-
input_ids=input_ids,
|
| 559 |
-
attention_mask=attention_mask,
|
| 560 |
-
position_ids=None,
|
| 561 |
-
past_key_values=None,
|
| 562 |
-
inputs_embeds=None,
|
| 563 |
-
labels=labels,
|
| 564 |
-
use_cache=use_cache,
|
| 565 |
-
output_attentions=output_attentions,
|
| 566 |
-
output_hidden_states=output_hidden_states,
|
| 567 |
-
return_dict=return_dict,
|
| 568 |
-
)
|
| 569 |
-
|
| 570 |
-
# === Handle Multimodal Forward ===
|
| 571 |
-
elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]):
|
| 572 |
-
assert past_key_values is None, "Unexpected key `past_key_values` provided during multimodal forward!"
|
| 573 |
-
|
| 574 |
-
# Get input embeddings (from language model embeddings)
|
| 575 |
-
input_embeddings = self.get_input_embeddings()(input_ids) # (B, seq_len, D)
|
| 576 |
-
|
| 577 |
-
# Extract action masks
|
| 578 |
-
all_actions_mask = self._process_action_masks(labels)
|
| 579 |
-
|
| 580 |
-
# Extract the language portion of the input embeddings (i.e. remove the action tokens portion)
|
| 581 |
-
language_embeddings = input_embeddings[~all_actions_mask].reshape(
|
| 582 |
-
input_embeddings.shape[0], -1, input_embeddings.shape[2]
|
| 583 |
-
) # (B, lang_seq_len, llm_dim)
|
| 584 |
-
|
| 585 |
-
# Get visual features
|
| 586 |
-
projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)
|
| 587 |
-
|
| 588 |
-
# Add proprioceptive state if provided
|
| 589 |
-
projected_patch_embeddings = self._process_proprio_features(
|
| 590 |
-
projected_patch_embeddings, proprio, proprio_projector
|
| 591 |
-
)
|
| 592 |
-
|
| 593 |
-
# [Diffusion] Add diffusion timestep embedding if provided
|
| 594 |
-
if diffusion_timestep_embeddings is not None:
|
| 595 |
-
# For simplicity, just append diffusion timestep embedding to the end of projected vision patch tokens
|
| 596 |
-
projected_patch_embeddings = torch.cat(
|
| 597 |
-
(projected_patch_embeddings, diffusion_timestep_embeddings), dim=1
|
| 598 |
-
)
|
| 599 |
-
|
| 600 |
-
# Process action embeddings
|
| 601 |
-
if noisy_actions is not None:
|
| 602 |
-
# Get mask corresponding to all action tokens
|
| 603 |
-
all_actions_mask = self._process_action_masks(labels)
|
| 604 |
-
|
| 605 |
-
# Reshape noisy actions into individual action tokens
|
| 606 |
-
# noisy_actions: (B, chunk_len, action_dim) -> (B, chunk_len * action_dim, 1)
|
| 607 |
-
B = noisy_actions.shape[0]
|
| 608 |
-
noisy_actions = noisy_actions.reshape(B, -1).unsqueeze(-1)
|
| 609 |
-
|
| 610 |
-
# Project noisy action tokens into language model embedding space
|
| 611 |
-
noisy_action_features = noisy_action_projector(noisy_actions) # (B, chunk_len * action_dim, llm_dim)
|
| 612 |
-
|
| 613 |
-
# Replace embeddings of the action tokens with noisy action embeddings
|
| 614 |
-
input_embeddings = self._replace_input_embeddings(
|
| 615 |
-
input_embeddings, all_actions_mask, noisy_action_features
|
| 616 |
-
)
|
| 617 |
-
else:
|
| 618 |
-
# Replace the embeddings of the action tokens with zeros
|
| 619 |
-
# (Later on, the positional embeddings will be added to them)
|
| 620 |
-
all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
|
| 621 |
-
input_embeddings = input_embeddings * ~all_actions_mask
|
| 622 |
-
|
| 623 |
-
# Build multimodal embeddings & attention mask
|
| 624 |
-
multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
|
| 625 |
-
input_embeddings, projected_patch_embeddings, attention_mask
|
| 626 |
-
)
|
| 627 |
-
|
| 628 |
-
# Build labels for multimodal sequence if needed
|
| 629 |
-
multimodal_labels = self._build_multimodal_labels(labels, projected_patch_embeddings)
|
| 630 |
-
|
| 631 |
-
# Dispatch to language model
|
| 632 |
-
language_model_output = self.language_model(
|
| 633 |
-
input_ids=None,
|
| 634 |
-
attention_mask=multimodal_attention_mask,
|
| 635 |
-
position_ids=None,
|
| 636 |
-
past_key_values=None,
|
| 637 |
-
inputs_embeds=multimodal_embeddings,
|
| 638 |
-
labels=multimodal_labels,
|
| 639 |
-
use_cache=use_cache,
|
| 640 |
-
output_attentions=output_attentions,
|
| 641 |
-
output_hidden_states=output_hidden_states,
|
| 642 |
-
return_dict=return_dict,
|
| 643 |
-
)
|
| 644 |
-
|
| 645 |
-
# === Otherwise =>> Assume Invalid! ===
|
| 646 |
-
elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]):
|
| 647 |
-
raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!")
|
| 648 |
-
|
| 649 |
-
else:
|
| 650 |
-
raise ValueError(
|
| 651 |
-
"Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n"
|
| 652 |
-
f"=> `input_ids` = {input_ids is not None}\n"
|
| 653 |
-
f"=> `attention_mask` = {attention_mask is not None}\n"
|
| 654 |
-
f"=> `pixel_values` = {pixel_values is not None}\n"
|
| 655 |
-
f"=> `labels` = {labels is not None}\n"
|
| 656 |
-
f"=> `input_embeds` = {inputs_embeds is not None}\n"
|
| 657 |
-
f"=> `past_key_values` = {past_key_values is not None}\n"
|
| 658 |
-
f"=> `use_cache` = {use_cache}"
|
| 659 |
-
)
|
| 660 |
-
|
| 661 |
-
# Unpack `language_model_output` and return PrismaticCausalLMOutputWithPast (or tuple if not `return_dict`)
|
| 662 |
-
if not return_dict:
|
| 663 |
-
if output_projector_features and (projected_patch_embeddings is not None):
|
| 664 |
-
return *language_model_output, projected_patch_embeddings
|
| 665 |
-
|
| 666 |
-
return language_model_output
|
| 667 |
-
|
| 668 |
-
return PrismaticCausalLMOutputWithPast(
|
| 669 |
-
loss=language_model_output.loss,
|
| 670 |
-
logits=language_model_output.logits,
|
| 671 |
-
past_key_values=language_model_output.past_key_values,
|
| 672 |
-
hidden_states=language_model_output.hidden_states,
|
| 673 |
-
attentions=language_model_output.attentions,
|
| 674 |
-
projector_features=projected_patch_embeddings,
|
| 675 |
-
)
|
| 676 |
-
|
| 677 |
-
# === GenerationMixin Methods ===
|
| 678 |
-
def prepare_inputs_for_generation(
|
| 679 |
-
self,
|
| 680 |
-
input_ids: Optional[torch.Tensor] = None,
|
| 681 |
-
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 682 |
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 683 |
-
pixel_values: Optional[torch.FloatTensor] = None,
|
| 684 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 685 |
-
**kwargs: str,
|
| 686 |
-
) -> Dict[str, torch.Tensor]:
|
| 687 |
-
"""Borrowed from `LlamaForCausalLM` and simplified for batch size = 1; mirrors original PrismaticVLM logic."""
|
| 688 |
-
if ((input_ids is not None) and (input_ids.shape[0] > 1)) or (
|
| 689 |
-
(inputs_embeds is not None) and (inputs_embeds.shape[0] > 1)
|
| 690 |
-
):
|
| 691 |
-
raise ValueError("Generation with batch size > 1 is not currently supported!")
|
| 692 |
-
|
| 693 |
-
# Handle `past_key_values` (cache) =>> assume `input_ids` just has unprocessed tokens
|
| 694 |
-
if past_key_values is not None:
|
| 695 |
-
input_ids = input_ids[:, -1:]
|
| 696 |
-
|
| 697 |
-
# If `input_embeds` are passed, we only want to use them in the 1st generation step
|
| 698 |
-
if inputs_embeds is not None and past_key_values is None:
|
| 699 |
-
model_inputs = {"input_embeds": inputs_embeds}
|
| 700 |
-
else:
|
| 701 |
-
model_inputs = {"input_ids": input_ids}
|
| 702 |
-
|
| 703 |
-
# Make sure `pixel_values` are preserved in `model_inputs`
|
| 704 |
-
model_inputs.update(
|
| 705 |
-
{
|
| 706 |
-
"attention_mask": attention_mask,
|
| 707 |
-
"pixel_values": pixel_values,
|
| 708 |
-
"past_key_values": past_key_values,
|
| 709 |
-
"use_cache": kwargs.get("use_cache"),
|
| 710 |
-
}
|
| 711 |
-
)
|
| 712 |
-
|
| 713 |
-
return model_inputs
|
| 714 |
-
|
| 715 |
-
# Defer to Language Model (all handle this differently, with different return types)
|
| 716 |
-
def _reorder_cache(self, *args, **kwargs) -> Any:
|
| 717 |
-
return self.language_model._reorder_cache(*args, **kwargs)
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
|
| 721 |
-
config_class: PretrainedConfig = OpenVLAConfig
|
| 722 |
-
|
| 723 |
-
def __init__(self, config: OpenVLAConfig) -> None:
|
| 724 |
-
super().__init__(config)
|
| 725 |
-
self.norm_stats = config.norm_stats
|
| 726 |
-
|
| 727 |
-
# Compute action bins
|
| 728 |
-
self.bins = np.linspace(-1, 1, config.n_action_bins)
|
| 729 |
-
self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0
|
| 730 |
-
|
| 731 |
-
# Compute vocab size for de-tokenization -- revert added "multiple of"
|
| 732 |
-
self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of
|
| 733 |
-
|
| 734 |
-
def _prepare_input_for_action_prediction(self, input_ids, attention_mask):
|
| 735 |
-
"""Prepares input for action prediction by adding necessary tokens"""
|
| 736 |
-
# Add (ACTION_DIM * NUM_ACTIONS_CHUNK) placeholder tokens to input_ids to simulate action tokens
|
| 737 |
-
placeholder_action_token_ids = (
|
| 738 |
-
torch.ones((input_ids.shape[0], ACTION_DIM * NUM_ACTIONS_CHUNK)).to(input_ids.device).to(input_ids.dtype)
|
| 739 |
-
)
|
| 740 |
-
input_ids = torch.cat([input_ids, placeholder_action_token_ids], dim=-1)
|
| 741 |
-
|
| 742 |
-
# Add stop token to sequence (needed in non-causal bi-directional self-attention, as it appears at train time)
|
| 743 |
-
stop_token_id = torch.ones((input_ids.shape[0], 1)).to(input_ids.device).to(input_ids.dtype) * STOP_INDEX
|
| 744 |
-
input_ids = torch.cat([input_ids, stop_token_id], dim=-1)
|
| 745 |
-
|
| 746 |
-
# Extend the attention mask to fit the new shape of input
|
| 747 |
-
# Note: Only batch size == 1 supported right now
|
| 748 |
-
mask_extension = (
|
| 749 |
-
torch.ones((attention_mask.shape[0], input_ids.shape[-1] - attention_mask.shape[-1]))
|
| 750 |
-
.to(attention_mask.device)
|
| 751 |
-
.to(attention_mask.dtype)
|
| 752 |
-
)
|
| 753 |
-
attention_mask = torch.cat([attention_mask, mask_extension], dim=-1)
|
| 754 |
-
|
| 755 |
-
return input_ids, attention_mask
|
| 756 |
-
|
| 757 |
-
def _prepare_labels_for_action_prediction(self, labels, input_ids):
|
| 758 |
-
"""Creates labels tensor for action prediction if not provided"""
|
| 759 |
-
# Extend labels tensor with fake action labels
|
| 760 |
-
ARBITRARY_ACTION_TOKEN_IDX = ACTION_TOKEN_BEGIN_IDX + 1
|
| 761 |
-
labels_extension = (
|
| 762 |
-
torch.ones((labels.shape[0], input_ids.shape[-1] - labels.shape[-1])).to(labels.device).to(labels.dtype)
|
| 763 |
-
* ARBITRARY_ACTION_TOKEN_IDX
|
| 764 |
-
)
|
| 765 |
-
labels = torch.cat([labels, labels_extension], dim=-1)
|
| 766 |
-
|
| 767 |
-
# Replace last label token with stop token
|
| 768 |
-
labels[:, -1] = STOP_INDEX
|
| 769 |
-
|
| 770 |
-
return labels
|
| 771 |
-
|
| 772 |
-
def _unnormalize_actions(self, normalized_actions, unnorm_key=None):
|
| 773 |
-
"""Unnormalize actions using dataset statistics"""
|
| 774 |
-
action_norm_stats = self.get_action_stats(unnorm_key)
|
| 775 |
-
|
| 776 |
-
if ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS:
|
| 777 |
-
mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["min"], dtype=bool))
|
| 778 |
-
action_high, action_low = np.array(action_norm_stats["max"]), np.array(action_norm_stats["min"])
|
| 779 |
-
elif ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS_Q99:
|
| 780 |
-
mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool))
|
| 781 |
-
action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"])
|
| 782 |
-
else:
|
| 783 |
-
raise ValueError("Unsupported action/proprio normalization type detected!")
|
| 784 |
-
|
| 785 |
-
actions = np.where(
|
| 786 |
-
mask,
|
| 787 |
-
0.5 * (normalized_actions + 1) * (action_high - action_low + 1e-8) + action_low,
|
| 788 |
-
normalized_actions,
|
| 789 |
-
)
|
| 790 |
-
|
| 791 |
-
return actions
|
| 792 |
-
|
| 793 |
-
def _run_diffusion_prediction(
|
| 794 |
-
self,
|
| 795 |
-
input_embeddings,
|
| 796 |
-
all_actions_mask,
|
| 797 |
-
noise,
|
| 798 |
-
action_head,
|
| 799 |
-
projected_patch_embeddings,
|
| 800 |
-
labels,
|
| 801 |
-
attention_mask,
|
| 802 |
-
NUM_PATCHES,
|
| 803 |
-
NUM_PROMPT_TOKENS,
|
| 804 |
-
noisy_action_projector,
|
| 805 |
-
):
|
| 806 |
-
"""Run diffusion-based action prediction"""
|
| 807 |
-
# Clone embedding for reuse in each timestep
|
| 808 |
-
orig_projected_patch_embeddings = projected_patch_embeddings.clone()
|
| 809 |
-
curr_noisy_actions = noise
|
| 810 |
-
|
| 811 |
-
# Reverse diffusion: Iteratively denoise to generate action prediction
|
| 812 |
-
for t in action_head.noise_scheduler.timesteps:
|
| 813 |
-
# Get diffusion model's noise prediction (conditioned on VLA latent embedding, current noisy action
|
| 814 |
-
# embedding, and diffusion timestep embedding)
|
| 815 |
-
timesteps = torch.Tensor([t]).to(labels.device)
|
| 816 |
-
diffusion_timestep_embeddings = (
|
| 817 |
-
action_head.time_encoder(timesteps).to(curr_noisy_actions.dtype).to(curr_noisy_actions.device)
|
| 818 |
-
) # (B, llm_dim)
|
| 819 |
-
diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) # (B, 1, llm_dim)
|
| 820 |
-
|
| 821 |
-
# [Diffusion] Replace the embeddings of the action tokens with noisy actions
|
| 822 |
-
# (Later on, the positional embeddings will be added to them)
|
| 823 |
-
|
| 824 |
-
# For simplicity, append diffusion timestep embedding to the end of projected vision tokens
|
| 825 |
-
projected_patch_embeddings = torch.cat(
|
| 826 |
-
(orig_projected_patch_embeddings, diffusion_timestep_embeddings), dim=1
|
| 827 |
-
)
|
| 828 |
-
|
| 829 |
-
# Reshape and project noisy actions into language embedding space
|
| 830 |
-
B = curr_noisy_actions.shape[0]
|
| 831 |
-
orig_curr_noisy_actions_shape = curr_noisy_actions.shape
|
| 832 |
-
curr_noisy_actions = curr_noisy_actions.reshape(B, -1).unsqueeze(-1)
|
| 833 |
-
noisy_action_features = noisy_action_projector(curr_noisy_actions)
|
| 834 |
-
curr_noisy_actions = curr_noisy_actions.reshape(orig_curr_noisy_actions_shape)
|
| 835 |
-
|
| 836 |
-
# Replace action token embeddings with noisy action embeddings
|
| 837 |
-
input_embeddings = self._replace_input_embeddings(
|
| 838 |
-
input_embeddings.clone(), all_actions_mask, noisy_action_features
|
| 839 |
-
)
|
| 840 |
-
|
| 841 |
-
# Build multimodal embeddings and attention mask
|
| 842 |
-
multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
|
| 843 |
-
input_embeddings, projected_patch_embeddings, attention_mask
|
| 844 |
-
)
|
| 845 |
-
|
| 846 |
-
# Forward pass through language model
|
| 847 |
-
language_model_output = self.language_model(
|
| 848 |
-
input_ids=None,
|
| 849 |
-
attention_mask=multimodal_attention_mask,
|
| 850 |
-
position_ids=None,
|
| 851 |
-
past_key_values=None,
|
| 852 |
-
inputs_embeds=multimodal_embeddings,
|
| 853 |
-
labels=None,
|
| 854 |
-
use_cache=None,
|
| 855 |
-
output_attentions=False,
|
| 856 |
-
output_hidden_states=True,
|
| 857 |
-
return_dict=True,
|
| 858 |
-
)
|
| 859 |
-
|
| 860 |
-
# Extract hidden states for action portion of response
|
| 861 |
-
last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
|
| 862 |
-
actions_hidden_states = last_hidden_states[
|
| 863 |
-
:,
|
| 864 |
-
NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
|
| 865 |
-
:,
|
| 866 |
-
] # (B, act_chunk_len, D)
|
| 867 |
-
|
| 868 |
-
# Predict noise and update noisy actions: x_t -> x_{t-1}
|
| 869 |
-
noise_pred = action_head.predict_noise(actions_hidden_states)
|
| 870 |
-
curr_noisy_actions = action_head.noise_scheduler.step(noise_pred, t, curr_noisy_actions).prev_sample
|
| 871 |
-
|
| 872 |
-
curr_noisy_actions = curr_noisy_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
|
| 873 |
-
|
| 874 |
-
# Return final actions
|
| 875 |
-
return curr_noisy_actions.float().cpu().detach().numpy(), actions_hidden_states
|
| 876 |
-
|
| 877 |
-
def _regression_or_discrete_prediction(
|
| 878 |
-
self,
|
| 879 |
-
input_embeddings,
|
| 880 |
-
all_actions_mask,
|
| 881 |
-
projected_patch_embeddings,
|
| 882 |
-
attention_mask,
|
| 883 |
-
labels,
|
| 884 |
-
NUM_PATCHES,
|
| 885 |
-
NUM_PROMPT_TOKENS,
|
| 886 |
-
action_head=None,
|
| 887 |
-
):
|
| 888 |
-
"""Run L1 regression-based continuous action prediction or discrete action tokens prediction."""
|
| 889 |
-
# Zero out action token embeddings
|
| 890 |
-
all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
|
| 891 |
-
input_embeddings = input_embeddings * ~all_actions_mask
|
| 892 |
-
|
| 893 |
-
# Build multimodal embeddings and attention mask
|
| 894 |
-
multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
|
| 895 |
-
input_embeddings, projected_patch_embeddings, attention_mask
|
| 896 |
-
)
|
| 897 |
-
|
| 898 |
-
# Forward pass through language model
|
| 899 |
-
language_model_output = self.language_model(
|
| 900 |
-
input_ids=None,
|
| 901 |
-
attention_mask=multimodal_attention_mask,
|
| 902 |
-
position_ids=None,
|
| 903 |
-
past_key_values=None,
|
| 904 |
-
inputs_embeds=multimodal_embeddings,
|
| 905 |
-
labels=None,
|
| 906 |
-
use_cache=None,
|
| 907 |
-
output_attentions=False,
|
| 908 |
-
output_hidden_states=True,
|
| 909 |
-
return_dict=True,
|
| 910 |
-
)
|
| 911 |
-
|
| 912 |
-
# Extract hidden states for action tokens
|
| 913 |
-
last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
|
| 914 |
-
actions_hidden_states = last_hidden_states[
|
| 915 |
-
:,
|
| 916 |
-
NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
|
| 917 |
-
:,
|
| 918 |
-
] # (B, act_chunk_len, D)
|
| 919 |
-
|
| 920 |
-
# Handle different prediction methods
|
| 921 |
-
if action_head is not None:
|
| 922 |
-
# L1 regression prediction
|
| 923 |
-
normalized_actions = action_head.predict_action(actions_hidden_states)
|
| 924 |
-
normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
|
| 925 |
-
normalized_actions = normalized_actions.float().cpu().detach().numpy()
|
| 926 |
-
else:
|
| 927 |
-
# Discrete token-based prediction
|
| 928 |
-
predicted_action_token_ids = (
|
| 929 |
-
language_model_output.logits[
|
| 930 |
-
:,
|
| 931 |
-
NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
|
| 932 |
-
]
|
| 933 |
-
.argmax(dim=2)
|
| 934 |
-
.cpu()
|
| 935 |
-
.numpy()
|
| 936 |
-
)
|
| 937 |
-
discretized_actions = self.vocab_size - predicted_action_token_ids
|
| 938 |
-
discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)
|
| 939 |
-
normalized_actions = self.bin_centers[discretized_actions]
|
| 940 |
-
normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
|
| 941 |
-
|
| 942 |
-
return normalized_actions, actions_hidden_states
|
| 943 |
-
|
| 944 |
-
def predict_action(
|
| 945 |
-
self,
|
| 946 |
-
input_ids: Optional[torch.LongTensor] = None,
|
| 947 |
-
unnorm_key: Optional[str] = None,
|
| 948 |
-
proprio=None,
|
| 949 |
-
proprio_projector=None,
|
| 950 |
-
action_head=None,
|
| 951 |
-
noisy_action_projector=None,
|
| 952 |
-
use_film: bool = False,
|
| 953 |
-
**kwargs: str,
|
| 954 |
-
) -> np.ndarray:
|
| 955 |
-
"""Predict actions from input sequence, with options for different prediction methods.
|
| 956 |
-
|
| 957 |
-
Args:
|
| 958 |
-
input_ids: Input token ids
|
| 959 |
-
unnorm_key: Key for unnormalization statistics
|
| 960 |
-
proprio: Proprioceptive features
|
| 961 |
-
proprio_projector: Projector for proprioceptive features
|
| 962 |
-
action_head: Optional head for L1 regression or diffusion-based prediction
|
| 963 |
-
noisy_action_projector: Projector for noisy actions in diffusion-based prediction
|
| 964 |
-
use_film: Whether to use FiLM conditioning
|
| 965 |
-
**kwargs: Additional arguments including pixel_values and attention_mask
|
| 966 |
-
|
| 967 |
-
Returns:
|
| 968 |
-
Tuple of (unnormalized_actions, action_hidden_states)
|
| 969 |
-
"""
|
| 970 |
-
# If the special empty token ('') does not already appear after the colon (':') token in the prompt
|
| 971 |
-
# (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time
|
| 972 |
-
if not torch.all(input_ids[:, -1] == 29871):
|
| 973 |
-
input_ids = torch.cat(
|
| 974 |
-
(input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
|
| 975 |
-
)
|
| 976 |
-
|
| 977 |
-
pixel_values = kwargs["pixel_values"]
|
| 978 |
-
attention_mask = kwargs["attention_mask"]
|
| 979 |
-
|
| 980 |
-
# Create fake labels tensor (needed for action mask)
|
| 981 |
-
labels = input_ids.clone()
|
| 982 |
-
labels[:] = IGNORE_INDEX
|
| 983 |
-
|
| 984 |
-
# Get number of tokens in prompt (excluding the start token)
|
| 985 |
-
NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1 # Subtract action tokens and stop token
|
| 986 |
-
|
| 987 |
-
# Prepare inputs by adding necessary tokens
|
| 988 |
-
input_ids, attention_mask = self._prepare_input_for_action_prediction(input_ids, attention_mask)
|
| 989 |
-
|
| 990 |
-
# Update labels tensor for action mask computation later
|
| 991 |
-
labels = self._prepare_labels_for_action_prediction(labels, input_ids)
|
| 992 |
-
|
| 993 |
-
# Get input embeddings and action masks
|
| 994 |
-
input_embeddings = self.get_input_embeddings()(input_ids)
|
| 995 |
-
all_actions_mask = self._process_action_masks(labels)
|
| 996 |
-
|
| 997 |
-
# Extract language embeddings
|
| 998 |
-
language_embeddings = input_embeddings[~all_actions_mask].reshape(
|
| 999 |
-
input_embeddings.shape[0], -1, input_embeddings.shape[2]
|
| 1000 |
-
)
|
| 1001 |
-
|
| 1002 |
-
# Process vision features
|
| 1003 |
-
projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)
|
| 1004 |
-
|
| 1005 |
-
# Add proprioceptive features if provided
|
| 1006 |
-
use_proprio = proprio_projector is not None and proprio is not None
|
| 1007 |
-
if use_proprio:
|
| 1008 |
-
proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype)
|
| 1009 |
-
projected_patch_embeddings = self._process_proprio_features(
|
| 1010 |
-
projected_patch_embeddings, proprio, proprio_projector
|
| 1011 |
-
)
|
| 1012 |
-
|
| 1013 |
-
# Use diffusion if provided, otherwise use regression or discrete prediction
|
| 1014 |
-
use_diffusion = noisy_action_projector is not None and hasattr(action_head, "noise_scheduler")
|
| 1015 |
-
|
| 1016 |
-
# Calculate number of patches (including proprio token and/or diffusion timestep embedding if present)
|
| 1017 |
-
NUM_PATCHES = self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input()
|
| 1018 |
-
if use_proprio:
|
| 1019 |
-
NUM_PATCHES += 1
|
| 1020 |
-
if use_diffusion:
|
| 1021 |
-
NUM_PATCHES += 1
|
| 1022 |
-
|
| 1023 |
-
if use_diffusion:
|
| 1024 |
-
# Sample random noise with shape equal to output action, used as the starting state for reverse diffusion
|
| 1025 |
-
noise = torch.randn(
|
| 1026 |
-
size=(1, NUM_ACTIONS_CHUNK, ACTION_DIM), device=input_embeddings.device, dtype=input_embeddings.dtype
|
| 1027 |
-
)
|
| 1028 |
-
|
| 1029 |
-
# Run diffusion-based prediction
|
| 1030 |
-
normalized_actions, actions_hidden_states = self._run_diffusion_prediction(
|
| 1031 |
-
input_embeddings,
|
| 1032 |
-
all_actions_mask,
|
| 1033 |
-
noise,
|
| 1034 |
-
action_head,
|
| 1035 |
-
projected_patch_embeddings,
|
| 1036 |
-
labels,
|
| 1037 |
-
attention_mask,
|
| 1038 |
-
NUM_PATCHES,
|
| 1039 |
-
NUM_PROMPT_TOKENS,
|
| 1040 |
-
noisy_action_projector,
|
| 1041 |
-
)
|
| 1042 |
-
else:
|
| 1043 |
-
# Run regression or discrete token-based prediction
|
| 1044 |
-
normalized_actions, actions_hidden_states = self._regression_or_discrete_prediction(
|
| 1045 |
-
input_embeddings,
|
| 1046 |
-
all_actions_mask,
|
| 1047 |
-
projected_patch_embeddings,
|
| 1048 |
-
attention_mask,
|
| 1049 |
-
labels,
|
| 1050 |
-
NUM_PATCHES,
|
| 1051 |
-
NUM_PROMPT_TOKENS,
|
| 1052 |
-
action_head,
|
| 1053 |
-
)
|
| 1054 |
-
|
| 1055 |
-
# Unnormalize predicted actions
|
| 1056 |
-
actions = self._unnormalize_actions(normalized_actions, unnorm_key)
|
| 1057 |
-
|
| 1058 |
-
return actions, actions_hidden_states
|
| 1059 |
-
|
| 1060 |
-
@staticmethod
|
| 1061 |
-
def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:
|
| 1062 |
-
"""Validate and resolve the unnormalization key for action statistics"""
|
| 1063 |
-
if unnorm_key is None:
|
| 1064 |
-
assert len(norm_stats) == 1, (
|
| 1065 |
-
f"Your model was trained on more than one dataset, "
|
| 1066 |
-
f"please pass a `unnorm_key` from the following options to choose the statistics "
|
| 1067 |
-
f"used for un-normalizing actions: {norm_stats.keys()}"
|
| 1068 |
-
)
|
| 1069 |
-
unnorm_key = next(iter(norm_stats.keys()))
|
| 1070 |
-
|
| 1071 |
-
assert unnorm_key in norm_stats, (
|
| 1072 |
-
f"The `unnorm_key` you chose is not in the set of available dataset statistics, "
|
| 1073 |
-
f"please choose from: {norm_stats.keys()}"
|
| 1074 |
-
)
|
| 1075 |
-
return unnorm_key
|
| 1076 |
-
|
| 1077 |
-
def get_action_dim(self, unnorm_key: Optional[str] = None) -> int:
|
| 1078 |
-
"""Get the dimensionality of the policy's action space."""
|
| 1079 |
-
unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
|
| 1080 |
-
return len(self.norm_stats[unnorm_key]["action"]["min"])
|
| 1081 |
-
|
| 1082 |
-
def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]:
|
| 1083 |
-
"""Get all the logged statistics for the given dataset."""
|
| 1084 |
-
unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
|
| 1085 |
-
return self.norm_stats[unnorm_key]["action"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/prismatic/extern/hf/processing_prismatic.py
DELETED
|
@@ -1,252 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
processing_prismatic.py
|
| 3 |
-
|
| 4 |
-
HuggingFace-style preprocessor definitions for Prismatic VLMs, inheriting from `ProcessorMixin`. Default configuration
|
| 5 |
-
specifies `siglip-224px+7b`.
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
from typing import Any, ClassVar, List, Optional, Tuple, Union
|
| 9 |
-
|
| 10 |
-
import timm.data
|
| 11 |
-
import torch
|
| 12 |
-
import torchvision.transforms.functional as TVF
|
| 13 |
-
from PIL import Image
|
| 14 |
-
from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
|
| 15 |
-
from transformers import PreTrainedTokenizerBase
|
| 16 |
-
from transformers.image_processing_utils import BatchFeature, ImageProcessingMixin
|
| 17 |
-
from transformers.processing_utils import ProcessorMixin
|
| 18 |
-
from transformers.tokenization_utils import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
| 19 |
-
from transformers.utils import TensorType
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
# === Image Processing ===
|
| 23 |
-
def letterbox_pad_transform(image: Image.Image, padding_fill_value: Tuple[int, int, int]) -> Image.Image:
|
| 24 |
-
"""Given a PIL.Image, pad to square by adding a symmetric border around the height/width."""
|
| 25 |
-
(w, h), max_wh = image.size, max(image.size)
|
| 26 |
-
horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2)
|
| 27 |
-
padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
|
| 28 |
-
|
| 29 |
-
return TVF.pad(image, padding, fill=padding_fill_value, padding_mode="constant")
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
class PrismaticImageProcessor(ImageProcessingMixin):
|
| 33 |
-
model_input_names: ClassVar[List[str]] = ["pixel_values"]
|
| 34 |
-
|
| 35 |
-
def __init__(
|
| 36 |
-
self,
|
| 37 |
-
use_fused_vision_backbone: bool = False,
|
| 38 |
-
image_resize_strategy: str = "letterbox",
|
| 39 |
-
input_sizes: Optional[List[Tuple[int, int, int]]] = None,
|
| 40 |
-
interpolations: Optional[List[str]] = None,
|
| 41 |
-
means: Optional[List[Tuple[float, float, float]]] = None,
|
| 42 |
-
stds: Optional[List[Tuple[float, float, float]]] = None,
|
| 43 |
-
**kwargs: str,
|
| 44 |
-
) -> None:
|
| 45 |
-
"""
|
| 46 |
-
Initialize a PrismaticImageProcessor as a wrapper around a torchvision transform; this transform will be
|
| 47 |
-
created by TIMM, and edited to follow our custom `image_resize_strategy` logic.
|
| 48 |
-
@param use_fused_vision_backbone: Boolean indicating single or fused (dual) vision backbone
|
| 49 |
-
@param image_resize_strategy: Prismatic image resize strategy in < resize-naive | resize-crop | letterbox >
|
| 50 |
-
@param input_size: [TIMM :: `data_cfg`] Input image size as tuple (channels, width, height)
|
| 51 |
-
@param interpolation: [TIMM :: `data_cfg`] Interpolation as string (default: "bicubic")
|
| 52 |
-
@param mean: [TIMM :: `data_cfg`] Normalization mean as float tuple (or two-tuple if `fused_backbone`)
|
| 53 |
-
@param std: [TIMM :: `data_cfg`] Normalization std as float tuple (or two-tuple if `fused_backbone`)
|
| 54 |
-
"""
|
| 55 |
-
self.use_fused_vision_backbone = use_fused_vision_backbone
|
| 56 |
-
self.image_resize_strategy = image_resize_strategy
|
| 57 |
-
|
| 58 |
-
# Handle `None` default values
|
| 59 |
-
input_sizes = [(3, 224, 224)] if input_sizes is None else input_sizes
|
| 60 |
-
means = [(0.5, 0.5, 0.5)] if means is None else means
|
| 61 |
-
stds = [(0.5, 0.5, 0.5)] if stds is None else stds
|
| 62 |
-
|
| 63 |
-
# TIMM `data_cfg` Parameters
|
| 64 |
-
self.input_sizes, self.interpolations, self.means, self.stds = input_sizes, interpolations, means, stds
|
| 65 |
-
|
| 66 |
-
# Grab torchvision transforms via TIMM =>> need to parse for specific "functional" transform values!
|
| 67 |
-
self.tvf_resize_params, self.tvf_crop_params, self.tvf_normalize_params = [], [], []
|
| 68 |
-
self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None
|
| 69 |
-
|
| 70 |
-
for idx in range(len(input_sizes)):
|
| 71 |
-
transform = timm.data.create_transform(
|
| 72 |
-
input_size=self.input_sizes[idx],
|
| 73 |
-
interpolation=self.interpolations[idx],
|
| 74 |
-
mean=self.means[idx],
|
| 75 |
-
std=self.stds[idx],
|
| 76 |
-
crop_pct=1.0, # Set to 1.0 to ignore cropping (initial Resize sets `input_size`)
|
| 77 |
-
crop_mode="center", # Default crop mode -- no-op when `crop_pct == 1.0`
|
| 78 |
-
is_training=False, # No image augmentations when loading the transform!
|
| 79 |
-
)
|
| 80 |
-
|
| 81 |
-
# [Validation] Ensure appropriate transform structure, expected sizes
|
| 82 |
-
if not (
|
| 83 |
-
isinstance(transform, Compose)
|
| 84 |
-
and (len(transform.transforms) == 4)
|
| 85 |
-
and isinstance(transform.transforms[0], Resize)
|
| 86 |
-
and isinstance(transform.transforms[1], CenterCrop)
|
| 87 |
-
and isinstance(transform.transforms[2], ToTensor)
|
| 88 |
-
and isinstance(transform.transforms[3], Normalize)
|
| 89 |
-
and (transform.transforms[0].size == self.input_sizes[idx][-1])
|
| 90 |
-
and (transform.transforms[1].size == self.input_sizes[idx][-2:])
|
| 91 |
-
):
|
| 92 |
-
raise ValueError(f"Unexpected TIMM image transformation structure/sizes: `{transform}`")
|
| 93 |
-
|
| 94 |
-
# HF Image Processors *must* be JSON-serializable; as such, cannot have torchvision. as an attribute.
|
| 95 |
-
# => Instead, we're going to parse the transform and call "torchvision.transforms.functional" (`tvf`)
|
| 96 |
-
resize_t, crop_t, norm_t = transform.transforms[0], transform.transforms[1], transform.transforms[3]
|
| 97 |
-
self.tvf_resize_params.append(
|
| 98 |
-
{
|
| 99 |
-
"size": resize_t.size,
|
| 100 |
-
"interpolation": TVF.pil_modes_mapping[resize_t.interpolation],
|
| 101 |
-
"max_size": None,
|
| 102 |
-
"antialias": True,
|
| 103 |
-
}
|
| 104 |
-
)
|
| 105 |
-
self.tvf_crop_params.append({"output_size": crop_t.size})
|
| 106 |
-
self.tvf_normalize_params.append(
|
| 107 |
-
{
|
| 108 |
-
"mean": norm_t.mean.float().numpy().tolist(),
|
| 109 |
-
"std": norm_t.std.float().numpy().tolist(),
|
| 110 |
-
"inplace": False,
|
| 111 |
-
}
|
| 112 |
-
)
|
| 113 |
-
self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None
|
| 114 |
-
|
| 115 |
-
# Handle Prismatic `image_resize_strategy`
|
| 116 |
-
if self.image_resize_strategy == "resize-naive":
|
| 117 |
-
self.tvf_resize_params[idx]["size"] = (resize_t.size, resize_t.size)
|
| 118 |
-
elif self.image_resize_strategy == "letterbox":
|
| 119 |
-
self.tvf_do_letterbox, self.tvf_letterbox_fill = True, tuple([int(x * 255) for x in self.means[idx]])
|
| 120 |
-
elif self.image_resize_strategy == "resize-crop":
|
| 121 |
-
pass
|
| 122 |
-
else:
|
| 123 |
-
raise ValueError(f"Image resize strategy `{self.image_resize_strategy}` is not supported!")
|
| 124 |
-
|
| 125 |
-
# Dispatch **kwargs to super()
|
| 126 |
-
super().__init__(**kwargs)
|
| 127 |
-
|
| 128 |
-
def apply_transform(self, img: Image.Image) -> torch.Tensor:
|
| 129 |
-
"""Apply `functional` variant of TIMM's Transform = Compose([Resize -> CenterCrop -> ToTensor -> Normalize])"""
|
| 130 |
-
if self.tvf_do_letterbox:
|
| 131 |
-
img = letterbox_pad_transform(img, self.tvf_letterbox_fill)
|
| 132 |
-
|
| 133 |
-
# [Contract] Fused Backbones expect "channel-stacked" inputs; we'll unpack on the model side!
|
| 134 |
-
imgs_t = []
|
| 135 |
-
for idx in range(len(self.input_sizes)):
|
| 136 |
-
img_idx = TVF.resize(img, **self.tvf_resize_params[idx])
|
| 137 |
-
img_idx = TVF.center_crop(img_idx, **self.tvf_crop_params[idx])
|
| 138 |
-
img_idx_t = TVF.to_tensor(img_idx)
|
| 139 |
-
img_idx_t = TVF.normalize(img_idx_t, **self.tvf_normalize_params[idx])
|
| 140 |
-
imgs_t.append(img_idx_t)
|
| 141 |
-
|
| 142 |
-
# [Contract] `imgs_t` is a list of Tensors of shape [3, input_size, input_size]; stack along dim = 0
|
| 143 |
-
img_t = torch.vstack(imgs_t)
|
| 144 |
-
|
| 145 |
-
return img_t
|
| 146 |
-
|
| 147 |
-
def preprocess(
|
| 148 |
-
self,
|
| 149 |
-
images: Union[Image.Image, List[Image.Image]],
|
| 150 |
-
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 151 |
-
**_: str,
|
| 152 |
-
) -> BatchFeature:
|
| 153 |
-
"""
|
| 154 |
-
Preprocess an image (or batch of images); note that unlike the `transformers :: BaseImageProcessor` we
|
| 155 |
-
explicitly only handle PIL.Image.Image instances for simplicity.
|
| 156 |
-
@param images: A (batch of) PIL.Image.Image instance(s) to preprocess.
|
| 157 |
-
@param return_tensors: BatchFeature default Tensor format (e.g., "pt" for torch); if None, returns np.ndarray
|
| 158 |
-
@return: Instance of `transformers :: BatchFeature` with a single key "pixel_values"
|
| 159 |
-
"""
|
| 160 |
-
if not isinstance(images, list):
|
| 161 |
-
images = [images]
|
| 162 |
-
|
| 163 |
-
# Apply `self.img_transform` to each image (will return list of torch.Tensors); stack into "batched" Tensor
|
| 164 |
-
pixel_values = torch.stack([self.apply_transform(img.convert("RGB")) for img in images])
|
| 165 |
-
|
| 166 |
-
# Return BatchFeature =>> note that for compatibility, constructor expects Dict[str, np.ndarray], so we convert
|
| 167 |
-
return BatchFeature(data={"pixel_values": pixel_values.float().numpy()}, tensor_type=return_tensors)
|
| 168 |
-
|
| 169 |
-
def __call__(self, images: Union[Image.Image, List[Image.Image]], **kwargs) -> BatchFeature:
|
| 170 |
-
return self.preprocess(images, **kwargs)
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
# === PrismaticProcessor =>> Wraps both ImageProcessor and Tokenizer ===
|
| 174 |
-
# =>> https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/processing_llava.py
|
| 175 |
-
class PrismaticProcessor(ProcessorMixin):
|
| 176 |
-
attributes: ClassVar[List[str]] = ["image_processor", "tokenizer"]
|
| 177 |
-
image_processor_class: str = "AutoImageProcessor"
|
| 178 |
-
tokenizer_class: str = "AutoTokenizer"
|
| 179 |
-
|
| 180 |
-
def __init__(
|
| 181 |
-
self,
|
| 182 |
-
image_processor: Optional[ImageProcessingMixin] = None,
|
| 183 |
-
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
| 184 |
-
) -> None:
|
| 185 |
-
super().__init__(image_processor, tokenizer)
|
| 186 |
-
|
| 187 |
-
def __call__(
|
| 188 |
-
self,
|
| 189 |
-
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
|
| 190 |
-
images: Union[Image.Image, List[Image.Image]],
|
| 191 |
-
padding: Union[bool, str, PaddingStrategy] = False,
|
| 192 |
-
truncation: Optional[Union[bool, str, TruncationStrategy]] = None,
|
| 193 |
-
max_length: Optional[int] = None,
|
| 194 |
-
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
|
| 195 |
-
) -> BatchFeature:
|
| 196 |
-
"""
|
| 197 |
-
Preprocess a given (batch) of text/images for a Prismatic VLM; forwards text to the underlying LLM's tokenizer,
|
| 198 |
-
forwards images to PrismaticImageProcessor.
|
| 199 |
-
@param text: The (batch) of text to encode; must be a string or list of strings.
|
| 200 |
-
@param images: A (batch of) PIL.Image.Image instance(s) to preprocess.
|
| 201 |
-
@param padding: Sequence padding strategy (if multiple specified) in < True = "longest" | "max_length" | False >
|
| 202 |
-
@param truncation: Truncation strategy for the output sequences; requires `max_length` to be specified
|
| 203 |
-
@param max_length: Maximum length (in tokens) to truncate
|
| 204 |
-
@param return_tensors: Type of return tensors (usually "pt" or TensorType.PYTORCH)
|
| 205 |
-
@return: BatchFeature with keys for `input_ids`, `attention_mask` and `pixel_values`.
|
| 206 |
-
"""
|
| 207 |
-
pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"]
|
| 208 |
-
text_inputs = self.tokenizer(
|
| 209 |
-
text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
|
| 210 |
-
)
|
| 211 |
-
|
| 212 |
-
# [Validate] Need same number of images and text inputs!
|
| 213 |
-
if pixel_values.shape[0] != text_inputs.input_ids.shape[0]:
|
| 214 |
-
raise ValueError("Batch is malformed; expected same number of images and text inputs!")
|
| 215 |
-
|
| 216 |
-
return BatchFeature(data={**text_inputs, "pixel_values": pixel_values})
|
| 217 |
-
|
| 218 |
-
# === Tokenizer Dispatch Utilities =>> check `PreTrainedTokenizerBase` for documentation ===
|
| 219 |
-
def batch_decode(
|
| 220 |
-
self,
|
| 221 |
-
sequences: Union[List[int], List[List[int]], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor
|
| 222 |
-
skip_special_tokens: bool = False,
|
| 223 |
-
clean_up_tokenization_spaces: Optional[bool] = None,
|
| 224 |
-
**kwargs: str,
|
| 225 |
-
) -> List[str]:
|
| 226 |
-
return self.tokenizer.batch_decode(
|
| 227 |
-
sequences=sequences,
|
| 228 |
-
skip_special_tokens=skip_special_tokens,
|
| 229 |
-
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 230 |
-
**kwargs,
|
| 231 |
-
)
|
| 232 |
-
|
| 233 |
-
def decode(
|
| 234 |
-
self,
|
| 235 |
-
token_ids: Union[int, List[int], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor
|
| 236 |
-
skip_special_tokens: bool = False,
|
| 237 |
-
clean_up_tokenization_spaces: Optional[bool] = None,
|
| 238 |
-
**kwargs: str,
|
| 239 |
-
) -> str:
|
| 240 |
-
return self.tokenizer.decode(
|
| 241 |
-
token_ids=token_ids,
|
| 242 |
-
skip_special_tokens=skip_special_tokens,
|
| 243 |
-
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 244 |
-
**kwargs,
|
| 245 |
-
)
|
| 246 |
-
|
| 247 |
-
@property
|
| 248 |
-
def model_input_names(self) -> List[str]:
|
| 249 |
-
tokenizer_input_names = self.tokenizer.model_input_names
|
| 250 |
-
image_processor_input_names = self.image_processor.model_input_names
|
| 251 |
-
|
| 252 |
-
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/prismatic/models/__init__.py
DELETED
|
@@ -1,2 +0,0 @@
|
|
| 1 |
-
from .load import available_model_names, available_models, get_model_description, load, load_vla
|
| 2 |
-
from .materialize import get_llm_backbone_and_tokenizer, get_vision_backbone_and_transform, get_vlm
|
|
|
|
|
|
|
|
|
capvector-oft/prismatic/models/action_heads.py
DELETED
|
@@ -1,211 +0,0 @@
|
|
| 1 |
-
"""Implementations of various action heads, which serve as alternatives to VLM sequential token prediction."""
|
| 2 |
-
|
| 3 |
-
import math
|
| 4 |
-
|
| 5 |
-
import numpy as np
|
| 6 |
-
import torch
|
| 7 |
-
import torch.nn as nn
|
| 8 |
-
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
| 9 |
-
from prismatic.vla.constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX, NUM_ACTIONS_CHUNK, PROPRIO_DIM, STOP_INDEX
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class SinusoidalPositionalEncoding(nn.Module):
|
| 13 |
-
"""
|
| 14 |
-
Sine- and cosine-based positional encoding that produces embeddings of a batch of timesteps.
|
| 15 |
-
|
| 16 |
-
For example, at train time, the input might be a batch of 32 randomly sampled diffusion timesteps -> shape (32,)
|
| 17 |
-
Then the output would be a batch of 32 timestep embeddings -> shape (32, D)
|
| 18 |
-
|
| 19 |
-
Adapted from: https://github.com/real-stanford/diffusion_policy/blob/main/diffusion_policy/model/diffusion/positional_embedding.py
|
| 20 |
-
"""
|
| 21 |
-
|
| 22 |
-
def __init__(self, dim):
|
| 23 |
-
super().__init__()
|
| 24 |
-
self.dim = dim # dimensionality of the positional encoding
|
| 25 |
-
|
| 26 |
-
def forward(self, x):
|
| 27 |
-
# x: (batch_size,)
|
| 28 |
-
device = x.device
|
| 29 |
-
assert self.dim % 2 == 0, f"# dimensions must be even but got {self.dim}"
|
| 30 |
-
half_dim = self.dim // 2
|
| 31 |
-
exponent = torch.arange(half_dim, device=device) * -math.log(10000) / (half_dim - 1) # shape: (D/2,)
|
| 32 |
-
emb = torch.exp(exponent) # shape: (D/2,)
|
| 33 |
-
emb = x[:, None] * emb[None, :] # shape: (batch_size, 1) * (1, D/2) -> (batch_size, D/2)
|
| 34 |
-
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) # shape: (batch_size, D)
|
| 35 |
-
return emb
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
class MLPResNetBlock(nn.Module):
|
| 39 |
-
"""One MLP ResNet block with a residual connection."""
|
| 40 |
-
def __init__(self, dim):
|
| 41 |
-
super().__init__()
|
| 42 |
-
self.dim = dim
|
| 43 |
-
self.ffn = nn.Sequential( # feedforward network, similar to the ones in Transformers
|
| 44 |
-
nn.LayerNorm(dim),
|
| 45 |
-
nn.Linear(dim, dim),
|
| 46 |
-
nn.ReLU(),
|
| 47 |
-
)
|
| 48 |
-
|
| 49 |
-
def forward(self, x):
|
| 50 |
-
# x: (batch_size, hidden_dim)
|
| 51 |
-
# We follow the module ordering of "Pre-Layer Normalization" feedforward networks in Transformers as
|
| 52 |
-
# described here: https://arxiv.org/pdf/2002.04745.pdf
|
| 53 |
-
identity = x
|
| 54 |
-
x = self.ffn(x)
|
| 55 |
-
x = x + identity
|
| 56 |
-
return x
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
class MLPResNet(nn.Module):
|
| 60 |
-
"""MLP with residual connection blocks."""
|
| 61 |
-
def __init__(self, num_blocks, input_dim, hidden_dim, output_dim):
|
| 62 |
-
super().__init__()
|
| 63 |
-
self.layer_norm1 = nn.LayerNorm(input_dim)
|
| 64 |
-
self.fc1 = nn.Linear(input_dim, hidden_dim)
|
| 65 |
-
self.relu = nn.ReLU()
|
| 66 |
-
self.mlp_resnet_blocks = nn.ModuleList()
|
| 67 |
-
for _ in range(num_blocks):
|
| 68 |
-
self.mlp_resnet_blocks.append(MLPResNetBlock(dim=hidden_dim))
|
| 69 |
-
self.layer_norm2 = nn.LayerNorm(hidden_dim)
|
| 70 |
-
self.fc2 = nn.Linear(hidden_dim, output_dim)
|
| 71 |
-
|
| 72 |
-
def forward(self, x):
|
| 73 |
-
# x: (batch_size, input_dim)
|
| 74 |
-
x = self.layer_norm1(x) # shape: (batch_size, input_dim)
|
| 75 |
-
x = self.fc1(x) # shape: (batch_size, hidden_dim)
|
| 76 |
-
x = self.relu(x) # shape: (batch_size, hidden_dim)
|
| 77 |
-
for block in self.mlp_resnet_blocks:
|
| 78 |
-
x = block(x) # shape: (batch_size, hidden_dim)
|
| 79 |
-
x = self.layer_norm2(x) # shape: (batch_size, hidden_dim)
|
| 80 |
-
x = self.fc2(x) # shape: (batch_size, output_dim)
|
| 81 |
-
return x
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
class L1RegressionActionHead(nn.Module):
|
| 85 |
-
"""Simple MLP-based action head that generates continuous actions via L1 regression."""
|
| 86 |
-
def __init__(
|
| 87 |
-
self,
|
| 88 |
-
input_dim=4096,
|
| 89 |
-
hidden_dim=4096,
|
| 90 |
-
action_dim=7,
|
| 91 |
-
):
|
| 92 |
-
super().__init__()
|
| 93 |
-
self.action_dim = action_dim
|
| 94 |
-
self.model = MLPResNet(
|
| 95 |
-
num_blocks=2, input_dim=input_dim*ACTION_DIM, hidden_dim=hidden_dim, output_dim=action_dim
|
| 96 |
-
)
|
| 97 |
-
|
| 98 |
-
def predict_action(self, actions_hidden_states):
|
| 99 |
-
# actions_hidden_states: last hidden states of Transformer corresponding to action tokens in sequence
|
| 100 |
-
# - shape: (batch_size, chunk_len * action_dim, hidden_dim)
|
| 101 |
-
# ground_truth_actions: ground-truth actions
|
| 102 |
-
# - shape: (batch_size, chunk_len, action_dim)
|
| 103 |
-
batch_size = actions_hidden_states.shape[0]
|
| 104 |
-
device = actions_hidden_states.device
|
| 105 |
-
rearranged_actions_hidden_states = actions_hidden_states.reshape(batch_size, NUM_ACTIONS_CHUNK, -1)
|
| 106 |
-
action = self.model(rearranged_actions_hidden_states)
|
| 107 |
-
return action
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
class NoisePredictionModel(nn.Module):
|
| 111 |
-
"""
|
| 112 |
-
Diffusion noise prediction model that takes an observation embedding (which fuses the
|
| 113 |
-
noisy action, diffusion timestep, and image-language observation embeddings) and
|
| 114 |
-
outputs a noise prediction.
|
| 115 |
-
"""
|
| 116 |
-
|
| 117 |
-
def __init__(
|
| 118 |
-
self,
|
| 119 |
-
transformer_hidden_dim, # Transformer hidden embedding size
|
| 120 |
-
hidden_dim, # MLP hidden size
|
| 121 |
-
action_dim=7, # action dimensionality
|
| 122 |
-
):
|
| 123 |
-
super().__init__()
|
| 124 |
-
self.mlp_resnet = MLPResNet(
|
| 125 |
-
num_blocks=2,
|
| 126 |
-
input_dim=transformer_hidden_dim,
|
| 127 |
-
hidden_dim=hidden_dim,
|
| 128 |
-
output_dim=action_dim,
|
| 129 |
-
)
|
| 130 |
-
|
| 131 |
-
def forward(
|
| 132 |
-
self,
|
| 133 |
-
obs,
|
| 134 |
-
):
|
| 135 |
-
# obs: observation embeddings to condition the generation on
|
| 136 |
-
# - shape: (batch_size, chunk_len, rearranged_hidden_dim=action_dim*hidden_dim)
|
| 137 |
-
#
|
| 138 |
-
# output: predicted noise
|
| 139 |
-
# - shape: (batch_size, action_dim)
|
| 140 |
-
output = self.mlp_resnet(obs)
|
| 141 |
-
return output
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
class DiffusionActionHead(nn.Module):
|
| 145 |
-
"""
|
| 146 |
-
Simple MLP-based action head that generates continuous actions via conditional denoising diffusion process.
|
| 147 |
-
|
| 148 |
-
Loosely inspired by: https://github.com/real-stanford/diffusion_policy/blob/main/diffusion_policy/model/diffusion/transformer_for_diffusion.py
|
| 149 |
-
"""
|
| 150 |
-
|
| 151 |
-
def __init__(
|
| 152 |
-
self,
|
| 153 |
-
input_dim=4096,
|
| 154 |
-
hidden_dim=4096,
|
| 155 |
-
action_dim=7,
|
| 156 |
-
num_diffusion_steps_train=50,
|
| 157 |
-
):
|
| 158 |
-
super().__init__()
|
| 159 |
-
self.action_dim = action_dim
|
| 160 |
-
self.noise_predictor = NoisePredictionModel(
|
| 161 |
-
transformer_hidden_dim=hidden_dim*ACTION_DIM, hidden_dim=hidden_dim, action_dim=action_dim
|
| 162 |
-
)
|
| 163 |
-
self.num_diffusion_steps_train = num_diffusion_steps_train
|
| 164 |
-
self.noise_scheduler = DDIMScheduler(num_train_timesteps=num_diffusion_steps_train, beta_schedule="squaredcos_cap_v2")
|
| 165 |
-
self.time_encoder = SinusoidalPositionalEncoding(dim=hidden_dim)
|
| 166 |
-
|
| 167 |
-
def sample_noisy_actions(self, ground_truth_actions):
|
| 168 |
-
"""
|
| 169 |
-
Samples noise and applies noise to ground-truth actions to produce noisy actions, which are
|
| 170 |
-
used as input in the noise prediction network. Returns noise, noisy actions, and the
|
| 171 |
-
corresponding diffusion timestep embeddings.
|
| 172 |
-
"""
|
| 173 |
-
# ground_truth_actions: ground-truth actions
|
| 174 |
-
# - shape: (batch_size, chunk_len, action_dim)
|
| 175 |
-
batch_size = ground_truth_actions.shape[0]
|
| 176 |
-
device = ground_truth_actions.device
|
| 177 |
-
# Sample random noise with shape equal to actions, used for closed-form forward diffusion.
|
| 178 |
-
noise = torch.randn(size=(batch_size, NUM_ACTIONS_CHUNK, ACTION_DIM), device=device, dtype=ground_truth_actions.dtype) # (B, chunk_len, action_dim)
|
| 179 |
-
# Sample random diffusion timesteps (one for each action in batch).
|
| 180 |
-
timesteps = torch.randint(
|
| 181 |
-
low=0, high=self.noise_scheduler.config.num_train_timesteps, size=(batch_size,), device=device
|
| 182 |
-
)
|
| 183 |
-
# Add noise to clean actions according to the magnitude at each diffusion timestep via
|
| 184 |
-
# closed-form forward diffusion.
|
| 185 |
-
noisy_actions = self.noise_scheduler.add_noise(ground_truth_actions, noise, timesteps) # (B, chunk_len, action_dim)
|
| 186 |
-
|
| 187 |
-
# Get diffusion timestep embeddings as well
|
| 188 |
-
diffusion_timestep_embeddings = self.time_encoder(timesteps).to(noisy_actions.dtype).to(noisy_actions.device) # (B, llm_dim)
|
| 189 |
-
diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) # (B, 1, llm_dim)
|
| 190 |
-
|
| 191 |
-
return_dict = dict(
|
| 192 |
-
noise=noise,
|
| 193 |
-
noisy_actions=noisy_actions,
|
| 194 |
-
diffusion_timestep_embeddings=diffusion_timestep_embeddings,
|
| 195 |
-
)
|
| 196 |
-
|
| 197 |
-
return return_dict
|
| 198 |
-
|
| 199 |
-
def predict_noise(self, actions_hidden_states):
|
| 200 |
-
"""
|
| 201 |
-
Given a batch of last hidden Transformer layer embeddings (which fuse the vision-language observation embeddings,
|
| 202 |
-
noisy action embeddings, and diffusion timestep embedding), predicts the noise applied to the actions.
|
| 203 |
-
"""
|
| 204 |
-
# actions_hidden_states: last hidden states of Transformer corresponding to action tokens in sequence
|
| 205 |
-
# - shape: (batch_size, chunk_len * action_dim, hidden_dim)
|
| 206 |
-
batch_size = actions_hidden_states.shape[0]
|
| 207 |
-
device = actions_hidden_states.device
|
| 208 |
-
rearranged_actions_hidden_states = actions_hidden_states.reshape(batch_size, NUM_ACTIONS_CHUNK, -1) # (batch_size, chunk_len, action_dim * hidden_dim)
|
| 209 |
-
# Get diffusion model's noise prediction.
|
| 210 |
-
noise_pred = self.noise_predictor(rearranged_actions_hidden_states)
|
| 211 |
-
return noise_pred
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/prismatic/models/backbones/__init__.py
DELETED
|
File without changes
|
capvector-oft/prismatic/models/backbones/llm/__init__.py
DELETED
|
@@ -1,4 +0,0 @@
|
|
| 1 |
-
from .base_llm import LLMBackbone
|
| 2 |
-
from .llama2 import LLaMa2LLMBackbone
|
| 3 |
-
from .mistral import MistralLLMBackbone
|
| 4 |
-
from .phi import PhiLLMBackbone
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/prismatic/models/backbones/llm/base_llm.py
DELETED
|
@@ -1,223 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
base_llm.py
|
| 3 |
-
|
| 4 |
-
Abstract class definition of a large (autoregressive) language model backbone (LLM), with full annotations of class
|
| 5 |
-
methods, utility functions, and initialization logic.
|
| 6 |
-
|
| 7 |
-
We also define the generic HFLLMBackbone class here, providing a default interface for loading any HF
|
| 8 |
-
AutoModelForCausalLM (e.g., LLamaForCausalLM). In general, we make the assumption that any given LLM backbone implements
|
| 9 |
-
the AutoModelForCausalLM API (though we may add Seq2Seq models in the future).
|
| 10 |
-
|
| 11 |
-
We make this assumption to keep the LLM handling in this codebase relatively lightweight, and to inherit all the nice HF
|
| 12 |
-
utilities around different types of decoding/generation strategies.
|
| 13 |
-
"""
|
| 14 |
-
|
| 15 |
-
import warnings
|
| 16 |
-
from abc import ABC, abstractmethod
|
| 17 |
-
from functools import partial
|
| 18 |
-
from typing import Callable, List, Optional, Sequence, Type
|
| 19 |
-
|
| 20 |
-
import torch
|
| 21 |
-
import torch.nn as nn
|
| 22 |
-
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
| 23 |
-
from transformers import AutoConfig, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase
|
| 24 |
-
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 25 |
-
|
| 26 |
-
from prismatic.models.backbones.llm.prompting import PromptBuilder
|
| 27 |
-
from prismatic.overwatch import initialize_overwatch
|
| 28 |
-
|
| 29 |
-
# Suppress HF Deprecation Warnings
|
| 30 |
-
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 31 |
-
|
| 32 |
-
# Initialize Overwatch =>> Wraps `logging.Logger`
|
| 33 |
-
overwatch = initialize_overwatch(__name__)
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
# === Abstract Base Class for arbitrary HF LLM Backbones ===
|
| 37 |
-
class LLMBackbone(nn.Module, ABC):
|
| 38 |
-
def __init__(self, llm_backbone_id: str) -> None:
|
| 39 |
-
super().__init__()
|
| 40 |
-
self.identifier = llm_backbone_id
|
| 41 |
-
|
| 42 |
-
# Instance attributes for an LLM Backbone
|
| 43 |
-
self.llm: PreTrainedModel = None
|
| 44 |
-
self.tokenizer: PreTrainedTokenizerBase = None
|
| 45 |
-
|
| 46 |
-
def get_tokenizer(self) -> PreTrainedTokenizerBase:
|
| 47 |
-
return self.tokenizer
|
| 48 |
-
|
| 49 |
-
@abstractmethod
|
| 50 |
-
def get_fsdp_wrapping_policy(self) -> Callable: ...
|
| 51 |
-
|
| 52 |
-
@abstractmethod
|
| 53 |
-
def enable_gradient_checkpointing(self) -> None: ...
|
| 54 |
-
|
| 55 |
-
@abstractmethod
|
| 56 |
-
def forward(
|
| 57 |
-
self,
|
| 58 |
-
input_ids: Optional[torch.LongTensor] = None,
|
| 59 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 60 |
-
position_ids: Optional[torch.LongTensor] = None,
|
| 61 |
-
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 62 |
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 63 |
-
labels: Optional[torch.LongTensor] = None,
|
| 64 |
-
use_cache: Optional[bool] = None,
|
| 65 |
-
output_attentions: Optional[bool] = None,
|
| 66 |
-
output_hidden_states: Optional[bool] = None,
|
| 67 |
-
return_dict: Optional[bool] = None,
|
| 68 |
-
) -> CausalLMOutputWithPast:
|
| 69 |
-
"""Run a forward pass through the LLM given targets (labels), returning the scalar Cross-Entropy Loss"""
|
| 70 |
-
raise NotImplementedError
|
| 71 |
-
|
| 72 |
-
@abstractmethod
|
| 73 |
-
def embed_input_ids(self, input_ids: torch.LongTensor) -> torch.Tensor: ...
|
| 74 |
-
|
| 75 |
-
@property
|
| 76 |
-
@abstractmethod
|
| 77 |
-
def prompt_builder_fn(self) -> Type[PromptBuilder]: ...
|
| 78 |
-
|
| 79 |
-
@property
|
| 80 |
-
@abstractmethod
|
| 81 |
-
def transformer_layer_cls(self) -> Type[nn.Module]: ...
|
| 82 |
-
|
| 83 |
-
@property
|
| 84 |
-
@abstractmethod
|
| 85 |
-
def half_precision_dtype(self) -> torch.dtype: ...
|
| 86 |
-
|
| 87 |
-
@property
|
| 88 |
-
@abstractmethod
|
| 89 |
-
def last_layer_finetune_modules(self) -> Sequence[nn.Module]: ...
|
| 90 |
-
|
| 91 |
-
@property
|
| 92 |
-
def embed_dim(self) -> int:
|
| 93 |
-
return self.llm.config.hidden_size
|
| 94 |
-
|
| 95 |
-
@property
|
| 96 |
-
def pad_token_id(self) -> int:
|
| 97 |
-
return self.tokenizer.pad_token_id
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
# === Abstract Base Class for Arbitrary HF Causal LLMs ===
|
| 101 |
-
class HFCausalLLMBackbone(LLMBackbone, ABC):
|
| 102 |
-
def __init__(
|
| 103 |
-
self,
|
| 104 |
-
llm_backbone_id: str,
|
| 105 |
-
llm_family: str,
|
| 106 |
-
llm_cls: Type[PreTrainedModel],
|
| 107 |
-
hf_hub_path: str,
|
| 108 |
-
llm_max_length: int = 2048,
|
| 109 |
-
hf_token: Optional[str] = None,
|
| 110 |
-
inference_mode: bool = False,
|
| 111 |
-
use_flash_attention_2: bool = False,
|
| 112 |
-
) -> None:
|
| 113 |
-
super().__init__(llm_backbone_id)
|
| 114 |
-
self.llm_family = llm_family
|
| 115 |
-
self.llm_max_length = llm_max_length
|
| 116 |
-
self.inference_mode = inference_mode
|
| 117 |
-
|
| 118 |
-
# Initialize LLM (downloading from HF Hub if necessary) --> `llm_cls` is the actual {Model}ForCausalLM class!
|
| 119 |
-
# => Note: We're eschewing use of the AutoModel API so that we can be more explicit about LLM-specific details
|
| 120 |
-
if not self.inference_mode:
|
| 121 |
-
overwatch.info(f"Loading [bold]{llm_family}[/] LLM from [underline]`{hf_hub_path}`[/]", ctx_level=1)
|
| 122 |
-
self.llm = llm_cls.from_pretrained(
|
| 123 |
-
hf_hub_path,
|
| 124 |
-
token=hf_token,
|
| 125 |
-
use_flash_attention_2=use_flash_attention_2 if not self.inference_mode else False,
|
| 126 |
-
# The following parameters are set to prevent `UserWarnings` from HF; we want greedy decoding!
|
| 127 |
-
do_sample=False,
|
| 128 |
-
temperature=1.0,
|
| 129 |
-
top_p=1.0,
|
| 130 |
-
)
|
| 131 |
-
|
| 132 |
-
# [Contract] `inference_mode` means we're loading from a pretrained checkpoint; no need to load base weights!
|
| 133 |
-
else:
|
| 134 |
-
overwatch.info(f"Building empty [bold]{llm_family}[/] LLM from [underline]`{hf_hub_path}`[/]", ctx_level=1)
|
| 135 |
-
llm_config = AutoConfig.from_pretrained(hf_hub_path, token=hf_token)
|
| 136 |
-
self.llm = llm_cls._from_config(llm_config)
|
| 137 |
-
|
| 138 |
-
# Lightweight Handling (with extended explanation) for setting some LLM Parameters
|
| 139 |
-
# => Set `decoder.use_cache = False` --> incompatible with gradient checkpointing (+ training in general)
|
| 140 |
-
#
|
| 141 |
-
# Reference: https://discuss.huggingface.co/t/what-is-the-purpose-of-use-cache-in-decoder/958
|
| 142 |
-
self.llm.config.use_cache = False if not self.inference_mode else True
|
| 143 |
-
|
| 144 |
-
# => Turns out that when gradient checkpointing is on and the underlying LLM has no "trainable" parameters
|
| 145 |
-
# (requires_grad is False), backprop will fail; setting `enable_input_requires_grad()` registers a new
|
| 146 |
-
# forward hook that fixes this =>> also totally safe for the "full finetuning" setting!
|
| 147 |
-
if not self.inference_mode:
|
| 148 |
-
self.llm.enable_input_require_grads()
|
| 149 |
-
|
| 150 |
-
# Load (Fast) Tokenizer
|
| 151 |
-
overwatch.info(f"Loading [bold]{llm_family}[/] (Fast) Tokenizer via the AutoTokenizer API", ctx_level=1)
|
| 152 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 153 |
-
hf_hub_path, model_max_length=self.llm_max_length, token=hf_token, padding_side="right"
|
| 154 |
-
)
|
| 155 |
-
|
| 156 |
-
# Validation =>> Our VLM logic currently operates under the assumption that the tokenization of a new input
|
| 157 |
-
# starts with a <BOS> token unless `add_special_tokens = False`; for these models, we empirically
|
| 158 |
-
# find that adding image patches *after* the BOS leads to much better performance.
|
| 159 |
-
#
|
| 160 |
-
# As a result we explicitly validate that a tokenizer conforms to the expected behavior; if you're reading this
|
| 161 |
-
# line, it's probably because you're adding a new LLM with a different tokenizer behavior. If so, feel free to
|
| 162 |
-
# override the `SPECIAL_CASES` set below, but make sure to make the appropriate changes in the `datasets.py`
|
| 163 |
-
# and VLM `forward()` logic!
|
| 164 |
-
SPECIAL_CASES = {
|
| 165 |
-
# Phi-2 Tokenizer doesn't add any BOS tokens by default, and sets BOS == EOS == "<|endoftext|>"
|
| 166 |
-
# =>> We'll prepend BOS to first input (to play nicely with image token insertion logic; verified that
|
| 167 |
-
# this works well with base LLM generation.
|
| 168 |
-
# =>> Like Llama-2 Tokenizers -- we'll add a special PAD token for training purposes.
|
| 169 |
-
"phi-2-3b",
|
| 170 |
-
}
|
| 171 |
-
if self.identifier in SPECIAL_CASES:
|
| 172 |
-
return
|
| 173 |
-
|
| 174 |
-
# Note =>> this assert should hold for all Llama-derived tokenizers (`LlamaTokenizerFast` ==> includes Mistral!
|
| 175 |
-
assert (self.tokenizer("Test 123", add_special_tokens=True).input_ids[0] == self.tokenizer.bos_token_id) and (
|
| 176 |
-
self.tokenizer("Test 123", add_special_tokens=False).input_ids[0] != self.tokenizer.bos_token_id
|
| 177 |
-
), (
|
| 178 |
-
f"Default Tokenizer of type `{type(self.tokenizer)}` does not automatically prefix inputs with BOS token!\n"
|
| 179 |
-
"Please read the comment in `base_llm.py` for more information!"
|
| 180 |
-
)
|
| 181 |
-
|
| 182 |
-
def get_fsdp_wrapping_policy(self) -> Callable:
|
| 183 |
-
"""Return a `transformer_auto_wrap_policy` where we wrap each instance of `self.transformer_layer_cls`"""
|
| 184 |
-
transformer_block_policy = partial(
|
| 185 |
-
transformer_auto_wrap_policy, transformer_layer_cls={self.transformer_layer_cls}
|
| 186 |
-
)
|
| 187 |
-
|
| 188 |
-
return transformer_block_policy
|
| 189 |
-
|
| 190 |
-
def enable_gradient_checkpointing(self) -> None:
|
| 191 |
-
"""Dispatch to underlying LLM instance's `gradient_checkpointing_enable`; defined for all `PretrainedModel`."""
|
| 192 |
-
self.llm.gradient_checkpointing_enable()
|
| 193 |
-
|
| 194 |
-
def embed_input_ids(self, input_ids: torch.LongTensor) -> torch.Tensor:
|
| 195 |
-
return self.llm.get_input_embeddings()(input_ids)
|
| 196 |
-
|
| 197 |
-
# [Contract] Should match the `forward` call of the underlying `llm` instance!
|
| 198 |
-
def forward(
|
| 199 |
-
self,
|
| 200 |
-
input_ids: Optional[torch.LongTensor] = None,
|
| 201 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 202 |
-
position_ids: Optional[torch.LongTensor] = None,
|
| 203 |
-
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 204 |
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 205 |
-
labels: Optional[torch.LongTensor] = None,
|
| 206 |
-
use_cache: Optional[bool] = None,
|
| 207 |
-
output_attentions: Optional[bool] = None,
|
| 208 |
-
output_hidden_states: Optional[bool] = None,
|
| 209 |
-
return_dict: Optional[bool] = None,
|
| 210 |
-
) -> CausalLMOutputWithPast:
|
| 211 |
-
output: CausalLMOutputWithPast = self.llm(
|
| 212 |
-
input_ids=input_ids,
|
| 213 |
-
attention_mask=attention_mask,
|
| 214 |
-
position_ids=position_ids,
|
| 215 |
-
past_key_values=past_key_values,
|
| 216 |
-
inputs_embeds=inputs_embeds,
|
| 217 |
-
labels=labels,
|
| 218 |
-
use_cache=use_cache,
|
| 219 |
-
output_attentions=output_attentions,
|
| 220 |
-
output_hidden_states=output_hidden_states,
|
| 221 |
-
return_dict=return_dict,
|
| 222 |
-
)
|
| 223 |
-
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/prismatic/models/backbones/llm/llama2.py
DELETED
|
@@ -1,102 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
llama2.py
|
| 3 |
-
|
| 4 |
-
Class definition for all LLMs derived from LlamaForCausalLM.
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
from typing import Optional, Sequence, Type
|
| 8 |
-
|
| 9 |
-
import torch
|
| 10 |
-
from torch import nn as nn
|
| 11 |
-
from transformers import LlamaForCausalLM
|
| 12 |
-
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
| 13 |
-
|
| 14 |
-
from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone
|
| 15 |
-
from prismatic.models.backbones.llm.prompting import (
|
| 16 |
-
LLaMa2ChatPromptBuilder,
|
| 17 |
-
PromptBuilder,
|
| 18 |
-
PurePromptBuilder,
|
| 19 |
-
VicunaV15ChatPromptBuilder,
|
| 20 |
-
)
|
| 21 |
-
|
| 22 |
-
# Registry =>> Support LLaMa-2 Models (from HF Transformers)
|
| 23 |
-
# fmt: off
|
| 24 |
-
LLAMA2_MODELS = {
|
| 25 |
-
# === Pure Meta LLaMa-2 (non-instruct/chat-tuned) Models ===
|
| 26 |
-
"llama2-7b-pure": {
|
| 27 |
-
"llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-7b-hf"
|
| 28 |
-
},
|
| 29 |
-
|
| 30 |
-
"llama2-13b-pure": {
|
| 31 |
-
"llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-13b-hf"
|
| 32 |
-
},
|
| 33 |
-
|
| 34 |
-
# === Meta LLaMa-2 Chat Models ===
|
| 35 |
-
"llama2-7b-chat": {
|
| 36 |
-
"llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-7b-chat-hf"
|
| 37 |
-
},
|
| 38 |
-
|
| 39 |
-
"llama2-13b-chat": {
|
| 40 |
-
"llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-13b-chat-hf"
|
| 41 |
-
},
|
| 42 |
-
|
| 43 |
-
# === Vicuna v1.5 Chat Models ===
|
| 44 |
-
"vicuna-v15-7b": {
|
| 45 |
-
"llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "lmsys/vicuna-7b-v1.5"
|
| 46 |
-
},
|
| 47 |
-
|
| 48 |
-
"vicuna-v15-13b": {
|
| 49 |
-
"llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "lmsys/vicuna-13b-v1.5"
|
| 50 |
-
},
|
| 51 |
-
}
|
| 52 |
-
# fmt: on
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
class LLaMa2LLMBackbone(HFCausalLLMBackbone):
|
| 56 |
-
def __init__(
|
| 57 |
-
self,
|
| 58 |
-
llm_backbone_id: str,
|
| 59 |
-
llm_max_length: int = 2048,
|
| 60 |
-
hf_token: Optional[str] = None,
|
| 61 |
-
inference_mode: bool = False,
|
| 62 |
-
use_flash_attention_2: bool = True,
|
| 63 |
-
) -> None:
|
| 64 |
-
super().__init__(
|
| 65 |
-
llm_backbone_id,
|
| 66 |
-
llm_max_length=llm_max_length,
|
| 67 |
-
hf_token=hf_token,
|
| 68 |
-
inference_mode=inference_mode,
|
| 69 |
-
use_flash_attention_2=use_flash_attention_2,
|
| 70 |
-
**LLAMA2_MODELS[llm_backbone_id],
|
| 71 |
-
)
|
| 72 |
-
|
| 73 |
-
# [Special Case] LLaMa-2 PAD Token Handling --> for clarity, we add an extra token (and resize)
|
| 74 |
-
self.tokenizer.add_special_tokens({"pad_token": "<PAD>"})
|
| 75 |
-
self.llm.config.pad_token_id = self.tokenizer.pad_token_id
|
| 76 |
-
self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64)
|
| 77 |
-
|
| 78 |
-
@property
|
| 79 |
-
def prompt_builder_fn(self) -> Type[PromptBuilder]:
|
| 80 |
-
if self.identifier.startswith("llama2-") and self.identifier.endswith("-pure"):
|
| 81 |
-
return PurePromptBuilder
|
| 82 |
-
|
| 83 |
-
elif self.identifier.startswith("llama2-") and self.identifier.endswith("-chat"):
|
| 84 |
-
return LLaMa2ChatPromptBuilder
|
| 85 |
-
|
| 86 |
-
elif self.identifier.startswith("vicuna"):
|
| 87 |
-
return VicunaV15ChatPromptBuilder
|
| 88 |
-
|
| 89 |
-
raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`")
|
| 90 |
-
|
| 91 |
-
@property
|
| 92 |
-
def transformer_layer_cls(self) -> Type[nn.Module]:
|
| 93 |
-
return LlamaDecoderLayer
|
| 94 |
-
|
| 95 |
-
@property
|
| 96 |
-
def half_precision_dtype(self) -> torch.dtype:
|
| 97 |
-
"""LLaMa-2 was trained in BF16; see https://huggingface.co/docs/transformers/main/model_doc/llama2."""
|
| 98 |
-
return torch.bfloat16
|
| 99 |
-
|
| 100 |
-
@property
|
| 101 |
-
def last_layer_finetune_modules(self) -> Sequence[nn.Module]:
|
| 102 |
-
return (self.llm.model.embed_tokens, self.llm.model.layers[-1], self.llm.lm_head)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capvector-oft/prismatic/models/backbones/llm/mistral.py
DELETED
|
@@ -1,72 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
mistral.py
|
| 3 |
-
|
| 4 |
-
Class definition for all LLMs derived from MistralForCausalLM.
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
from typing import Optional, Type
|
| 8 |
-
|
| 9 |
-
import torch
|
| 10 |
-
from torch import nn as nn
|
| 11 |
-
from transformers import MistralForCausalLM
|
| 12 |
-
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
|
| 13 |
-
|
| 14 |
-
from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone
|
| 15 |
-
from prismatic.models.backbones.llm.prompting import MistralInstructPromptBuilder, PromptBuilder, PurePromptBuilder
|
| 16 |
-
|
| 17 |
-
# Registry =>> Support Mistral Models (from HF Transformers)
|
| 18 |
-
# fmt: off
|
| 19 |
-
MISTRAL_MODELS = {
|
| 20 |
-
# === Base Mistral v0.1 ===
|
| 21 |
-
"mistral-v0.1-7b-pure": {
|
| 22 |
-
"llm_family": "mistral", "llm_cls": MistralForCausalLM, "hf_hub_path": "mistralai/Mistral-7B-v0.1"
|
| 23 |
-
},
|
| 24 |
-
|
| 25 |
-
# === Mistral Instruct v0.1 ===
|
| 26 |
-
"mistral-v0.1-7b-instruct": {
|
| 27 |
-
"llm_family": "mistral", "llm_cls": MistralForCausalLM, "hf_hub_path": "mistralai/Mistral-7B-Instruct-v0.1"
|
| 28 |
-
}
|
| 29 |
-
}
|
| 30 |
-
# fmt: on
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
class MistralLLMBackbone(HFCausalLLMBackbone):
|
| 34 |
-
def __init__(
|
| 35 |
-
self,
|
| 36 |
-
llm_backbone_id: str,
|
| 37 |
-
llm_max_length: int = 2048,
|
| 38 |
-
hf_token: Optional[str] = None,
|
| 39 |
-
inference_mode: bool = False,
|
| 40 |
-
use_flash_attention_2: bool = True,
|
| 41 |
-
) -> None:
|
| 42 |
-
super().__init__(
|
| 43 |
-
llm_backbone_id,
|
| 44 |
-
llm_max_length=llm_max_length,
|
| 45 |
-
hf_token=hf_token,
|
| 46 |
-
inference_mode=inference_mode,
|
| 47 |
-
use_flash_attention_2=use_flash_attention_2,
|
| 48 |
-
**MISTRAL_MODELS[llm_backbone_id],
|
| 49 |
-
)
|
| 50 |
-
|
| 51 |
-
# [Special Case] Mistral PAD Token Handling --> for clarity, we add an extra token (and resize)
|
| 52 |
-
self.tokenizer.add_special_tokens({"pad_token": "<PAD>"})
|
| 53 |
-
self.llm.config.pad_token_id = self.tokenizer.pad_token_id
|
| 54 |
-
self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64)
|
| 55 |
-
|
| 56 |
-
@property
|
| 57 |
-
def prompt_builder_fn(self) -> Type[PromptBuilder]:
|
| 58 |
-
if self.identifier.endswith("-pure"):
|
| 59 |
-
return PurePromptBuilder
|
| 60 |
-
|
| 61 |
-
elif self.identifier.endswith("-instruct"):
|
| 62 |
-
return MistralInstructPromptBuilder
|
| 63 |
-
|
| 64 |
-
raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`")
|
| 65 |
-
|
| 66 |
-
@property
|
| 67 |
-
def transformer_layer_cls(self) -> Type[nn.Module]:
|
| 68 |
-
return MistralDecoderLayer
|
| 69 |
-
|
| 70 |
-
@property
|
| 71 |
-
def half_precision_dtype(self) -> torch.dtype:
|
| 72 |
-
return torch.bfloat16
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|