diff --git a/.gitattributes b/.gitattributes deleted file mode 100644 index a6344aac8c09253b3b630fb776ae94478aa0275b..0000000000000000000000000000000000000000 --- a/.gitattributes +++ /dev/null @@ -1,35 +0,0 @@ -*.7z filter=lfs diff=lfs merge=lfs -text -*.arrow filter=lfs diff=lfs merge=lfs -text -*.bin filter=lfs diff=lfs merge=lfs -text -*.bz2 filter=lfs diff=lfs merge=lfs -text -*.ckpt filter=lfs diff=lfs merge=lfs -text -*.ftz filter=lfs diff=lfs merge=lfs -text -*.gz filter=lfs diff=lfs merge=lfs -text -*.h5 filter=lfs diff=lfs merge=lfs -text -*.joblib filter=lfs diff=lfs merge=lfs -text -*.lfs.* filter=lfs diff=lfs merge=lfs -text -*.mlmodel filter=lfs diff=lfs merge=lfs -text -*.model filter=lfs diff=lfs merge=lfs -text -*.msgpack filter=lfs diff=lfs merge=lfs -text -*.npy filter=lfs diff=lfs merge=lfs -text -*.npz filter=lfs diff=lfs merge=lfs -text -*.onnx filter=lfs diff=lfs merge=lfs -text -*.ot filter=lfs diff=lfs merge=lfs -text -*.parquet filter=lfs diff=lfs merge=lfs -text -*.pb filter=lfs diff=lfs merge=lfs -text -*.pickle filter=lfs diff=lfs merge=lfs -text -*.pkl filter=lfs diff=lfs merge=lfs -text -*.pt filter=lfs diff=lfs merge=lfs -text -*.pth filter=lfs diff=lfs merge=lfs -text -*.rar filter=lfs diff=lfs merge=lfs -text -*.safetensors filter=lfs diff=lfs merge=lfs -text -saved_model/**/* filter=lfs diff=lfs merge=lfs -text -*.tar.* filter=lfs diff=lfs merge=lfs -text -*.tar filter=lfs diff=lfs merge=lfs -text -*.tflite filter=lfs diff=lfs merge=lfs -text -*.tgz filter=lfs diff=lfs merge=lfs -text -*.wasm filter=lfs diff=lfs merge=lfs -text -*.xz filter=lfs diff=lfs merge=lfs -text -*.zip filter=lfs diff=lfs merge=lfs -text -*.zst filter=lfs diff=lfs merge=lfs -text -*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md deleted file mode 100644 index e75904ef5bf65a72a86f9aad92c69b6f5cfed495..0000000000000000000000000000000000000000 --- a/README.md +++ /dev/null @@ -1,57 +0,0 @@ -# CapVector: Learning Transferable Capability Vectors in Parametric Space for Vision-Language-Action Models - -
- -[![Paper](https://img.shields.io/badge/Paper-A42C25?style=for-the-badge&logo=arxiv&logoColor=white)](http://arxiv.org/abs/) [![Page](https://img.shields.io/badge/Project--Page-blue?style=for-the-badge&logo=homepage&logoColor=white)](https://capvector.github.io/) [![Hugging Face Collection](https://img.shields.io/badge/Models-fcd022?style=for-the-badge&logo=huggingface&logoColor=white)](https://huggingface.co/haofuly/capvector_models_collection) - -
- -CapVector is a training recipe for vision-language-action (VLA) models that extracts a transferable capability vector from the parameter difference between auxiliary-objective SFT methods and standard SFT methods. This vector is merged into a pretrained VLA to form a stronger initialization, and downstream adaptation uses standard SFT with a lightweight orthogonal regularization loss to preserve the injected capability. - - -## 🌟 Key Features -- **Efficient downstream adaptation**: CapVector recovers much of the benefit of auxiliary-objective SFT methods, while keeping the downstream overhead close to standard SFT. -- **Versatility**: CapVector fits for OpenVLA-based, OpenPi-based, and StarVLA-based backbones. -- **Generalization**: CapVector is designed to transfer across tasks, environments, and robot embodiments. - - -## 🚀 Get Started - -This repository provides two implementation paths: -- [`capvector-oft/`](./capvector-oft) based implementation -- [`capvector-pi05/`](./capvector-pi05) based implementation. - -Choose the subdirectory that matches your base model and training stack. Follow the subproject README for environment setup, data preparation, training, and inference. - -[`capvector-pi05/`](./capvector-pi05) provides the capability vector extraction and merging scripts. - - -## 🌏 Contact -For further discussion and collaboration, please feel free to contact us via Email and WeChat: - -| Author | Email | WeChat | -|:---:|:---:|:---:| -| Wenxuan Song | songwenxuan0115@gmail.com | swx0757 | - - -## ❤️ Acknowledgments - -CapVector builds on and interfaces with several excellent open-source projects, including: - -- [OpenVLA-OFT](https://github.com/moojink/openvla-oft) -- [OpenPI](https://github.com/Physical-Intelligence/openpi) - - -## 🖊 Citation - -If you find this work useful, please cite: - -```bibtex -@article{song2026capvector, - title = {CapVector: Learning Transferable Capability Vectors in Parametric Space for Vision-Language-Action Models}, - author = {Song, Wenxuan and Zhao, Han and Li, Fuhao and Zhou, Ziyang and Wang, Xi and Lyu, Jing and Ding, Pengxiang and Wang, Yan and Wang, Donglin and Li, Haoang}, - journal = {Preprint}, - year = {2026} -} -``` - diff --git a/capvector-oft/.pre-commit-config.yaml b/capvector-oft/.pre-commit-config.yaml deleted file mode 100644 index 84c9d149ad54e2c25ba3a07a8435ed2f4c5d09e9..0000000000000000000000000000000000000000 --- a/capvector-oft/.pre-commit-config.yaml +++ /dev/null @@ -1,27 +0,0 @@ -# See https://pre-commit.com for more information -# See https://pre-commit.com/hooks.html for more hooks -exclude: ".git" - -repos: - - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.2.2 - hooks: - - id: ruff - args: [ --fix, --exit-non-zero-on-fix ] - - - repo: https://github.com/psf/black - rev: 24.2.0 - hooks: - - id: black - - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 - hooks: - - id: check-added-large-files - - id: check-ast - - id: check-case-conflict - - id: check-merge-conflict - - id: check-toml - - id: check-yaml - - id: end-of-file-fixer - - id: trailing-whitespace diff --git a/capvector-oft/ALOHA.md b/capvector-oft/ALOHA.md deleted file mode 100644 index f4aea9f5b765a0e4557045ef07eb58aeaab0d132..0000000000000000000000000000000000000000 --- a/capvector-oft/ALOHA.md +++ /dev/null @@ -1,157 +0,0 @@ -# OpenVLA-OFT+ in Real-World ALOHA Robot Tasks - -## Relevant Files - -Evaluation -* `experiments/robot/aloha/`: ALOHA training and eval files - * `run_aloha_eval.py`: ALOHA eval script (CLIENT SIDE; see "SERVER SIDE" below) - * `aloha_utils.py`: ALOHA eval utils - * Other ALOHA robot environment files copied from the original [ALOHA GitHub repo](https://github.com/tonyzhaozh/aloha): - * `constants.py` - * `real_env.py` - * `robot_utils.py` -* `experiments/robot/`: General eval utils files - * `openvla_utils.py`: OpenVLA-specific eval utils - * `robot_utils.py`: Other eval utils -* `vla-scripts/deploy.py`: VLA server deploy script (SERVER SIDE) - -Note: Unlike the LIBERO evaluation setup, we use a server-client interface here. This is particularly useful if the user's machine which commands the robot does not have access to a local GPU with sufficient specs to run the fine-tuned VLA policies. - -Training -* `experiments/robot/aloha/`: ALOHA training and eval files - * `preprocess_split_aloha_data.py`: ALOHA data preprocessing script -* `vla-scripts/finetune.py`: VLA fine-tuning script - -## Setup - -Set up a conda environment for training policies and deploying them on the VLA server (see instructions in [SETUP.md](SETUP.md)). - -## Fine-Tuning on ALOHA Robot Data - -We assume that you have collected a set of expert demonstrations on the ALOHA robot already. - -First, use our `preprocess_split_aloha_data.py` script to preprocess the raw ALOHA dataset: downsize images from 480x640 to 256x256 and split into training and validation sets. Below are examples for the `put X into pot` task in our paper (which has 3 possible target objects, 1 per episode): - -```bash -python experiments/robot/aloha/preprocess_split_aloha_data.py \ - --dataset_path /scr/moojink/data/aloha1_raw/put_green_pepper_into_pot/ \ - --out_base_dir /scr/moojink/data/aloha1_preprocessed/ \ - --percent_val 0.05 -python experiments/robot/aloha/preprocess_split_aloha_data.py \ - --dataset_path /scr/moojink/data/aloha1_raw/put_red_pepper_into_pot/ \ - --out_base_dir /scr/moojink/data/aloha1_preprocessed/ \ - --percent_val 0.05 -python experiments/robot/aloha/preprocess_split_aloha_data.py \ - --dataset_path /scr/moojink/data/aloha1_raw/put_yellow_corn_into_pot/ \ - --out_base_dir /scr/moojink/data/aloha1_preprocessed/ \ - --percent_val 0.05 -``` - -Then, convert the preprocessed ALOHA datasets into a single RLDS dataset that is compatible with OpenVLA fine-tuning. This process is the same as in the original OpenVLA repo. See instructions for converting to RLDS [here](https://github.com/moojink/rlds_dataset_builder) (a sample ALOHA preprocessed-to-RLDS conversion script is available [here](https://github.com/moojink/rlds_dataset_builder/blob/main/aloha1_put_X_into_pot_300_demos/aloha1_put_X_into_pot_300_demos_dataset_builder.py); this script converts the three preprocessed datasets above into one unified RLDS dataset, with train/val splits). - -After converting to RLDS, register the dataset (which, for the example task above, would be called `aloha1_put_X_into_pot_300_demos`) with our dataloader by adding an entry for it in `configs.py` ([here](prismatic/vla/datasets/rlds/oxe/configs.py#L680)), `transforms.py` ([here](prismatic/vla/datasets/rlds/oxe/transforms.py#L928)), and `mixtures.py` ([here](prismatic/vla/datasets/rlds/oxe/mixtures.py#L216)). For reference, in each of these files, there are sample entries for the ALOHA datasets that we used in our paper. - -Before fine-tuning, set the desired ALOHA action chunk size in [`prismatic/vla/constants.py`](prismatic/vla/constants.py) (see `NUM_ACTIONS_CHUNK` in `ALOHA_CONSTANTS`). We set it to 25 by default because we used a control frequency of 25 Hz in our ALOHA setup to reduce storage costs and training time (while still maintaining smoothness in the robot's motions). If you use 50 Hz, we recommend setting `NUM_ACTIONS_CHUNK` to `50`. In general, 1 second-long action chunks are a good default. Do NOT modify `ACTION_PROPRIO_NORMALIZATION_TYPE`: Since the ALOHA robot action space is absolute joint angles, we do not want to use a normalization scheme that clips outlier values (like the Q1-Q99 normalization we used with the relative end-effector pose actions for LIBERO), since that would prevent the model from outputting certain robot joint angles that are crucial for solving the task. - -Now begin fine-tuning! Below is a sample command to fine-tune OpenVLA using our OFT+ recipe on the `put X into pot` task above ("+" in "OFT+" means FiLM is included for enhanced language grounding). Replace `X` in the first line with the number of GPUs available to you. - -```bash -torchrun --standalone --nnodes 1 --nproc-per-node X vla-scripts/finetune.py \ - --vla_path openvla/openvla-7b \ - --data_root_dir /PATH/TO/RLDS/DATASETS/DIR/ \ - --dataset_name aloha1_put_X_into_pot_300_demos \ - --run_root_dir /YOUR/CHECKPOINTS/AND/LOG/DIR/ \ - --use_l1_regression True \ - --use_diffusion False \ - --use_film True \ - --num_images_in_input 3 \ - --use_proprio True \ - --batch_size 4 \ - --learning_rate 5e-4 \ - --num_steps_before_decay 50000 \ - --max_steps 100005 \ - --use_val_set True \ - --val_freq 10000 \ - --save_freq 10000 \ - --save_latest_checkpoint_only False \ - --image_aug True \ - --lora_rank 32 \ - --wandb_entity "YOUR_WANDB_ENTITY" \ - --wandb_project "YOUR_WANDB_PROJECT" \ - --run_id_note parallel_dec--25_acts_chunk--continuous_acts--L1_regression--3rd_person_img--left_right_wrist_imgs--proprio_state--film -``` - -The above training command should reproduce our OpenVLA-OFT+ results on the `put X into pot` task if `X = 8` and the 100K step checkpoint is evaluated. It will fine-tune OpenVLA using 3 input images (1 third-person image + 2 wrist camera images). Note that we use learning rate decay after a certain point (50K steps in the command above) since doing so speeds up training convergence (train L1 loss spikes down from our experience). - -Best practices for fine-tuning: -* In general, we recommend fine-tuning until training L1 loss goes below 0.01 and starts to plateau. - * One way to achieve this is to fine-tune using our default learning rate of `5e-4` until the loss starts to decrease very slowly, and then decay the learning rate by 10x to `5e-5` (which should make the loss spike down) and train until the training L1 loss finally plateaus. -* Depending on your dataset size, you may need to adjust some hyperparameters. For example, if you use a large dataset with over 300 demos, you may need to decay the learning rate later and train for longer for best performance. Decaying too earlier can lead to a suboptimal policy. -* If your task does not require good langauge grounding (e.g., if there is only one language instruction), FiLM is not necessary; consider setting `--use_film False` to train fewer model parameters. -* Please be sure to test your policy with the same device/GPU used to train it! Otherwise, performance may drop substantially. You may be able to avoid the performance drop if you merge the LoRA weights into the base model on the downstream device used for testing (e.g., if you train on H100 and then merge on A100 before testing on A100). You can see our script [vla-scripts/merge_lora_weights_and_save.py](vla-scripts/merge_lora_weights_and_save.py) for merging the LoRA adapter into the base model offline. It's okay if you already merged LoRA weights into the base OpenVLA model during fine-tuning; you can always redownload the base model and merge again as long as you still have the LoRA adapter (`merge_lora_weights_and_save.py` will handle this for you). - -If you run into any issues, please open a new GitHub issue. - -## Launching ALOHA Robot Evaluations - -In the primary conda environment (`openvla-oft`) which you will use to launch the VLA server, install a few packages for the server-client interface: - -```bash -conda activate openvla-oft -pip install uvicorn fastapi json-numpy -``` - -On the machine that you will use to command the robot, set up a second conda environment that will be used to run the robot environment, query the VLA server, and execute actions in the environment: - -```bash -# Create and activate client conda environment -conda create -n openvla-oft-aloha python=3.10 -y -conda activate openvla-oft-aloha - -# Install PyTorch -# Use a command specific to your machine: https://pytorch.org/get-started/locally/ -pip3 install torch torchvision torchaudio - -# Clone openvla-oft repo and pip install to download dependencies -git clone https://github.com/moojink/openvla-oft.git -cd openvla-oft -pip install -e . - -# Install packages needed for the ALOHA robot environment -pip install -r experiments/robot/aloha/requirements_aloha.txt -``` - -Launch the VLA server on the machine that has the GPU you will use to run model inference (using the `openvla-oft` conda environment). Below is a sample command for this (change as needed): - -```bash -python vla-scripts/deploy.py \ - --pretrained_checkpoint /PATH/TO/FINETUNED/MODEL/CHECKPOINT/DIR/ \ - --use_l1_regression True \ - --use_film True \ - --num_images_in_input 3 \ - --use_proprio True \ - --center_crop True \ - --unnorm_key aloha1_put_X_into_pot_300_demos -``` - -Then, run the ALOHA evaluation script. Specify the VLA server URL or IP address in the `vla_server_url` argument. Below is a sample command: - -```bash -python experiments/robot/aloha/run_aloha_eval.py \ - --center_crop True \ - --num_open_loop_steps 25 \ - --use_vla_server True \ - --vla_server_url \ - --num_rollouts_planned \ - --max_steps -``` - -If you run into any issues, please open a new GitHub issue. - -## Troubleshooting Tips - -* Tip #1: If you run into a ROS error such as `ImportError: /lib/x86_64-linux-gnu/libp11-kit.so.0: undefined symbol: ffi_type_pointer, version LIBFFI_BASE_7.0`, try running the following command in your client conda environment (`openvla-oft-aloha`): - - ``` - conda install -c conda-forge libffi - ``` diff --git a/capvector-oft/LIBERO.md b/capvector-oft/LIBERO.md deleted file mode 100644 index aaf6aadd69ed6eb98317c61e2402800689c5c2f1..0000000000000000000000000000000000000000 --- a/capvector-oft/LIBERO.md +++ /dev/null @@ -1,130 +0,0 @@ -# OpenVLA-OFT in the LIBERO Simulation Benchmark - -## Relevant Files - -Evaluation -* `experiments/robot/libero/`: LIBERO eval files - * `run_libero_eval.py`: LIBERO eval script - * `libero_utils.py`: LIBERO eval utils -* `experiments/robot/`: General eval utils files - * `openvla_utils.py`: OpenVLA-specific eval utils - * `robot_utils.py`: Other eval utils - -Training -* `vla-scripts/finetune.py`: VLA fine-tuning script - - -## Setup - -Set up a conda environment (see instructions in [SETUP.md](SETUP.md)). - -Clone and install the [LIBERO repo](https://github.com/Lifelong-Robot-Learning/LIBERO) and required packages: - -```bash -git clone https://github.com/Lifelong-Robot-Learning/LIBERO.git -pip install -e LIBERO -pip install -r experiments/robot/libero/libero_requirements.txt # From openvla-oft base dir -``` - -(Optional, if you plan to launch training) To download the [LIBERO datasets](https://huggingface.co/datasets/openvla/modified_libero_rlds) that we used in our fine-tuning -experiments, run the command below. This will download the LIBERO-Spatial, LIBERO-Object, LIBERO-Goal, -and LIBERO-10 datasets in RLDS data format (~10 GB total). You can use these to fine-tune OpenVLA or -train other methods. This step is optional since we provide pretrained OpenVLA-OFT checkpoints below. -Note that these are the same datasets used in the original OpenVLA project. If needed, see details on how to download the original non-RLDS datasets [here](https://github.com/openvla/openvla?tab=readme-ov-file#libero-setup). -```bash -git clone git@hf.co:datasets/openvla/modified_libero_rlds -``` - -## Launching LIBERO Evaluations - -We fine-tuned OpenVLA via LoRA (r=32) with our OFT recipe on four LIBERO task suites: LIBERO-Spatial, LIBERO-Object, LIBERO-Goal, and LIBERO-10 (also called LIBERO-Long). -In the initial version of our paper, we trained one checkpoint for each LIBERO task suite independently. In an updated version of the paper, we conducted an additional experiment in which we trained a single policy on all four task suites combined (results for this are available in the Additional Experiments section in the Appendix). Overall, the results for the task-specific policies and the combined policy are comparable: 97.1% vs. 96.8% average success rate across the four suites, respectively. - -Below are the four independently trained OpenVLA-OFT checkpoints for LIBERO: -* [moojink/openvla-7b-oft-finetuned-libero-spatial](https://huggingface.co/moojink/openvla-7b-oft-finetuned-libero-spatial) -* [moojink/openvla-7b-oft-finetuned-libero-object](https://huggingface.co/moojink/openvla-7b-oft-finetuned-libero-object) -* [moojink/openvla-7b-oft-finetuned-libero-goal](https://huggingface.co/moojink/openvla-7b-oft-finetuned-libero-goal) -* [moojink/openvla-7b-oft-finetuned-libero-10](https://huggingface.co/moojink/openvla-7b-oft-finetuned-libero-10) - -Below is the OpenVLA-OFT checkpoint trained on all four task suites combined: -* [moojink/openvla-7b-oft-finetuned-libero-spatial-object-goal-10](https://huggingface.co/moojink/openvla-7b-oft-finetuned-libero-spatial-object-goal-10) - -To start evaluations with one of the independently trained checkpoints, run one of the commands below. Each will automatically download the appropriate checkpoint listed above. You can set the `TRANSFORMERS_CACHE` and `HF_HOME` environment variable to change where the checkpoint files get cached. - -```bash -# Launch LIBERO-Spatial evals -python experiments/robot/libero/run_libero_eval.py \ - --pretrained_checkpoint moojink/openvla-7b-oft-finetuned-libero-spatial \ - --task_suite_name libero_spatial - -# Launch LIBERO-Object evals -python experiments/robot/libero/run_libero_eval.py \ - --pretrained_checkpoint moojink/openvla-7b-oft-finetuned-libero-object \ - --task_suite_name libero_object - -# Launch LIBERO-Goal evals -python experiments/robot/libero/run_libero_eval.py \ - --pretrained_checkpoint moojink/openvla-7b-oft-finetuned-libero-goal \ - --task_suite_name libero_goal - -# Launch LIBERO-10 (LIBERO-Long) evals -python experiments/robot/libero/run_libero_eval.py \ - --pretrained_checkpoint moojink/openvla-7b-oft-finetuned-libero-10 \ - --task_suite_name libero_10 -``` - -To evaluate the policy trained on all four task suites together, simply swap out the `--pretrained_checkpoint` in the commands above with `moojink/openvla-7b-oft-finetuned-libero-spatial-object-goal-10`. - -Notes: -* The evaluation script will run 500 trials by default (10 tasks x 50 episodes each). You can modify the number of - trials per task by setting `--num_trials_per_task`. You can also change the random seed via `--seed`. There are - other arguments in the script; we set them to the default values that work with the OpenVLA-OFT checkpoints above. -* **NOTE: Setting `--center_crop True` is important** because we fine-tuned OpenVLA with random crop augmentations - (we took a random crop with 90% area in every training sample, so at test time we simply take the center 90% crop). -* The evaluation script logs results locally. You can also log results in Weights & Biases - by setting `--use_wandb True` and specifying `--wandb_project ` and `--wandb_entity `. -* The results reported in our paper were obtained using **Python 3.10.14, PyTorch 2.2.0, and our - [custom transformers v4.40.1 fork](https://github.com/moojink/transformers-openvla-oft.git)** - on an **NVIDIA A100 GPU**, averaged over three random seeds. Please stick to these package versions if possible. - Note that results may vary slightly if you use a different GPU than the A100. If the discrepancy is large, - please post a GitHub issue, and we will look into it. - -## Fine-Tuning on LIBERO Datasets - -First, download the LIBERO datasets as mentioned above in the Setup section above: `libero_spatial_no_noops`, `libero_object_no_noops`, `libero_goal_no_noops`, `libero_10_no_noops`. (`"_no_noops"` stands for no no-op actions, i.e., training samples with near-zero actions are filtered out). - -Then, launch the fine-tuning script with the OFT configuration below, replacing `X` in the first line with the number of GPUs. The command below launches fine-tuning on LIBERO-Spatial with the hyperparameters that we used in our paper. Here, batch size 8 per GPU will require ~62 GB VRAM, and batch size 1 per GPU will require ~25 GB VRAM. - -```bash -torchrun --standalone --nnodes 1 --nproc-per-node X vla-scripts/finetune.py \ - --vla_path openvla/openvla-7b \ - --data_root_dir /PATH/TO/RLDS/DATASETS/DIR/ \ - --dataset_name libero_spatial_no_noops \ - --run_root_dir /YOUR/CHECKPOINTS/AND/LOG/DIR/ \ - --use_l1_regression True \ - --use_diffusion False \ - --use_film False \ - --num_images_in_input 2 \ - --use_proprio True \ - --batch_size 8 \ - --learning_rate 5e-4 \ - --num_steps_before_decay 100000 \ - --max_steps 150005 \ - --save_freq 10000 \ - --save_latest_checkpoint_only False \ - --image_aug True \ - --lora_rank 32 \ - --wandb_entity "YOUR_WANDB_ENTITY" \ - --wandb_project "YOUR_WANDB_PROJECT" \ - --run_id_note parallel_dec--8_acts_chunk--continuous_acts--L1_regression--3rd_person_img--wrist_img--proprio_state -``` - -The above training command should reproduce our OpenVLA-OFT results if `X = 8` and the 150K step checkpoint is evaluated. - -You can replace `libero_spatial_no_noops` with `libero_object_no_noops`, `libero_goal_no_noops`, or `libero_10_no_noops`. You can also modify other args — e.g., if you want to train with just one input image from the third-person camera and disable proprio state input, you can set `--num_images_in_input 1` and `--use_proprio False`. - -In general, we recommend fine-tuning until training L1 loss goes below 0.01 and starts to plateau (with the above configuration, it should reach ~0.006 L1 loss on LIBERO-Spatial after 150K gradient steps with 10x LR decay after 100K steps). However, for LIBERO-Goal only, we found that the 50K checkpoint (which was at ~0.02 L1 loss) performed best for unknown reasons. For all other task suites though, we found that the 150K checkpoint performed best. - -Please be sure to test your policy with the same device/GPU used to train it! Otherwise, performance may drop substantially. You may be able to avoid the performance drop if you merge the LoRA weights into the base model on the downstream device used for testing (e.g., if you train on H100 and then merge on A100 before testing on A100). You can see our script [vla-scripts/merge_lora_weights_and_save.py](vla-scripts/merge_lora_weights_and_save.py) for merging the LoRA adapter into the base model offline. It's okay if you already merged LoRA weights into the base OpenVLA model during fine-tuning; you can always redownload the base model and merge again as long as you still have the LoRA adapter (`merge_lora_weights_and_save.py` will handle this for you). - -If you run into any issues, please open a new GitHub issue. If you do not receive a response within 2 business days, please email Moo Jin Kim (moojink@cs.stanford.edu) to bring the issue to his attention. diff --git a/capvector-oft/LICENSE b/capvector-oft/LICENSE deleted file mode 100644 index 852706a85173e194ead678a927cef6c362d081b6..0000000000000000000000000000000000000000 --- a/capvector-oft/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2025 Moo Jin Kim, Chelsea Finn, Percy Liang. - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/capvector-oft/SETUP.md b/capvector-oft/SETUP.md deleted file mode 100644 index 89e512d0dfbf26e45509d18fe383f54a889ea613..0000000000000000000000000000000000000000 --- a/capvector-oft/SETUP.md +++ /dev/null @@ -1,24 +0,0 @@ -# Setup Instructions - -## Set Up Conda Environment - -```bash -# Create and activate conda environment -conda create -n capvector-openvla-oft python=3.10 -y -conda activate capvector-openvla-oft - -# Install PyTorch -# Use a command specific to your machine: https://pytorch.org/get-started/locally/ -pip3 install torch torchvision torchaudio - -# Clone openvla-oft repo and pip install to download dependencies -git clone https://github.com/Songwxuan/CapVector -cd openvla-oft -pip install -e . - -# Install Flash Attention 2 for training (https://github.com/Dao-AILab/flash-attention) -# =>> If you run into difficulty, try `pip cache remove flash_attn` first -pip install packaging ninja -ninja --version; echo $? # Verify Ninja --> should return exit code "0" -pip install "flash-attn==2.5.5" --no-build-isolation -``` \ No newline at end of file diff --git a/capvector-oft/capvector/.gitignore b/capvector-oft/capvector/.gitignore deleted file mode 100644 index d49c7727561cbee22189ff80c618be1061d6e2fe..0000000000000000000000000000000000000000 --- a/capvector-oft/capvector/.gitignore +++ /dev/null @@ -1,8 +0,0 @@ -bin/ -draw_pic/ -feature_vector_ckpt/ -figure/ -id_extrapolation/ -id_interpolation/ -initialized_pt_vla/ -lora_diff/ \ No newline at end of file diff --git a/capvector-oft/capvector/compute_lora_diff.py b/capvector-oft/capvector/compute_lora_diff.py deleted file mode 100644 index cb06b6671594c22341405e39b023e72eb94566fa..0000000000000000000000000000000000000000 --- a/capvector-oft/capvector/compute_lora_diff.py +++ /dev/null @@ -1,35 +0,0 @@ -from safetensors.torch import load_file, save_file -import torch -import argparse - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--base", required=True) - parser.add_argument("--target", required=True) - parser.add_argument("--out", default="lora_diff.safetensors") - args = parser.parse_args() - - base = load_file(args.base) - target = load_file(args.target) - - diff = {} - - print("=== Key Comparison ===") - only_in_base = set(base) - set(target) - only_in_target = set(target) - set(base) - - print("Only in base:", list(only_in_base)[:10]) - print("Only in target:", list(only_in_target)[:10]) - - for k in target: - if k in base: - diff[k] = target[k] - base[k] - else: - # new parameters are directly retained - diff[k] = target[k].clone() - - save_file(diff, args.out) - print(f"\nSaved diff to: {args.out}") - -if __name__ == "__main__": - main() diff --git a/capvector-oft/capvector/compute_lora_shell/compute_lora_diff.sh b/capvector-oft/capvector/compute_lora_shell/compute_lora_diff.sh deleted file mode 100644 index 8cbd7ae132a7f853a566a2929761756a3a33596a..0000000000000000000000000000000000000000 --- a/capvector-oft/capvector/compute_lora_shell/compute_lora_diff.sh +++ /dev/null @@ -1,8 +0,0 @@ -BASE_ADAPTER="checkpoints/reference_models/openvla_oft_libero_spatial/lora_adapter/adapter_model.safetensors" -TARGET_ADAPTER="checkpoints/task_models/SF_spatial/lora_adapter/adapter_model.safetensors" -OUTPUT_DIFF="checkpoints/lora_diff/sf_150000_steps_spatial_adapter_diff.safetensors" - -python compute_lora_diff.py \ - --base "$BASE_ADAPTER" \ - --target "$TARGET_ADAPTER" \ - --out "$OUTPUT_DIFF" diff --git a/capvector-oft/capvector/initialized_interpolate_shell/get_vector_robotwin.sh b/capvector-oft/capvector/initialized_interpolate_shell/get_vector_robotwin.sh deleted file mode 100644 index ac6195836a0af9e7f1d08b95577acaf94c198927..0000000000000000000000000000000000000000 --- a/capvector-oft/capvector/initialized_interpolate_shell/get_vector_robotwin.sh +++ /dev/null @@ -1,26 +0,0 @@ -TASK=bigbin_pot_microwave_qrcode_bowlsthree # Customize for your task -VERSION=53 -PT_CKPT="checkpoints/openvla_base" -TASK_MODEL_CHECKPOINT="checkpoints/task_models/v106.1" -REFERENCE_MODEL_CHECKPOINT="checkpoints/reference_models/v106.0" -VECTOR_SAVE_PATH="checkpoints/feature_vectors/feature_vector_with_SF_${TASK}_v${VERSION}.pth" -INITIALIZED_PT_VLA_PATH="checkpoints/initialized_pt_vla/initailized_openvla_with_SF_${TASK}_v${VERSION}" -TASK_SUITE_NAME="ALOHA_${TASK}" - -python interpolate_robotwin.py \ - --pretrained_checkpoint "$TASK_MODEL_CHECKPOINT" \ - --original_pretrained_checkpoint "$REFERENCE_MODEL_CHECKPOINT" \ - --vector_save_path "$VECTOR_SAVE_PATH" \ - --initialized_pt_vla_path $INITIALIZED_PT_VLA_PATH \ - --pt_ckpt $PT_CKPT\ - --feature_vector_weight 1.1 \ - --task_suite_name $TASK_SUITE_NAME - -#the code below is used to transplant the model parameters except for the vla backbone, such as processor and tokenizer, so that the initialized model is complete - -rsync -av \ - --ignore-existing \ - --exclude='*.safetensors' \ - --exclude='*.back.*' \ - $PT_CKPT/ \ - $INITIALIZED_PT_VLA_PATH/ diff --git a/capvector-oft/capvector/interpolate.py b/capvector-oft/capvector/interpolate.py deleted file mode 100644 index 50a2e55a0da4b64b047161ddbd9aa26376531f7e..0000000000000000000000000000000000000000 --- a/capvector-oft/capvector/interpolate.py +++ /dev/null @@ -1,247 +0,0 @@ -""" -This is for extracting feature vector from the openvla-oft model and interpolating it with the original openvla model. -""" - - -import os -import json -import logging - -import sys -from collections import deque -from dataclasses import dataclass -from enum import Enum -from pathlib import Path -from typing import Optional, Union -from PIL import Image - -import draccus -import numpy as np -from tqdm import tqdm -import torch -import copy - -import wandb - -REPO_ROOT = Path(__file__).resolve().parents[1] -if str(REPO_ROOT) not in sys.path: - sys.path.append(str(REPO_ROOT)) -from experiments.robot.openvla_utils import ( - get_action_head, - get_noisy_action_projector, - get_processor, - get_proprio_projector, - resize_image_for_policy, -) -from experiments.robot.robot_utils import ( - DATE_TIME, - get_action, - get_image_resize_size, - get_model, - invert_gripper_action, - normalize_gripper_action, - set_seed_everywhere, -) -from experiments.robot.libero.run_libero_eval import check_unnorm_key -from prismatic.vla.constants import NUM_ACTIONS_CHUNK - - -# Set up logging -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s [%(levelname)s] %(message)s", - handlers=[logging.StreamHandler()], -) -logger = logging.getLogger(__name__) - - -@dataclass -class GenerateConfig: - # fmt: off - - ################################################################################################################# - # Model-specific parameters - ################################################################################################################# - model_family: str = "openvla" # Model family - #the task-specific model after sf fine-tuning - pretrained_checkpoint: Union[str, Path] = "checkpoints/task_model" # Task-specific checkpoint path - #the task-specific model after oft fine-tuning - original_pretrained_checkpoint: Union[str, Path] = "checkpoints/reference_model" # Reference checkpoint path - #feature vector is the difference between the two models, which represents the spatial features - vector_save_path: Union[str, Path] = "checkpoints/feature_vectors/feature_vector.pth" - #the pt vla model initialized with the feature vector, named rule: initailized_{pt_ckpt}_with_{task-specific model name}_${task name on libero} - initialized_pt_vla_path: Union[str, Path] = "checkpoints/initialized_pt_vla" - #the original pretrained openvla model - pt_ckpt: Union[str, Path] = "checkpoints/openvla_base" - #the weight of the feature vector when initializing the pt vla model - feature_vector_weight: float = 1 # Weight of feature vector for interpolation - - use_l1_regression: bool = True # If True, uses continuous action head with L1 regression objective - use_diffusion: bool = False # If True, uses continuous action head with diffusion modeling objective (DDIM) - num_diffusion_steps_train: int = 50 # (When `diffusion==True`) Number of diffusion steps used for training - num_diffusion_steps_inference: int = 50 # (When `diffusion==True`) Number of diffusion steps used for inference - use_film: bool = False # If True, uses FiLM to infuse language inputs into visual features - num_images_in_input: int = 2 # Number of images in the VLA input (default: 1) - use_proprio: bool = True # Whether to include proprio state in input - - center_crop: bool = True # Center crop? (if trained w/ random crop image aug) - num_open_loop_steps: int = 8 # Number of actions to execute open-loop before requerying policy - - lora_rank: int = 32 # Rank of LoRA weight matrix (MAKE SURE THIS MATCHES TRAINING!) - - unnorm_key: Union[str, Path] = "" # Action un-normalization key - - load_in_8bit: bool = False # (For OpenVLA only) Load with 8-bit quantization - load_in_4bit: bool = False # (For OpenVLA only) Load with 4-bit quantization - - ################################################################################################################# - # LIBERO environment-specific parameters - ################################################################################################################# - task_suite_name: str = "de" # Task suite - num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize in sim - num_trials_per_task: int = 50 # Number of rollouts per task - initial_states_path: str = "DEFAULT" # "DEFAULT", or path to initial states JSON file - env_img_res: int = 256 # Resolution for environment images (not policy input resolution) - - ################################################################################################################# - # Utils - ################################################################################################################# - run_id_note: Optional[str] = None # Extra note to add to end of run ID for logging - local_log_dir: str = "./experiments/logs" # Local directory for eval logs - - use_wandb: bool = False # Whether to also log results in Weights & Biases - wandb_entity: str = "your-wandb-entity" # Name of WandB entity - wandb_project: str = "your-wandb-project" # Name of WandB project - - seed: int = 7 # Random Seed (for reproducibility) - -def validate_config(cfg: GenerateConfig) -> None: - """Validate configuration parameters.""" - assert cfg.pretrained_checkpoint is not None, "pretrained_checkpoint must not be None!" - - if "image_aug" in str(cfg.pretrained_checkpoint): - assert cfg.center_crop, "Expecting `center_crop==True` because model was trained with image augmentations!" - - assert not (cfg.load_in_8bit and cfg.load_in_4bit), "Cannot use both 8-bit and 4-bit quantization!" - - # Validate task suite - assert cfg.task_suite_name in [suite.value for suite in TaskSuite], f"Invalid task suite: {cfg.task_suite_name}" - -def initialize_model(cfg: GenerateConfig, only_pt: bool = False): #load action_head and noisy_action_projector separately - """Initialize model and associated components.""" - # Load model - model = get_model(cfg) - - # Load proprio projector if needed - proprio_projector = None - if cfg.use_proprio: - proprio_projector = get_proprio_projector( - cfg, - model.llm_dim, - proprio_dim=8, # 8-dimensional proprio for LIBERO - ) - - # Load action head if needed - action_head = None - if cfg.use_l1_regression or cfg.use_diffusion: - action_head = get_action_head(cfg, model.llm_dim) - - # Load noisy action projector if using diffusion - noisy_action_projector = None - if cfg.use_diffusion: - noisy_action_projector = get_noisy_action_projector(cfg, model.llm_dim) - - # Get OpenVLA processor if needed - processor = None - if not only_pt: - if cfg.model_family == "openvla": - processor = get_processor(cfg) - check_unnorm_key(cfg, model) - - return model, action_head, proprio_projector, noisy_action_projector, processor - -# @draccus.wrap() -def generate_feature_vector(cfg: GenerateConfig): - """Generate a feature vector (parameter differences) between two task-specific models.""" - # Validate configuration - - # Set random seed - set_seed_everywhere(cfg.seed) - - # Initialize model and components - model, action_head, proprio_projector, noisy_action_projector, processor = initialize_model(cfg) - - original_config = GenerateConfig( - pretrained_checkpoint=cfg.original_pretrained_checkpoint, - task_suite_name=cfg.task_suite_name, - ) - - original_model, original_action_head, original_proprio_projector, original_noisy_action_projector, original_processor = initialize_model(original_config) - #for action_head and noisy_action_projector, these modules are not interpolated - assert len(model.state_dict()) == len(original_model.state_dict()) - feature_vector_dict = {} - total = len(original_model.state_dict()) - for name, original_model_param in tqdm(original_model.named_parameters(), total=total): - model_param = model.state_dict()[name] - feature_vector_dict[name] = (model_param - original_model_param).detach().cpu() - - return feature_vector_dict - -# @draccus.wrap() -def interpolate_feature_vector(cfg: GenerateConfig): - """Interpolate feature vector.""" - feature_vector_dict = torch.load(cfg.vector_save_path) - - pt_vla_config = GenerateConfig( - pretrained_checkpoint=cfg.pt_ckpt, - original_pretrained_checkpoint=cfg.original_pretrained_checkpoint, - vector_save_path=cfg.vector_save_path, - initialized_pt_vla_path=cfg.initialized_pt_vla_path, - feature_vector_weight=cfg.feature_vector_weight, - pt_ckpt=cfg.pt_ckpt, - task_suite_name=cfg.task_suite_name, - use_proprio=False, - use_l1_regression=False, - use_diffusion=False - ) - - pt_vla,_,_,_,_ = initialize_model(pt_vla_config, only_pt=True) - - #copy the SF parameters for checking the change before and after interpolation - model_sd = pt_vla.state_dict() - before_interp_sd = {k: v.clone() for k, v in model_sd.items() if v.dtype.is_floating_point} - - with torch.no_grad(): - pt_params = dict(pt_vla.named_parameters()) - for name, diff in feature_vector_dict.items(): - if name in pt_params: - pt_param = pt_params[name] - diff = diff.to(pt_param.device) - pt_param.add_(diff, alpha=cfg.feature_vector_weight) - - #check after interpolation - diffs_after = [] - for name, before_tensor in before_interp_sd.items(): - after_tensor = model_sd[name] - difference = (after_tensor - before_tensor).float().norm().item() - diffs_after.append(difference) - - print(f"[DEBUG] post-interp (SF -> interp): mean={sum(diffs_after)/len(diffs_after):.6f}, " - f"max={max(diffs_after):.6f}, num_tensors={len(diffs_after)}") - - ######################################################### - return pt_vla - -@draccus.wrap() -def main(cfg: GenerateConfig): - if not os.path.exists(cfg.vector_save_path): - feature_vector_dict = generate_feature_vector(cfg) - torch.save(feature_vector_dict, cfg.vector_save_path) - else: - print(f"Feature vector already exists at {cfg.vector_save_path}") - initialized_pt_vla = interpolate_feature_vector(cfg) - os.makedirs(cfg.initialized_pt_vla_path, exist_ok=True) - initialized_pt_vla.save_pretrained(cfg.initialized_pt_vla_path) - -if __name__ == "__main__": - main() diff --git a/capvector-oft/capvector/interpolate.sh b/capvector-oft/capvector/interpolate.sh deleted file mode 100644 index 4debebe4a4e03c8652fabed63e6c508de943e064..0000000000000000000000000000000000000000 --- a/capvector-oft/capvector/interpolate.sh +++ /dev/null @@ -1,26 +0,0 @@ -TASK=spatial # or object / goal / 10 / 90 -VERSION=21.4 -PT_CKPT="checkpoints/openvla_base" -TASK_MODEL_CHECKPOINT="checkpoints/task_models/SF_${TASK}" -REFERENCE_MODEL_CHECKPOINT="checkpoints/reference_models/openvla_oft_libero_${TASK}" -VECTOR_SAVE_PATH="checkpoints/feature_vectors/feature_vector_with_SF_${TASK}_v${VERSION}.pth" -INITIALIZED_PT_VLA_PATH="checkpoints/initialized_pt_vla/initailized_openvla_with_SF_${TASK}_v${VERSION}" -TASK_SUITE_NAME="libero_${TASK}" - -python interpolate.py \ - --pretrained_checkpoint "$TASK_MODEL_CHECKPOINT" \ - --original_pretrained_checkpoint "$REFERENCE_MODEL_CHECKPOINT" \ - --vector_save_path "$VECTOR_SAVE_PATH" \ - --initialized_pt_vla_path $INITIALIZED_PT_VLA_PATH \ - --pt_ckpt $PT_CKPT\ - --feature_vector_weight 0.5 \ - --task_suite_name $TASK_SUITE_NAME - -#the code below is used to transplant the model parameters except for the vla backbone, such as processor and tokenizer, so that the initialized model is complete - -rsync -av \ - --ignore-existing \ - --exclude='*.safetensors' \ - --exclude='*.back.*' \ - $PT_CKPT/ \ - $INITIALIZED_PT_VLA_PATH/ diff --git a/capvector-oft/capvector/interpolate_robotwin.py b/capvector-oft/capvector/interpolate_robotwin.py deleted file mode 100644 index 6a19dcfdece46dd308dd33aa024569210a7607f2..0000000000000000000000000000000000000000 --- a/capvector-oft/capvector/interpolate_robotwin.py +++ /dev/null @@ -1,247 +0,0 @@ -""" -This is for extracting feature vector from the openvla-oft model and interpolating it with the original openvla model. -""" - - -import os -import json -import logging - -import sys -from collections import deque -from dataclasses import dataclass -from enum import Enum -from pathlib import Path -from typing import Optional, Union -from PIL import Image - -import draccus -import numpy as np -from tqdm import tqdm -import torch -import copy - -import wandb - -REPO_ROOT = Path(__file__).resolve().parents[1] -if str(REPO_ROOT) not in sys.path: - sys.path.append(str(REPO_ROOT)) -from experiments.robot.openvla_utils import ( - get_action_head, - get_noisy_action_projector, - get_processor, - get_proprio_projector, - resize_image_for_policy, -) -from experiments.robot.robot_utils import ( - DATE_TIME, - get_action, - get_image_resize_size, - get_model, - invert_gripper_action, - normalize_gripper_action, - set_seed_everywhere, -) -from experiments.robot.libero.run_libero_eval import check_unnorm_key -from prismatic.vla.constants import NUM_ACTIONS_CHUNK -from prismatic.vla.constants import PROPRIO_DIM - -# Set up logging -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s [%(levelname)s] %(message)s", - handlers=[logging.StreamHandler()], -) -logger = logging.getLogger(__name__) - - -@dataclass -class GenerateConfig: - # fmt: off - - ################################################################################################################# - # Model-specific parameters - ################################################################################################################# - model_family: str = "openvla" # Model family - #the task-specific model after sf fine-tuning - pretrained_checkpoint: Union[str, Path] = "checkpoints/task_model" # Task-specific checkpoint path - #the task-specific model after oft fine-tuning - original_pretrained_checkpoint: Union[str, Path] = "checkpoints/reference_model" # Reference checkpoint path - #feature vector is the difference between the two models, which represents the spatial features - vector_save_path: Union[str, Path] = "checkpoints/feature_vectors/feature_vector.pth" - #the pt vla model initialized with the feature vector, named rule: initailized_{pt_ckpt}_with_{task-specific model name}_${task name on libero} - initialized_pt_vla_path: Union[str, Path] = "checkpoints/initialized_pt_vla" - #the original pretrained openvla model - pt_ckpt: Union[str, Path] = "checkpoints/openvla_base" - #the weight of the feature vector when initializing the pt vla model - feature_vector_weight: float = 1 # Weight of feature vector for interpolation - - use_l1_regression: bool = True # If True, uses continuous action head with L1 regression objective - use_diffusion: bool = False # If True, uses continuous action head with diffusion modeling objective (DDIM) - num_diffusion_steps_train: int = 50 # (When `diffusion==True`) Number of diffusion steps used for training - num_diffusion_steps_inference: int = 50 # (When `diffusion==True`) Number of diffusion steps used for inference - use_film: bool = False # If True, uses FiLM to infuse language inputs into visual features - num_images_in_input: int = 3 # Number of images in the VLA input (default: 1) - use_proprio: bool = True # Whether to include proprio state in input - - center_crop: bool = True # Center crop? (if trained w/ random crop image aug) - num_open_loop_steps: int = 8 # Number of actions to execute open-loop before requerying policy - - lora_rank: int = 32 # Rank of LoRA weight matrix (MAKE SURE THIS MATCHES TRAINING!) - - unnorm_key: Union[str, Path] = "" # Action un-normalization key - - load_in_8bit: bool = False # (For OpenVLA only) Load with 8-bit quantization - load_in_4bit: bool = False # (For OpenVLA only) Load with 4-bit quantization - - ################################################################################################################# - # LIBERO environment-specific parameters - ################################################################################################################# - task_suite_name: str = "de" # Task suite - num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize in sim - num_trials_per_task: int = 50 # Number of rollouts per task - initial_states_path: str = "DEFAULT" # "DEFAULT", or path to initial states JSON file - env_img_res: int = 256 # Resolution for environment images (not policy input resolution) - - ################################################################################################################# - # Utils - ################################################################################################################# - run_id_note: Optional[str] = None # Extra note to add to end of run ID for logging - local_log_dir: str = "./experiments/logs" # Local directory for eval logs - - use_wandb: bool = False # Whether to also log results in Weights & Biases - wandb_entity: str = "your-wandb-entity" # Name of WandB entity - wandb_project: str = "your-wandb-project" # Name of WandB project - - seed: int = 7 # Random Seed (for reproducibility) - -def validate_config(cfg: GenerateConfig) -> None: - """Validate configuration parameters.""" - assert cfg.pretrained_checkpoint is not None, "pretrained_checkpoint must not be None!" - - if "image_aug" in str(cfg.pretrained_checkpoint): - assert cfg.center_crop, "Expecting `center_crop==True` because model was trained with image augmentations!" - - assert not (cfg.load_in_8bit and cfg.load_in_4bit), "Cannot use both 8-bit and 4-bit quantization!" - - # Validate task suite - # assert cfg.task_suite_name in [suite.value for suite in TaskSuite], f"Invalid task suite: {cfg.task_suite_name}" - -def initialize_model(cfg: GenerateConfig, only_pt: bool = False): #load action_head and noisy_action_projector separately - """Initialize model and associated components.""" - # Load model - model = get_model(cfg) - - # Load proprio projector if needed - proprio_projector = None - if cfg.use_proprio: - proprio_projector = get_proprio_projector( - cfg, - model.llm_dim, - proprio_dim=PROPRIO_DIM, #set the proprio_dim for different robots - ) - - # Load action head if needed - action_head = None - if cfg.use_l1_regression or cfg.use_diffusion: - action_head = get_action_head(cfg, model.llm_dim) - - # Load noisy action projector if using diffusion - noisy_action_projector = None - if cfg.use_diffusion: - noisy_action_projector = get_noisy_action_projector(cfg, model.llm_dim) - - # Get OpenVLA processor if needed - processor = None - if not only_pt: - if cfg.model_family == "openvla": - processor = get_processor(cfg) - # check_unnorm_key(cfg, model) - - return model, action_head, proprio_projector, noisy_action_projector, processor - -# @draccus.wrap() -def generate_feature_vector(cfg: GenerateConfig): - """Generate a feature vector (parameter differences) between two task-specific models.""" - # Validate configuration - - # Set random seed - set_seed_everywhere(cfg.seed) - - # Initialize model and components - model, action_head, proprio_projector, noisy_action_projector, processor = initialize_model(cfg) - - original_config = GenerateConfig( - pretrained_checkpoint=cfg.original_pretrained_checkpoint, - task_suite_name=cfg.task_suite_name, - ) - - original_model, original_action_head, original_proprio_projector, original_noisy_action_projector, original_processor = initialize_model(original_config) - #for action_head and noisy_action_projector, these modules are not interpolated - assert len(model.state_dict()) == len(original_model.state_dict()) - feature_vector_dict = {} - total = len(original_model.state_dict()) - for name, original_model_param in tqdm(original_model.named_parameters(), total=total): - model_param = model.state_dict()[name] - feature_vector_dict[name] = (model_param - original_model_param).detach().cpu() - - return feature_vector_dict - -# @draccus.wrap() -def interpolate_feature_vector(cfg: GenerateConfig): - """Interpolate feature vector.""" - feature_vector_dict = torch.load(cfg.vector_save_path) - - pt_vla_config = GenerateConfig( - pretrained_checkpoint=cfg.pt_ckpt, - original_pretrained_checkpoint=cfg.original_pretrained_checkpoint, - vector_save_path=cfg.vector_save_path, - initialized_pt_vla_path=cfg.initialized_pt_vla_path, - feature_vector_weight=cfg.feature_vector_weight, - pt_ckpt=cfg.pt_ckpt, - task_suite_name=cfg.task_suite_name, - use_proprio=False, - use_l1_regression=False, - use_diffusion=False - ) - - pt_vla,_,_,_,_ = initialize_model(pt_vla_config, only_pt=True) - - #copy the SF parameters for checking the change before and after interpolation - model_sd = pt_vla.state_dict() - before_interp_sd = {k: v.clone() for k, v in model_sd.items() if v.dtype.is_floating_point} - - with torch.no_grad(): - pt_params = dict(pt_vla.named_parameters()) - for name, diff in feature_vector_dict.items(): - if name in pt_params: - pt_param = pt_params[name] - diff = diff.to(pt_param.device) - pt_param.add_(diff, alpha=cfg.feature_vector_weight) - - #check after interpolation - diffs_after = [] - for name, before_tensor in before_interp_sd.items(): - after_tensor = model_sd[name] - difference = (after_tensor - before_tensor).float().norm().item() - diffs_after.append(difference) - - print(f"[DEBUG] post-interp (SF -> interp): mean={sum(diffs_after)/len(diffs_after):.6f}, " - f"max={max(diffs_after):.6f}, num_tensors={len(diffs_after)}") - - ######################################################### - return pt_vla - -@draccus.wrap() -def main(cfg: GenerateConfig): - if not os.path.exists(cfg.vector_save_path): - feature_vector_dict = generate_feature_vector(cfg) - torch.save(feature_vector_dict, cfg.vector_save_path) - else: - print(f"Feature vector already exists at {cfg.vector_save_path}") - initialized_pt_vla = interpolate_feature_vector(cfg) - os.makedirs(cfg.initialized_pt_vla_path, exist_ok=True) - initialized_pt_vla.save_pretrained(cfg.initialized_pt_vla_path) - -if __name__ == "__main__": - main() diff --git a/capvector-oft/capvector/tools/check_model_config.py b/capvector-oft/capvector/tools/check_model_config.py deleted file mode 100644 index ddc29e880cc4fd35a907a04ca9d8bf178955861b..0000000000000000000000000000000000000000 --- a/capvector-oft/capvector/tools/check_model_config.py +++ /dev/null @@ -1,23 +0,0 @@ -#This is for checking the completeness of the model parameters. -import argparse - -import torch - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("checkpoint_path", help="Path to the feature vector checkpoint (.pth)") - args = parser.parse_args() - - fv = torch.load(args.checkpoint_path, map_location="cpu") - - print("num_tensors:", len(fv)) - nz = 0 - for _, value in fv.items(): - if value.abs().sum().item() != 0: - nz += 1 - print("nonzero_tensors:", nz) - - -if __name__ == "__main__": - main() diff --git a/capvector-oft/capvector/tools/compute_lora_diff.py b/capvector-oft/capvector/tools/compute_lora_diff.py deleted file mode 100644 index 5232107c4efd638287fdd7e10b5e0dcf5ec247cd..0000000000000000000000000000000000000000 --- a/capvector-oft/capvector/tools/compute_lora_diff.py +++ /dev/null @@ -1,36 +0,0 @@ -#This is for computing the difference between the base model and the target model. -from safetensors.torch import load_file, save_file -import torch -import argparse - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--base", required=True) - parser.add_argument("--target", required=True) - parser.add_argument("--out", default="lora_diff.safetensors") - args = parser.parse_args() - - base = load_file(args.base) - target = load_file(args.target) - - diff = {} - - print("=== Key Comparison ===") - only_in_base = set(base) - set(target) - only_in_target = set(target) - set(base) - - print("Only in base:", list(only_in_base)[:10]) - print("Only in target:", list(only_in_target)[:10]) - - for k in target: - if k in base: - diff[k] = target[k] - base[k] - else: - # keep the new parameters - diff[k] = target[k].clone() - - save_file(diff, args.out) - print(f"\nSaved diff to: {args.out}") - -if __name__ == "__main__": - main() diff --git a/capvector-oft/capvector/tools/compute_lora_diff.sh b/capvector-oft/capvector/tools/compute_lora_diff.sh deleted file mode 100644 index 8cbd7ae132a7f853a566a2929761756a3a33596a..0000000000000000000000000000000000000000 --- a/capvector-oft/capvector/tools/compute_lora_diff.sh +++ /dev/null @@ -1,8 +0,0 @@ -BASE_ADAPTER="checkpoints/reference_models/openvla_oft_libero_spatial/lora_adapter/adapter_model.safetensors" -TARGET_ADAPTER="checkpoints/task_models/SF_spatial/lora_adapter/adapter_model.safetensors" -OUTPUT_DIFF="checkpoints/lora_diff/sf_150000_steps_spatial_adapter_diff.safetensors" - -python compute_lora_diff.py \ - --base "$BASE_ADAPTER" \ - --target "$TARGET_ADAPTER" \ - --out "$OUTPUT_DIFF" diff --git a/capvector-oft/capvector/tools/vector_analyze.py b/capvector-oft/capvector/tools/vector_analyze.py deleted file mode 100644 index b7aca71369fe0474811c7fad2992eff5b4599589..0000000000000000000000000000000000000000 --- a/capvector-oft/capvector/tools/vector_analyze.py +++ /dev/null @@ -1,153 +0,0 @@ -#This is for analyzing the vector of the model and finding out which layers have the largest absolute values. -import argparse -import csv -import os -import re -from collections import OrderedDict, defaultdict - -import matplotlib.pyplot as plt -import torch - - -LAYER_PREFIX = "language_model.model.layers." -NUM_LAYERS = 32 -USE_LOG_Y = True - - -def pick_state_dict(obj): - if isinstance(obj, (OrderedDict, dict)): - for key in ["state_dict", "model_state_dict", "model", "net", "weights", "params"]: - if key in obj and isinstance(obj[key], (OrderedDict, dict)): - return obj[key] - if any(torch.is_tensor(value) for value in obj.values()): - return obj - return None - - -def aggregate_layers_abs_sum(state_dict): - layer_sum = defaultdict(float) - layer_cnt = defaultdict(int) - pattern = re.compile(r"^" + re.escape(LAYER_PREFIX) + r"(\d+)\.") - - for name, tensor in state_dict.items(): - if not isinstance(name, str): - continue - match = pattern.match(name) - if match is None or not torch.is_tensor(tensor): - continue - - layer_id = int(match.group(1)) - if layer_id < 0 or layer_id >= NUM_LAYERS: - continue - - value = tensor.detach() - if value.is_cuda: - value = value.cpu() - - value = value.to(torch.float64) - layer_sum[layer_id] += value.abs().sum().item() - layer_cnt[layer_id] += 1 - - for layer_id in range(NUM_LAYERS): - layer_sum[layer_id] = float(layer_sum.get(layer_id, 0.0)) - layer_cnt[layer_id] = int(layer_cnt.get(layer_id, 0)) - - return layer_sum, layer_cnt - - -def save_layer_csv(layer_sum, layer_cnt, path): - output_dir = os.path.dirname(path) - if output_dir: - os.makedirs(output_dir, exist_ok=True) - with open(path, "w", newline="") as file_obj: - writer = csv.DictWriter(file_obj, fieldnames=["layer_id", "abs_sum", "num_tensors"]) - writer.writeheader() - for layer_id in range(NUM_LAYERS): - writer.writerow( - { - "layer_id": layer_id, - "abs_sum": f"{layer_sum[layer_id]:.12e}", - "num_tensors": layer_cnt[layer_id], - } - ) - - -def plot_line(xs, ys, out_png, title): - ys_plot = ys[:] - if USE_LOG_Y: - min_pos = min([value for value in ys_plot if value > 0], default=1e-300) - eps = min_pos * 1e-6 if min_pos > 0 else 1e-300 - ys_plot = [value if value > 0 else eps for value in ys_plot] - - plt.figure(figsize=(12, 4.5)) - plt.plot(xs, ys_plot, marker="o", linewidth=1.5) - plt.xlabel("Layer id") - plt.ylabel("abs_sum (all params in layer)") - plt.title(title) - plt.grid(True, which="both", linestyle="--", linewidth=0.5, alpha=0.5) - if USE_LOG_Y: - plt.yscale("log") - plt.tight_layout() - plt.savefig(out_png, dpi=200) - plt.close() - - -def plot_bar(xs, ys, out_png, title): - ys_plot = ys[:] - if USE_LOG_Y: - min_pos = min([value for value in ys_plot if value > 0], default=1e-300) - eps = min_pos * 1e-6 if min_pos > 0 else 1e-300 - ys_plot = [value if value > 0 else eps for value in ys_plot] - - plt.figure(figsize=(12, 4.5)) - plt.bar(xs, ys_plot) - plt.xlabel("Layer id") - plt.ylabel("abs_sum (all params in layer)") - plt.title(title) - plt.grid(True, which="both", axis="y", linestyle="--", linewidth=0.5, alpha=0.5) - if USE_LOG_Y: - plt.yscale("log") - plt.tight_layout() - plt.savefig(out_png, dpi=200) - plt.close() - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("checkpoint_path", help="Path to the feature vector checkpoint (.pth)") - args = parser.parse_args() - - base = os.path.splitext(args.checkpoint_path)[0] - out_csv = base + "_language_model_layers_abs_sum.csv" - out_png_line = base + "_language_model_layers_abs_sum_line.png" - out_png_bar = base + "_language_model_layers_abs_sum_bar.png" - - ckpt = torch.load(args.checkpoint_path, map_location="cpu") - state_dict = pick_state_dict(ckpt) - - if state_dict is None: - print("Not a state_dict-like dict. Type:", type(ckpt)) - if isinstance(ckpt, dict): - print("Top-level keys:", list(ckpt.keys())[:50]) - raise SystemExit(1) - - layer_sum, layer_cnt = aggregate_layers_abs_sum(state_dict) - save_layer_csv(layer_sum, layer_cnt, out_csv) - print(f"Saved CSV: {out_csv}") - - xs = list(range(NUM_LAYERS)) - ys = [layer_sum[i] for i in xs] - plot_line(xs, ys, out_png_line, f"{LAYER_PREFIX}*: abs_sum per layer") - plot_bar(xs, ys, out_png_bar, f"{LAYER_PREFIX}*: abs_sum per layer") - - print(f"Saved plot: {out_png_line}") - print(f"Saved plot: {out_png_bar}") - - top = sorted(((i, layer_sum[i], layer_cnt[i]) for i in xs), key=lambda item: item[1], reverse=True)[:5] - print("Top-5 layers by abs_sum:") - for layer_id, abs_sum, tensor_count in top: - print(f" layer {layer_id:02d}: abs_sum={abs_sum:.6e}, tensors={tensor_count}") - - -if __name__ == "__main__": - main() diff --git a/capvector-oft/capvector/tools/vector_regularize.py b/capvector-oft/capvector/tools/vector_regularize.py deleted file mode 100644 index 279646e7d9cbeba522bf6636d7bea0d27c8471f1..0000000000000000000000000000000000000000 --- a/capvector-oft/capvector/tools/vector_regularize.py +++ /dev/null @@ -1,75 +0,0 @@ -# Used to regularize feature vectors by first computing the absolute-sum of each layer and then performing normalization - -import argparse -from collections import OrderedDict - -import torch - - -def pick_state_dict(obj): - """Extract state_dict from a checkpoint-like object""" - if isinstance(obj, (OrderedDict, dict)): - for k in ["state_dict", "model_state_dict", "model", "net", "weights", "params"]: - if k in obj and isinstance(obj[k], (OrderedDict, dict)): - return obj[k] - if any(torch.is_tensor(v) for v in obj.values()): - return obj - return None - - -def calculate_total_abs_sum(state_dict): - """Compute the sum of absolute values over all parameters""" - total_sum = 0.0 - param_count = 0 - - for name, tensor in state_dict.items(): - if not torch.is_tensor(tensor): - continue - - x = tensor.detach() - if x.is_cuda: - x = x.cpu() - - # Use float64 to ensure numerical precision - x = x.to(torch.float64) - abs_sum = x.abs().sum().item() - total_sum += abs_sum - param_count += 1 - - print(f"{name}: {abs_sum:.12e} (shape: {list(x.shape)}, numel: {x.numel()})") - - return total_sum, param_count - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("checkpoint_path", help="Path to the feature vector checkpoint (.pth)") - args = parser.parse_args() - - print(f"Loading checkpoint: {args.checkpoint_path}") - ckpt = torch.load(args.checkpoint_path, map_location="cpu") - sd = pick_state_dict(ckpt) - - if sd is None: - print("Error: failed to extract state_dict from checkpoint") - print(f"Checkpoint type: {type(ckpt)}") - if isinstance(ckpt, dict): - print(f"Top-level keys: {list(ckpt.keys())[:20]}") - raise SystemExit(1) - - print(f"\nFound {len(sd)} parameters\n") - print("=" * 80) - print("Absolute-sum of each parameter:") - print("=" * 80) - - total_abs_sum, param_count = calculate_total_abs_sum(sd) - - print("=" * 80) - print(f"\nSummary:") - print(f" Total number of parameters: {param_count}") - print(f" Sum of absolute values of all parameters: {total_abs_sum:.12e}") - print(f" Sum of absolute values of all parameters (scientific notation): {total_abs_sum:.6e}") - - -if __name__ == "__main__": - main() diff --git a/capvector-oft/experiments/robot/aloha/aloha_utils.py b/capvector-oft/experiments/robot/aloha/aloha_utils.py deleted file mode 100644 index bb8d2d747ffddeced9064cfe7bf99c20bf1865b3..0000000000000000000000000000000000000000 --- a/capvector-oft/experiments/robot/aloha/aloha_utils.py +++ /dev/null @@ -1,85 +0,0 @@ -"""Utils for evaluating policies in real-world ALOHA environments.""" - -import os - -import imageio -import numpy as np -from PIL import Image - -from experiments.robot.aloha.real_env import make_real_env -from experiments.robot.robot_utils import ( - DATE, - DATE_TIME, -) - - -def get_next_task_label(task_label): - """Prompt the user to input the next task.""" - if task_label == "": - user_input = "" - while user_input == "": - user_input = input("Enter the task name: ") - task_label = user_input - else: - user_input = input("Enter the task name (or leave blank to repeat the previous task): ") - if user_input == "": - pass # Do nothing -> Let task_label be the same - else: - task_label = user_input - print(f"Task: {task_label}") - return task_label - - -def get_aloha_env(): - """Initializes and returns the ALOHA environment.""" - env = make_real_env(init_node=True) - return env - - -def resize_image_for_preprocessing(img): - """ - Takes numpy array corresponding to a single image and resizes to 256x256, exactly as done - in the ALOHA data preprocessing script, which is used before converting the dataset to RLDS. - """ - ALOHA_PREPROCESS_SIZE = 256 - img = np.array( - Image.fromarray(img).resize((ALOHA_PREPROCESS_SIZE, ALOHA_PREPROCESS_SIZE), resample=Image.BICUBIC) - ) # BICUBIC is default; specify explicitly to make it clear - return img - - -def get_aloha_image(obs): - """Extracts third-person image from observations and preprocesses it.""" - # obs: dm_env._environment.TimeStep - img = obs.observation["images"]["cam_high"] - img = resize_image_for_preprocessing(img) - return img - - -def get_aloha_wrist_images(obs): - """Extracts both wrist camera images from observations and preprocesses them.""" - # obs: dm_env._environment.TimeStep - left_wrist_img = obs.observation["images"]["cam_left_wrist"] - right_wrist_img = obs.observation["images"]["cam_right_wrist"] - left_wrist_img = resize_image_for_preprocessing(left_wrist_img) - right_wrist_img = resize_image_for_preprocessing(right_wrist_img) - return left_wrist_img, right_wrist_img - - -def save_rollout_video(rollout_images, idx, success, task_description, log_file=None, notes=None): - """Saves an MP4 replay of an episode.""" - rollout_dir = f"./rollouts/{DATE}" - os.makedirs(rollout_dir, exist_ok=True) - processed_task_description = task_description.lower().replace(" ", "_").replace("\n", "_").replace(".", "_")[:50] - filetag = f"{rollout_dir}/{DATE_TIME}--openvla_oft--episode={idx}--success={success}--task={processed_task_description}" - if notes is not None: - filetag += f"--{notes}" - mp4_path = f"{filetag}.mp4" - video_writer = imageio.get_writer(mp4_path, fps=25) - for img in rollout_images: - video_writer.append_data(img) - video_writer.close() - print(f"Saved rollout MP4 at path {mp4_path}") - if log_file is not None: - log_file.write(f"Saved rollout MP4 at path {mp4_path}\n") - return mp4_path diff --git a/capvector-oft/experiments/robot/aloha/constants.py b/capvector-oft/experiments/robot/aloha/constants.py deleted file mode 100644 index 312b180978ee56da2d69442702fbb94f20f19f58..0000000000000000000000000000000000000000 --- a/capvector-oft/experiments/robot/aloha/constants.py +++ /dev/null @@ -1,100 +0,0 @@ -### Task parameters - -DATA_DIR = '/scr2/moojink/data/aloha1/' -TASK_CONFIGS = { - # fold shorts - 'fold_shorts':{ - 'dataset_dir': DATA_DIR + '/fold_shorts', - 'num_episodes': 20, - 'episode_len': 1000, - 'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist'] - }, - # fold shirt - 'fold_shirt':{ - 'dataset_dir': DATA_DIR + '/fold_shirt', - 'num_episodes': 30, - 'episode_len': 1250, - 'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist'] - }, - # scoop X into bowl - 'scoop_raisins_into_bowl':{ - 'dataset_dir': DATA_DIR + '/scoop_raisins_into_bowl', - 'num_episodes': 15, - 'episode_len': 900, - 'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist'] - }, - 'scoop_almonds_and_green_M&Ms_into_bowl':{ - 'dataset_dir': DATA_DIR + '/scoop_almonds_and_green_M&Ms_into_bowl', - 'num_episodes': 15, - 'episode_len': 900, - 'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist'] - }, - 'scoop_pretzels_into_bowl':{ - 'dataset_dir': DATA_DIR + '/scoop_pretzels_into_bowl', - 'num_episodes': 15, - 'episode_len': 900, - 'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist'] - }, - # put X into pot - 'put_red_pepper_into_pot':{ - 'dataset_dir': DATA_DIR + '/put_red_pepper_into_pot', - 'num_episodes': 100, - 'episode_len': 400, - 'camera_names': ['cam_high', 'cam_low', 'cam_left_wrist', 'cam_right_wrist'] - }, - 'put_yellow_corn_into_pot':{ - 'dataset_dir': DATA_DIR + '/put_yellow_corn_into_pot', - 'num_episodes': 100, - 'episode_len': 400, - 'camera_names': ['cam_high', 'cam_low', 'cam_left_wrist', 'cam_right_wrist'] - }, - 'put_green_pepper_into_pot':{ - 'dataset_dir': DATA_DIR + '/put_green_pepper_into_pot', - 'num_episodes': 100, - 'episode_len': 400, - 'camera_names': ['cam_high', 'cam_low', 'cam_left_wrist', 'cam_right_wrist'] - }, -} - -### ALOHA fixed constants -DT = 0.04 # 1 / 0.04 -> 25 Hz -JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"] -START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239] - -# Left finger position limits (qpos[7]), right_finger = -1 * left_finger -MASTER_GRIPPER_POSITION_OPEN = 0.02417 -MASTER_GRIPPER_POSITION_CLOSE = 0.01244 -PUPPET_GRIPPER_POSITION_OPEN = 0.05800 -PUPPET_GRIPPER_POSITION_CLOSE = 0.01844 - -# Gripper joint limits (qpos[6]) -MASTER_GRIPPER_JOINT_OPEN = 0.3083 # For ALOHA 1 -MASTER_GRIPPER_JOINT_CLOSE = -0.6842 # For ALOHA 1 -# MASTER_GRIPPER_JOINT_OPEN = -0.8 # For ALOHA 2 -# MASTER_GRIPPER_JOINT_CLOSE = -1.65 # For ALOHA 2 -PUPPET_GRIPPER_JOINT_OPEN = 1.4910 -PUPPET_GRIPPER_JOINT_CLOSE = -0.6213 - -############################ Helper functions ############################ - -MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) -PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) -MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE -PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE -MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x)) - -MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) -PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) -MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE -PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE -MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x)) - -MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) -PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) - -MASTER_POS2JOINT = lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE -MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN((x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)) -PUPPET_POS2JOINT = lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE -PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN((x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)) - -MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE)/2 diff --git a/capvector-oft/experiments/robot/aloha/preprocess_split_aloha_data.py b/capvector-oft/experiments/robot/aloha/preprocess_split_aloha_data.py deleted file mode 100644 index 5a835a5b346af805308d876b8b87c8570a988fa9..0000000000000000000000000000000000000000 --- a/capvector-oft/experiments/robot/aloha/preprocess_split_aloha_data.py +++ /dev/null @@ -1,260 +0,0 @@ -""" -Preprocesses ALOHA dataset(s) and splits them into train/val sets. - -Preprocessing includes downsizing images from 480x640 to 256x256. -Splits happen at the episode level (not step level), which means that -an episode is treated as an atomic unit that entirely goes to either -the train set or val set. - -Original ALOHA data layout: - /PATH/TO/DATASET/dataset_name/ - - episode_0.hdf5 - - episode_1.hdf5 - - ... - - episode_N.hdf5 - -Preprocessed data layout (after running this script): - /PATH/TO/PREPROCESSED_DATASETS/dataset_name/ - - train/ - - episode_0.hdf5 - - episode_1.hdf5 - - ... - - episode_M.hdf5 - - val/ - - episode_0.hdf5 - - episode_1.hdf5 - - ... - - episode_K.hdf5 - - where N > M > K - -Example usage: - # "put X into pot" task - python experiments/robot/aloha/preprocess_split_aloha_data.py \ - --dataset_path /scr/moojink/data/aloha1_raw/put_green_pepper_into_pot/ \ - --out_base_dir /scr/moojink/data/aloha1_preprocessed/ \ - --percent_val 0.05 && \ - python experiments/robot/aloha/preprocess_split_aloha_data.py \ - --dataset_path /scr/moojink/data/aloha1_raw/put_red_pepper_into_pot/ \ - --out_base_dir /scr/moojink/data/aloha1_preprocessed/ \ - --percent_val 0.05 && \ - python experiments/robot/aloha/preprocess_split_aloha_data.py \ - --dataset_path /scr/moojink/data/aloha1_raw/put_yellow_corn_into_pot/ \ - --out_base_dir /scr/moojink/data/aloha1_preprocessed/ \ - --percent_val 0.05 -""" - -import argparse -import glob -import os -import random - -import h5py -import numpy as np -from PIL import Image -from tqdm import tqdm - - -def load_hdf5(demo_path): - """Loads single episode.""" - if not os.path.isfile(demo_path): - print(f"Dataset does not exist at \n{demo_path}\n") - exit() - - print(f"Loading {demo_path}...") - with h5py.File(demo_path, "r") as root: - is_sim = root.attrs["sim"] - qpos = root["/observations/qpos"][()] - qvel = root["/observations/qvel"][()] - effort = root["/observations/effort"][()] - action = root["/action"][()] - image_dict = dict() - for cam_name in root["/observations/images/"].keys(): - image_dict[cam_name] = root[f"/observations/images/{cam_name}"][()] - print(f"Loading episode complete: {demo_path}") - - return qpos, qvel, effort, action, image_dict, is_sim - - -def load_and_preprocess_all_episodes(demo_paths, out_dataset_dir): - """ - Loads and preprocesses all episodes. - Resizes all images in one episode before loading the next, to reduce memory usage. - """ - cam_names = ["cam_high", "cam_left_wrist", "cam_right_wrist"] - idx = 0 - for demo in tqdm(demo_paths): - qpos, qvel, effort, action, image_dict, is_sim = load_hdf5(demo) - # Save non-image info - episode_len = image_dict["cam_high"].shape[0] - # Resize all images - print("Resizing images in episode...") - for k in cam_names: - resized_images = [] - for i in range(episode_len): - resized_images.append( - np.array( - Image.fromarray(image_dict[k][i]).resize( - (args.img_resize_size, args.img_resize_size), resample=Image.BICUBIC - ) - ) # BICUBIC is default; specify explicitly to make it clear - ) - image_dict[k] = np.stack(resized_images) - print("Resizing images in episode complete!") - # Save preprocessed episode - data_dict = dict( - qpos=qpos, - qvel=qvel, - effort=effort, - action=action, - image_dict=image_dict, - is_sim=is_sim, - ) - save_new_hdf5(out_dataset_dir, data_dict, idx) - idx += 1 - - -def randomly_split(full_qpos, full_qvel, full_effort, full_action, full_image_dict, percent_val): - """Randomly splits dataset into train and validation sets.""" - # Create a list of episode indices - num_episodes_total = len(full_qpos) - indices = list(range(num_episodes_total)) - # Shuffle the episode indices - random.shuffle(indices) - # Create new lists using the shuffled indices - shuffled_qpos = [full_qpos[idx] for idx in indices] - shuffled_qvel = [full_qvel[idx] for idx in indices] - shuffled_effort = [full_effort[idx] for idx in indices] - shuffled_action = [full_action[idx] for idx in indices] - shuffled_image_dict = { - "cam_high": [], - "cam_left_wrist": [], - "cam_right_wrist": [], - } - for k in full_image_dict.keys(): - shuffled_image_dict[k] = [full_image_dict[k][idx] for idx in indices] - # Split into train and val sets - num_episodes_val = int(num_episodes_total * percent_val) - print(f"Total # steps: {num_episodes_total}; using {num_episodes_val} ({percent_val:.2f}%) for val set") - num_episodes_train = num_episodes_total - num_episodes_val - train_dict = dict( - qpos=shuffled_qpos[:num_episodes_train], - qvel=shuffled_qvel[:num_episodes_train], - effort=shuffled_effort[:num_episodes_train], - action=shuffled_action[:num_episodes_train], - image_dict=dict( - cam_high=shuffled_image_dict["cam_high"][:num_episodes_train], - cam_left_wrist=shuffled_image_dict["cam_left_wrist"][:num_episodes_train], - cam_right_wrist=shuffled_image_dict["cam_right_wrist"][:num_episodes_train], - ), - ) - val_dict = dict( - qpos=shuffled_qpos[num_episodes_train:], - qvel=shuffled_qvel[num_episodes_train:], - effort=shuffled_effort[num_episodes_train:], - action=shuffled_action[num_episodes_train:], - image_dict=dict( - cam_high=shuffled_image_dict["cam_high"][num_episodes_train:], - cam_left_wrist=shuffled_image_dict["cam_left_wrist"][num_episodes_train:], - cam_right_wrist=shuffled_image_dict["cam_right_wrist"][num_episodes_train:], - ), - ) - return train_dict, val_dict - - -def save_new_hdf5(out_dataset_dir, data_dict, episode_idx): - """Saves an HDF5 file for a new episode.""" - camera_names = data_dict["image_dict"].keys() - H, W, C = data_dict["image_dict"]["cam_high"][0].shape - out_path = os.path.join(out_dataset_dir, f"episode_{episode_idx}.hdf5") - # Save HDF5 with same structure as original demos (except that now we combine all episodes into one HDF5 file) - with h5py.File( - out_path, "w", rdcc_nbytes=1024**2 * 2 - ) as root: # Magic constant for rdcc_nbytes comes from ALOHA codebase - episode_len = data_dict["qpos"].shape[0] - root.attrs["sim"] = data_dict["is_sim"] - obs = root.create_group("observations") - _ = obs.create_dataset("qpos", (episode_len, 14)) - _ = obs.create_dataset("qvel", (episode_len, 14)) - _ = obs.create_dataset("effort", (episode_len, 14)) - root["/observations/qpos"][...] = data_dict["qpos"] - root["/observations/qvel"][...] = data_dict["qvel"] - root["/observations/effort"][...] = data_dict["effort"] - image = obs.create_group("images") - for cam_name in camera_names: - _ = image.create_dataset( - cam_name, - (episode_len, H, W, C), - dtype="uint8", - chunks=(1, H, W, C), - ) - root[f"/observations/images/{cam_name}"][...] = data_dict["image_dict"][cam_name] - _ = root.create_dataset("action", (episode_len, 14)) - root["/action"][...] = data_dict["action"] - # Compute and save *relative* actions as well - actions = data_dict["action"] - relative_actions = np.zeros_like(actions) - relative_actions[:-1] = actions[1:] - actions[:-1] # Relative actions are the changes in joint pos - relative_actions[-1] = relative_actions[-2] # Just copy the second-to-last action for the last action - _ = root.create_dataset("relative_action", (episode_len, 14)) - root["/relative_action"][...] = relative_actions - print(f"Saved dataset: {out_path}") - - -def main(args): - # Create directory to save preprocessed dataset (if it doesn't exist already) - os.makedirs(args.out_base_dir, exist_ok=True) - out_dataset_dir = os.path.join(args.out_base_dir, os.path.basename(args.dataset_path.rstrip("/"))) - os.makedirs(out_dataset_dir, exist_ok=True) - # Get list of filepaths of all episodes - all_demo_paths = glob.glob(os.path.join(args.dataset_path, "*.hdf5")) # List of HDF5 filepaths - all_demo_paths.sort() - # Create a list of episode indices - num_episodes_total = len(all_demo_paths) - indices = list(range(num_episodes_total)) - # Shuffle the episode indices - random.shuffle(indices) - # Split into train and val sets - num_episodes_val = int(num_episodes_total * args.percent_val) - print(f"Total # episodes: {num_episodes_total}; using {num_episodes_val} ({args.percent_val:.2f}%) for val set") - num_episodes_train = num_episodes_total - num_episodes_val - train_indices = indices[:num_episodes_train] - val_indices = indices[num_episodes_train:] - train_demo_paths = [all_demo_paths[i] for i in train_indices] - val_demo_paths = [all_demo_paths[i] for i in val_indices] - # Preprocess all episodes and save the result - out_dataset_dir_train = os.path.join(out_dataset_dir, "train") - out_dataset_dir_val = os.path.join(out_dataset_dir, "val") - os.makedirs(out_dataset_dir_train, exist_ok=True) - os.makedirs(out_dataset_dir_val, exist_ok=True) - load_and_preprocess_all_episodes(train_demo_paths, out_dataset_dir_train) - load_and_preprocess_all_episodes(val_demo_paths, out_dataset_dir_val) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--dataset_path", - required=True, - help="Path to raw ALOHA dataset directory. Example: /PATH/TO/USER/data/aloha_raw/put_green_pepper_into_pot/", - ) - parser.add_argument( - "--out_base_dir", - required=True, - help="Path to directory in which to save preprocessed dataset. Example: /PATH/TO/USER/data/aloha_preprocessed/", - ) - parser.add_argument( - "--percent_val", - type=float, - help="Percent of dataset to use as validation set (measured in episodes, not steps).", - default=0.05, - ) - parser.add_argument( - "--img_resize_size", - type=int, - help="Size to resize images to. Final images will be square (img_resize_size x img_resize_size pixels).", - default=256, - ) - args = parser.parse_args() - - main(args) diff --git a/capvector-oft/experiments/robot/aloha/real_env.py b/capvector-oft/experiments/robot/aloha/real_env.py deleted file mode 100644 index a2b747ec537382a11a655e5e21e33c3a876a1c16..0000000000000000000000000000000000000000 --- a/capvector-oft/experiments/robot/aloha/real_env.py +++ /dev/null @@ -1,213 +0,0 @@ -import time -import numpy as np -import collections -import matplotlib.pyplot as plt -import dm_env - -from experiments.robot.aloha.constants import DT, START_ARM_POSE, MASTER_GRIPPER_JOINT_NORMALIZE_FN, PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN -from experiments.robot.aloha.constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN, PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN -from experiments.robot.aloha.constants import PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE -from experiments.robot.aloha.robot_utils import Recorder, ImageRecorder -from experiments.robot.aloha.robot_utils import setup_master_bot, setup_puppet_bot, move_arms, move_grippers -from interbotix_xs_modules.arm import InterbotixManipulatorXS -from interbotix_xs_msgs.msg import JointSingleCommand - -import IPython -e = IPython.embed - -class RealEnv: - """ - Environment for real robot bi-manual manipulation - Action space: [left_arm_qpos (6), # absolute joint position - left_gripper_positions (1), # normalized gripper position (0: close, 1: open) - right_arm_qpos (6), # absolute joint position - right_gripper_positions (1),] # normalized gripper position (0: close, 1: open) - - Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position - left_gripper_position (1), # normalized gripper position (0: close, 1: open) - right_arm_qpos (6), # absolute joint position - right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open) - "qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad) - left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing) - right_arm_qvel (6), # absolute joint velocity (rad) - right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing) - "images": {"cam_high": (480x640x3), # h, w, c, dtype='uint8' - "cam_low": (480x640x3), # h, w, c, dtype='uint8' - "cam_left_wrist": (480x640x3), # h, w, c, dtype='uint8' - "cam_right_wrist": (480x640x3)} # h, w, c, dtype='uint8' - """ - - def __init__(self, init_node, setup_robots=True): - self.puppet_bot_left = InterbotixManipulatorXS(robot_model="vx300s", group_name="arm", gripper_name="gripper", - robot_name=f'puppet_left', init_node=init_node) - self.puppet_bot_right = InterbotixManipulatorXS(robot_model="vx300s", group_name="arm", gripper_name="gripper", - robot_name=f'puppet_right', init_node=False) - if setup_robots: - self.setup_robots() - - self.recorder_left = Recorder('left', init_node=False) - self.recorder_right = Recorder('right', init_node=False) - self.image_recorder = ImageRecorder(init_node=False) - self.gripper_command = JointSingleCommand(name="gripper") - - def setup_robots(self): - setup_puppet_bot(self.puppet_bot_left) - setup_puppet_bot(self.puppet_bot_right) - - def get_qpos(self): - left_qpos_raw = self.recorder_left.qpos - right_qpos_raw = self.recorder_right.qpos - left_arm_qpos = left_qpos_raw[:6] - right_arm_qpos = right_qpos_raw[:6] - left_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[7])] # this is position not joint - right_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[7])] # this is position not joint - return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos]) - - def get_qvel(self): - left_qvel_raw = self.recorder_left.qvel - right_qvel_raw = self.recorder_right.qvel - left_arm_qvel = left_qvel_raw[:6] - right_arm_qvel = right_qvel_raw[:6] - left_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[7])] - right_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[7])] - return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel]) - - def get_effort(self): - left_effort_raw = self.recorder_left.effort - right_effort_raw = self.recorder_right.effort - left_robot_effort = left_effort_raw[:7] - right_robot_effort = right_effort_raw[:7] - return np.concatenate([left_robot_effort, right_robot_effort]) - - def get_images(self): - return self.image_recorder.get_images() - - def set_gripper_pose(self, left_gripper_desired_pos_normalized, right_gripper_desired_pos_normalized): - left_gripper_desired_joint = PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(left_gripper_desired_pos_normalized) - self.gripper_command.cmd = left_gripper_desired_joint - self.puppet_bot_left.gripper.core.pub_single.publish(self.gripper_command) - - right_gripper_desired_joint = PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(right_gripper_desired_pos_normalized) - self.gripper_command.cmd = right_gripper_desired_joint - self.puppet_bot_right.gripper.core.pub_single.publish(self.gripper_command) - - def _reset_joints(self): - reset_position = START_ARM_POSE[:6] - move_arms([self.puppet_bot_left, self.puppet_bot_right], [reset_position, reset_position], move_time=1) - - def _reset_gripper(self): - """Set to position mode and do position resets: first open then close. Then change back to PWM mode""" - move_grippers([self.puppet_bot_left, self.puppet_bot_right], [PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5) - move_grippers([self.puppet_bot_left, self.puppet_bot_right], [PUPPET_GRIPPER_JOINT_CLOSE] * 2, move_time=1) - - def _get_obs(self): - obs = collections.OrderedDict() - obs['qpos'] = self.get_qpos() - obs['qvel'] = self.get_qvel() - obs['effort'] = self.get_effort() - obs['images'] = self.get_images() - return obs - - def get_observation(self, t=0): - step_type = dm_env.StepType.FIRST if t == 0 else dm_env.StepType.MID - return dm_env.TimeStep( - step_type=step_type, - reward=self.get_reward(), - discount=None, - observation=self._get_obs() - ) - - def get_reward(self): - return 0 - - def reset(self, fake=False): - if not fake: - # Reboot puppet robot gripper motors - self.puppet_bot_left.dxl.robot_reboot_motors("single", "gripper", True) - self.puppet_bot_right.dxl.robot_reboot_motors("single", "gripper", True) - self._reset_joints() - self._reset_gripper() - return dm_env.TimeStep( - step_type=dm_env.StepType.FIRST, - reward=self.get_reward(), - discount=None, - observation=self._get_obs()) - - def step(self, action): - state_len = int(len(action) / 2) - left_action = action[:state_len] - right_action = action[state_len:] - self.puppet_bot_left.arm.set_joint_positions(left_action[:6], blocking=False) - self.puppet_bot_right.arm.set_joint_positions(right_action[:6], blocking=False) - self.set_gripper_pose(left_action[-1], right_action[-1]) - time.sleep(DT) - return dm_env.TimeStep( - step_type=dm_env.StepType.MID, - reward=self.get_reward(), - discount=None, - observation=self._get_obs()) - - -def get_action(master_bot_left, master_bot_right): - action = np.zeros(14) # 6 joint + 1 gripper, for two arms - # Arm actions - action[:6] = master_bot_left.dxl.joint_states.position[:6] - action[7:7+6] = master_bot_right.dxl.joint_states.position[:6] - # Gripper actions - action[6] = MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_left.dxl.joint_states.position[6]) - action[7+6] = MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_right.dxl.joint_states.position[6]) - - return action - - -def make_real_env(init_node, setup_robots=True): - env = RealEnv(init_node, setup_robots) - return env - - -def test_real_teleop(): - """ - Test bimanual teleoperation and show image observations onscreen. - It first reads joint poses from both master arms. - Then use it as actions to step the environment. - The environment returns full observations including images. - - An alternative approach is to have separate scripts for teleoperation and observation recording. - This script will result in higher fidelity (obs, action) pairs - """ - - onscreen_render = True - render_cam = 'cam_left_wrist' - - # source of data - master_bot_left = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper", - robot_name=f'master_left', init_node=True) - master_bot_right = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper", - robot_name=f'master_right', init_node=False) - setup_master_bot(master_bot_left) - setup_master_bot(master_bot_right) - - # setup the environment - env = make_real_env(init_node=False) - ts = env.reset(fake=True) - episode = [ts] - # setup visualization - if onscreen_render: - ax = plt.subplot() - plt_img = ax.imshow(ts.observation['images'][render_cam]) - plt.ion() - - for t in range(1000): - action = get_action(master_bot_left, master_bot_right) - ts = env.step(action) - episode.append(ts) - - if onscreen_render: - plt_img.set_data(ts.observation['images'][render_cam]) - plt.pause(DT) - else: - time.sleep(DT) - - -if __name__ == '__main__': - test_real_teleop() diff --git a/capvector-oft/experiments/robot/aloha/requirements_aloha.txt b/capvector-oft/experiments/robot/aloha/requirements_aloha.txt deleted file mode 100644 index e26a4eba548e81ca51a06d2ae07f4737b8aa28ed..0000000000000000000000000000000000000000 --- a/capvector-oft/experiments/robot/aloha/requirements_aloha.txt +++ /dev/null @@ -1,26 +0,0 @@ -numpy<2 -draccus -torchvision -torch -pyquaternion -pyyaml -rospkg -pexpect -mujoco==2.3.7 -dm_control==1.0.14 -opencv-python -matplotlib -einops -packaging -h5py -traitlets -ipdb -IPython -modern_robotics -Pillow -termcolor -imageio[ffmpeg] -uvicorn -fastapi -requests -json_numpy diff --git a/capvector-oft/experiments/robot/aloha/robot_utils.py b/capvector-oft/experiments/robot/aloha/robot_utils.py deleted file mode 100644 index 8d91655ed6b364343daaa081ca56a4323fb0e00c..0000000000000000000000000000000000000000 --- a/capvector-oft/experiments/robot/aloha/robot_utils.py +++ /dev/null @@ -1,187 +0,0 @@ -import numpy as np -import time -from experiments.robot.aloha.constants import DT -from interbotix_xs_msgs.msg import JointSingleCommand - -import IPython -e = IPython.embed - -class ImageRecorder: - def __init__(self, init_node=True, is_debug=False): - from collections import deque - import rospy - from cv_bridge import CvBridge - from sensor_msgs.msg import Image - self.is_debug = is_debug - self.bridge = CvBridge() - self.camera_names = ['cam_high', 'cam_low', 'cam_left_wrist', 'cam_right_wrist'] - if init_node: - rospy.init_node('image_recorder', anonymous=True) - for cam_name in self.camera_names: - setattr(self, f'{cam_name}_image', None) - setattr(self, f'{cam_name}_secs', None) - setattr(self, f'{cam_name}_nsecs', None) - if cam_name == 'cam_high': - callback_func = self.image_cb_cam_high - elif cam_name == 'cam_low': - callback_func = self.image_cb_cam_low - elif cam_name == 'cam_left_wrist': - callback_func = self.image_cb_cam_left_wrist - elif cam_name == 'cam_right_wrist': - callback_func = self.image_cb_cam_right_wrist - else: - raise NotImplementedError - rospy.Subscriber(f"/usb_{cam_name}/image_raw", Image, callback_func) - if self.is_debug: - setattr(self, f'{cam_name}_timestamps', deque(maxlen=50)) - time.sleep(0.5) - - def image_cb(self, cam_name, data): - setattr(self, f'{cam_name}_image', self.bridge.imgmsg_to_cv2(data, desired_encoding='passthrough')) - setattr(self, f'{cam_name}_secs', data.header.stamp.secs) - setattr(self, f'{cam_name}_nsecs', data.header.stamp.nsecs) - # cv2.imwrite('/home/tonyzhao/Desktop/sample.jpg', cv_image) - if self.is_debug: - getattr(self, f'{cam_name}_timestamps').append(data.header.stamp.secs + data.header.stamp.secs * 1e-9) - - def image_cb_cam_high(self, data): - cam_name = 'cam_high' - return self.image_cb(cam_name, data) - - def image_cb_cam_low(self, data): - cam_name = 'cam_low' - return self.image_cb(cam_name, data) - - def image_cb_cam_left_wrist(self, data): - cam_name = 'cam_left_wrist' - return self.image_cb(cam_name, data) - - def image_cb_cam_right_wrist(self, data): - cam_name = 'cam_right_wrist' - return self.image_cb(cam_name, data) - - def get_images(self): - image_dict = dict() - for cam_name in self.camera_names: - image_dict[cam_name] = getattr(self, f'{cam_name}_image') - return image_dict - - def print_diagnostics(self): - def dt_helper(l): - l = np.array(l) - diff = l[1:] - l[:-1] - return np.mean(diff) - for cam_name in self.camera_names: - image_freq = 1 / dt_helper(getattr(self, f'{cam_name}_timestamps')) - print(f'{cam_name} {image_freq=:.2f}') - print() - -class Recorder: - def __init__(self, side, init_node=True, is_debug=False): - from collections import deque - import rospy - from sensor_msgs.msg import JointState - from interbotix_xs_msgs.msg import JointGroupCommand, JointSingleCommand - - self.secs = None - self.nsecs = None - self.qpos = None - self.effort = None - self.arm_command = None - self.gripper_command = None - self.is_debug = is_debug - - if init_node: - rospy.init_node('recorder', anonymous=True) - rospy.Subscriber(f"/puppet_{side}/joint_states", JointState, self.puppet_state_cb) - rospy.Subscriber(f"/puppet_{side}/commands/joint_group", JointGroupCommand, self.puppet_arm_commands_cb) - rospy.Subscriber(f"/puppet_{side}/commands/joint_single", JointSingleCommand, self.puppet_gripper_commands_cb) - if self.is_debug: - self.joint_timestamps = deque(maxlen=50) - self.arm_command_timestamps = deque(maxlen=50) - self.gripper_command_timestamps = deque(maxlen=50) - time.sleep(0.1) - - def puppet_state_cb(self, data): - self.qpos = data.position - self.qvel = data.velocity - self.effort = data.effort - self.data = data - if self.is_debug: - self.joint_timestamps.append(time.time()) - - def puppet_arm_commands_cb(self, data): - self.arm_command = data.cmd - if self.is_debug: - self.arm_command_timestamps.append(time.time()) - - def puppet_gripper_commands_cb(self, data): - self.gripper_command = data.cmd - if self.is_debug: - self.gripper_command_timestamps.append(time.time()) - - def print_diagnostics(self): - def dt_helper(l): - l = np.array(l) - diff = l[1:] - l[:-1] - return np.mean(diff) - - joint_freq = 1 / dt_helper(self.joint_timestamps) - arm_command_freq = 1 / dt_helper(self.arm_command_timestamps) - gripper_command_freq = 1 / dt_helper(self.gripper_command_timestamps) - - print(f'{joint_freq=:.2f}\n{arm_command_freq=:.2f}\n{gripper_command_freq=:.2f}\n') - -def get_arm_joint_positions(bot): - return bot.arm.core.joint_states.position[:6] - -def get_arm_gripper_positions(bot): - joint_position = bot.gripper.core.joint_states.position[6] - return joint_position - -def move_arms(bot_list, target_pose_list, move_time=1): - num_steps = int(move_time / DT) - curr_pose_list = [get_arm_joint_positions(bot) for bot in bot_list] - traj_list = [np.linspace(curr_pose, target_pose, num_steps) for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)] - for t in range(num_steps): - for bot_id, bot in enumerate(bot_list): - bot.arm.set_joint_positions(traj_list[bot_id][t], blocking=False) - time.sleep(DT) - -def move_grippers(bot_list, target_pose_list, move_time): - gripper_command = JointSingleCommand(name="gripper") - num_steps = int(move_time / DT) - curr_pose_list = [get_arm_gripper_positions(bot) for bot in bot_list] - traj_list = [np.linspace(curr_pose, target_pose, num_steps) for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)] - for t in range(num_steps): - for bot_id, bot in enumerate(bot_list): - gripper_command.cmd = traj_list[bot_id][t] - bot.gripper.core.pub_single.publish(gripper_command) - time.sleep(DT) - -def setup_puppet_bot(bot): - bot.dxl.robot_reboot_motors("single", "gripper", True) - bot.dxl.robot_set_operating_modes("group", "arm", "position") - bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position") - torque_on(bot) - -def setup_master_bot(bot): - bot.dxl.robot_set_operating_modes("group", "arm", "pwm") - bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position") - torque_off(bot) - -def set_standard_pid_gains(bot): - bot.dxl.robot_set_motor_registers("group", "arm", 'Position_P_Gain', 800) - bot.dxl.robot_set_motor_registers("group", "arm", 'Position_I_Gain', 0) - -def set_low_pid_gains(bot): - bot.dxl.robot_set_motor_registers("group", "arm", 'Position_P_Gain', 100) - bot.dxl.robot_set_motor_registers("group", "arm", 'Position_I_Gain', 0) - -def torque_off(bot): - bot.dxl.robot_torque_enable("group", "arm", False) - bot.dxl.robot_torque_enable("single", "gripper", False) - -def torque_on(bot): - bot.dxl.robot_torque_enable("group", "arm", True) - bot.dxl.robot_torque_enable("single", "gripper", True) diff --git a/capvector-oft/experiments/robot/aloha/run_aloha_eval.py b/capvector-oft/experiments/robot/aloha/run_aloha_eval.py deleted file mode 100644 index cef67d8bb5f293e1929e5fdef049278546153d8a..0000000000000000000000000000000000000000 --- a/capvector-oft/experiments/robot/aloha/run_aloha_eval.py +++ /dev/null @@ -1,385 +0,0 @@ -""" -run_aloha_eval.py - -Evaluates a model in a real-world ALOHA environment. -""" - -import logging -import os -import socket -import sys -import time -from collections import deque -from dataclasses import dataclass -from pathlib import Path -from typing import Optional, Union - -import draccus -import tqdm - -# Append current directory so that interpreter can find experiments.robot -sys.path.append(".") -from experiments.robot.aloha.aloha_utils import ( - get_aloha_env, - get_aloha_image, - get_aloha_wrist_images, - get_next_task_label, - save_rollout_video, -) -from experiments.robot.openvla_utils import ( - get_action_from_server, - resize_image_for_policy, -) -from experiments.robot.robot_utils import ( - DATE_TIME, - get_image_resize_size, - set_seed_everywhere, -) - -# Set up logging -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s [%(levelname)s] %(message)s", - handlers=[logging.StreamHandler()], -) -logger = logging.getLogger(__name__) - - -@dataclass -class GenerateConfig: - # fmt: off - - ################################################################################################################# - # Model-specific parameters - ################################################################################################################# - model_family: str = "openvla" # Model family - - center_crop: bool = True # Center crop? (if trained w/ random crop image aug) - num_open_loop_steps: int = 25 # Number of actions to execute open-loop before requerying policy - - use_vla_server: bool = True # Whether to query remote VLA server for actions - vla_server_url: Union[str, Path] = "" # Remote VLA server URL (set to 127.0.0.1 if on same machine) - - ################################################################################################################# - # ALOHA environment-specific parameters - ################################################################################################################# - num_rollouts_planned: int = 50 # Number of test rollouts - max_steps: int = 1500 # Max number of steps per rollout - use_relative_actions: bool = False # Whether to use relative actions (delta joint angles) - - ################################################################################################################# - # Utils - ################################################################################################################# - run_id_note: Optional[str] = None # Extra note to add to end of run ID for logging - local_log_dir: str = "./experiments/logs" # Local directory for eval logs - - seed: int = 7 # Random Seed (for reproducibility) - - # fmt: on - - -def validate_config(cfg: GenerateConfig) -> None: - """Validate configuration parameters.""" - assert cfg.use_vla_server, ( - "Must use VLA server (server-client interface) to query model and get actions! Please set --use_vla_server=True" - ) - - -def setup_logging(cfg: GenerateConfig): - """Set up logging to file.""" - # Create run ID - run_id = f"EVAL-{cfg.model_family}-{DATE_TIME}" - if cfg.run_id_note is not None: - run_id += f"--{cfg.run_id_note}" - - # Set up local logging - os.makedirs(cfg.local_log_dir, exist_ok=True) - local_log_filepath = os.path.join(cfg.local_log_dir, run_id + ".txt") - log_file = open(local_log_filepath, "w") - logger.info(f"Logging to local log file: {local_log_filepath}") - - return log_file, local_log_filepath, run_id - - -def log_message(message: str, log_file=None): - """Log a message to console and optionally to a log file.""" - print(message) - logger.info(message) - if log_file: - log_file.write(message + "\n") - log_file.flush() - - -def get_server_endpoint(cfg: GenerateConfig): - """Get the server endpoint for remote inference.""" - ip_address = socket.gethostbyname(cfg.vla_server_url) - return f"http://{ip_address}:8777/act" - - -def prepare_observation(obs, resize_size): - """Prepare observation for policy input.""" - # Get preprocessed images - img = get_aloha_image(obs) - left_wrist_img, right_wrist_img = get_aloha_wrist_images(obs) - - # Resize images to size expected by model - img_resized = resize_image_for_policy(img, resize_size) - left_wrist_img_resized = resize_image_for_policy(left_wrist_img, resize_size) - right_wrist_img_resized = resize_image_for_policy(right_wrist_img, resize_size) - - # Prepare observations dict - observation = { - "full_image": img_resized, - "left_wrist_image": left_wrist_img_resized, - "right_wrist_image": right_wrist_img_resized, - "state": obs.observation["qpos"], - } - - return observation, img_resized, left_wrist_img_resized, right_wrist_img_resized - - -def run_episode( - cfg: GenerateConfig, - env, - task_description: str, - server_endpoint: str, - resize_size, - log_file=None, -): - """Run a single episode in the ALOHA environment.""" - # Define control frequency - STEP_DURATION_IN_SEC = 1.0 / 25.0 - - # Reset environment - obs = env.reset() - - # Initialize action queue - action_queue = deque(maxlen=cfg.num_open_loop_steps) - - # Setup - t = 0 - curr_state = None - replay_images = [] - replay_images_resized = [] - replay_images_left_wrist_resized = [] - replay_images_right_wrist_resized = [] - - log_message("Prepare the scene, and then press Enter to begin...", log_file) - input() - - # Reset environment again to fetch first timestep observation - obs = env.reset() - - # Fetch initial robot state (but sleep first so that robot stops moving) - time.sleep(2) - curr_state = env.get_qpos() - - episode_start_time = time.time() - total_model_query_time = 0.0 - - try: - while t < cfg.max_steps: - # Get step start time (used to compute how much to sleep between steps) - step_start_time = time.time() - - # Get observation - obs = env.get_observation(t=t) - - # Save raw high camera image for replay video - replay_images.append(obs.observation["images"]["cam_high"]) - - # If action queue is empty, requery model - if len(action_queue) == 0: - # Prepare observation - observation, img_resized, left_wrist_resized, right_wrist_resized = prepare_observation(obs, resize_size) - observation["instruction"] = task_description - - # Save processed images for replay - replay_images_resized.append(img_resized) - replay_images_left_wrist_resized.append(left_wrist_resized) - replay_images_right_wrist_resized.append(right_wrist_resized) - - # Query model to get action - log_message("Requerying model...", log_file) - model_query_start_time = time.time() - actions = get_action_from_server(observation, server_endpoint) - actions = actions[: cfg.num_open_loop_steps] - total_model_query_time += time.time() - model_query_start_time - action_queue.extend(actions) - - # Get action from queue - action = action_queue.popleft() - log_message("-----------------------------------------------------", log_file) - log_message(f"t: {t}", log_file) - log_message(f"action: {action}", log_file) - - # Execute action in environment - if cfg.use_relative_actions: - # Get absolute joint angles from relative action - rel_action = action - target_state = curr_state + rel_action - obs = env.step(target_state.tolist()) - # Update current state (assume it is the commanded target state) - curr_state = target_state - else: - obs = env.step(action.tolist()) - t += 1 - - # Sleep until next timestep - step_elapsed_time = time.time() - step_start_time - if step_elapsed_time < STEP_DURATION_IN_SEC: - time_to_sleep = STEP_DURATION_IN_SEC - step_elapsed_time - log_message(f"Sleeping {time_to_sleep} sec...", log_file) - time.sleep(time_to_sleep) - - except (KeyboardInterrupt, Exception) as e: - if isinstance(e, KeyboardInterrupt): - log_message("\nCaught KeyboardInterrupt: Terminating episode early.", log_file) - else: - log_message(f"\nCaught exception: {e}", log_file) - - episode_end_time = time.time() - - # Get success feedback from user - user_input = input("Success? Enter 'y' or 'n': ") - success = True if user_input.lower() == "y" else False - - # Calculate episode statistics - episode_stats = { - "success": success, - "total_steps": t, - "model_query_time": total_model_query_time, - "episode_duration": episode_end_time - episode_start_time, - } - - return ( - episode_stats, - replay_images, - replay_images_resized, - replay_images_left_wrist_resized, - replay_images_right_wrist_resized, - ) - - -def save_episode_videos( - replay_images, - replay_images_resized, - replay_images_left_wrist, - replay_images_right_wrist, - episode_idx, - success, - task_description, - log_file=None, -): - """Save videos of the episode from different camera angles.""" - # Save main replay video - save_rollout_video(replay_images, episode_idx, success=success, task_description=task_description, log_file=log_file) - - # Save processed view videos - save_rollout_video( - replay_images_resized, - episode_idx, - success=success, - task_description=task_description, - log_file=log_file, - notes="resized", - ) - save_rollout_video( - replay_images_left_wrist, - episode_idx, - success=success, - task_description=task_description, - log_file=log_file, - notes="left_wrist_resized", - ) - save_rollout_video( - replay_images_right_wrist, - episode_idx, - success=success, - task_description=task_description, - log_file=log_file, - notes="right_wrist_resized", - ) - - -@draccus.wrap() -def eval_aloha(cfg: GenerateConfig) -> None: - """Main function to evaluate a trained policy in a real-world ALOHA environment.""" - # Validate configuration - validate_config(cfg) - - # Set random seed - set_seed_everywhere(cfg.seed) - - # Setup logging - log_file, local_log_filepath, run_id = setup_logging(cfg) - - # Get expected image dimensions - resize_size = get_image_resize_size(cfg) - - # Get ALOHA environment - env = get_aloha_env() - - # Get server endpoint for remote inference - server_endpoint = get_server_endpoint(cfg) - - # Initialize task description - task_description = "" - - # Start evaluation - num_rollouts_completed, total_successes = 0, 0 - - for episode_idx in tqdm.tqdm(range(cfg.num_rollouts_planned)): - # Get task description from user - task_description = get_next_task_label(task_description) - log_message(f"\nTask: {task_description}", log_file) - - log_message(f"Starting episode {num_rollouts_completed + 1}...", log_file) - - # Run episode - episode_stats, replay_images, replay_images_resized, replay_images_left_wrist, replay_images_right_wrist = ( - run_episode(cfg, env, task_description, server_endpoint, resize_size, log_file) - ) - - # Update counters - num_rollouts_completed += 1 - if episode_stats["success"]: - total_successes += 1 - - # Save videos - save_episode_videos( - replay_images, - replay_images_resized, - replay_images_left_wrist, - replay_images_right_wrist, - num_rollouts_completed, - episode_stats["success"], - task_description, - log_file, - ) - - # Log results - log_message(f"Success: {episode_stats['success']}", log_file) - log_message(f"# episodes completed so far: {num_rollouts_completed}", log_file) - log_message(f"# successes: {total_successes} ({total_successes / num_rollouts_completed * 100:.1f}%)", log_file) - log_message(f"Total model query time: {episode_stats['model_query_time']:.2f} sec", log_file) - log_message(f"Total episode elapsed time: {episode_stats['episode_duration']:.2f} sec", log_file) - - # Calculate final success rate - final_success_rate = float(total_successes) / float(num_rollouts_completed) if num_rollouts_completed > 0 else 0 - - # Log final results - log_message("\nFinal results:", log_file) - log_message(f"Total episodes: {num_rollouts_completed}", log_file) - log_message(f"Total successes: {total_successes}", log_file) - log_message(f"Overall success rate: {final_success_rate:.4f} ({final_success_rate * 100:.1f}%)", log_file) - - # Close log file - if log_file: - log_file.close() - - return final_success_rate - - -if __name__ == "__main__": - eval_aloha() diff --git a/capvector-oft/experiments/robot/libero/libero_requirements.txt b/capvector-oft/experiments/robot/libero/libero_requirements.txt deleted file mode 100644 index 2d496021919c8f2f21a2d277aca93e881aa9dd9f..0000000000000000000000000000000000000000 --- a/capvector-oft/experiments/robot/libero/libero_requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -imageio[ffmpeg] -robosuite==1.4.1 -bddl -easydict -cloudpickle -gym diff --git a/capvector-oft/experiments/robot/libero/libero_utils.py b/capvector-oft/experiments/robot/libero/libero_utils.py deleted file mode 100644 index d3575da8db6d3d433d77451a38d707b3680800b9..0000000000000000000000000000000000000000 --- a/capvector-oft/experiments/robot/libero/libero_utils.py +++ /dev/null @@ -1,87 +0,0 @@ -"""Utils for evaluating policies in LIBERO simulation environments.""" - -import math -import os - -import imageio -import numpy as np -import tensorflow as tf -from libero.libero import get_libero_path -from libero.libero.envs import OffScreenRenderEnv - -from experiments.robot.robot_utils import ( - DATE, - DATE_TIME, -) - - -def get_libero_env(task, model_family, resolution=256): - """Initializes and returns the LIBERO environment, along with the task description.""" - task_description = task.language - task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file) - env_args = {"bddl_file_name": task_bddl_file, "camera_heights": resolution, "camera_widths": resolution} - env = OffScreenRenderEnv(**env_args) - env.seed(0) # IMPORTANT: seed seems to affect object positions even when using fixed initial state - return env, task_description - - -def get_libero_dummy_action(model_family: str): - """Get dummy/no-op action, used to roll out the simulation while the robot does nothing.""" - return [0, 0, 0, 0, 0, 0, -1] - - -def get_libero_image(obs): - """Extracts third-person image from observations and preprocesses it.""" - img = obs["agentview_image"] - img = img[::-1, ::-1] # IMPORTANT: rotate 180 degrees to match train preprocessing - return img - - -def get_libero_wrist_image(obs): - """Extracts wrist camera image from observations and preprocesses it.""" - img = obs["robot0_eye_in_hand_image"] - img = img[::-1, ::-1] # IMPORTANT: rotate 180 degrees to match train preprocessing - return img - - -def save_rollout_video(rollout_images, idx, success, task_description, log_file=None): - """Saves an MP4 replay of an episode.""" - rollout_dir = f"./rollouts/{DATE}" - os.makedirs(rollout_dir, exist_ok=True) - processed_task_description = task_description.lower().replace(" ", "_").replace("\n", "_").replace(".", "_")[:50] - mp4_path = f"{rollout_dir}/{DATE_TIME}--openvla_oft--episode={idx}--success={success}--task={processed_task_description}.mp4" - video_writer = imageio.get_writer(mp4_path, fps=30) - for img in rollout_images: - video_writer.append_data(img) - video_writer.close() - print(f"Saved rollout MP4 at path {mp4_path}") - if log_file is not None: - log_file.write(f"Saved rollout MP4 at path {mp4_path}\n") - return mp4_path - - -def quat2axisangle(quat): - """ - Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55 - - Converts quaternion to axis-angle format. - Returns a unit vector direction scaled by its angle in radians. - - Args: - quat (np.array): (x,y,z,w) vec4 float angles - - Returns: - np.array: (ax,ay,az) axis-angle exponential coordinates - """ - # clip quaternion - if quat[3] > 1.0: - quat[3] = 1.0 - elif quat[3] < -1.0: - quat[3] = -1.0 - - den = np.sqrt(1.0 - quat[3] * quat[3]) - if math.isclose(den, 0.0): - # This is (close to) a zero degree rotation, immediately return - return np.zeros(3) - - return (quat[:3] * 2.0 * math.acos(quat[3])) / den diff --git a/capvector-oft/experiments/robot/libero/regenerate_libero_dataset.py b/capvector-oft/experiments/robot/libero/regenerate_libero_dataset.py deleted file mode 100644 index 2643acb9f686684fdd019ce2568a1cd6685e0a59..0000000000000000000000000000000000000000 --- a/capvector-oft/experiments/robot/libero/regenerate_libero_dataset.py +++ /dev/null @@ -1,249 +0,0 @@ -""" -Regenerates a LIBERO dataset (HDF5 files) by replaying demonstrations in the environments. - -Notes: - - We save image observations at 256x256px resolution (instead of 128x128). - - We filter out transitions with "no-op" (zero) actions that do not change the robot's state. - - We filter out unsuccessful demonstrations. - - In the LIBERO HDF5 data -> RLDS data conversion (not shown here), we rotate the images by - 180 degrees because we observe that the environments return images that are upside down - on our platform. - -Usage: - python experiments/robot/libero/regenerate_libero_dataset.py \ - --libero_task_suite [ libero_spatial | libero_object | libero_goal | libero_10 ] \ - --libero_raw_data_dir \ - --libero_target_dir - - Example (LIBERO-Spatial): - python experiments/robot/libero/regenerate_libero_dataset.py \ - --libero_task_suite libero_spatial \ - --libero_raw_data_dir ./LIBERO/libero/datasets/libero_spatial \ - --libero_target_dir ./LIBERO/libero/datasets/libero_spatial_no_noops - -""" - -import argparse -import json -import os -import time - -import h5py -import numpy as np -import robosuite.utils.transform_utils as T -import tqdm -from libero.libero import benchmark - -from experiments.robot.libero.libero_utils import ( - get_libero_dummy_action, - get_libero_env, -) - - -IMAGE_RESOLUTION = 256 - - -def is_noop(action, prev_action=None, threshold=1e-4): - """ - Returns whether an action is a no-op action. - - A no-op action satisfies two criteria: - (1) All action dimensions, except for the last one (gripper action), are near zero. - (2) The gripper action is equal to the previous timestep's gripper action. - - Explanation of (2): - Naively filtering out actions with just criterion (1) is not good because you will - remove actions where the robot is staying still but opening/closing its gripper. - So you also need to consider the current state (by checking the previous timestep's - gripper action as a proxy) to determine whether the action really is a no-op. - """ - # Special case: Previous action is None if this is the first action in the episode - # Then we only care about criterion (1) - if prev_action is None: - return np.linalg.norm(action[:-1]) < threshold - - # Normal case: Check both criteria (1) and (2) - gripper_action = action[-1] - prev_gripper_action = prev_action[-1] - return np.linalg.norm(action[:-1]) < threshold and gripper_action == prev_gripper_action - - -def main(args): - print(f"Regenerating {args.libero_task_suite} dataset!") - - # Create target directory - if os.path.isdir(args.libero_target_dir): - user_input = input(f"Target directory already exists at path: {args.libero_target_dir}\nEnter 'y' to overwrite the directory, or anything else to exit: ") - if user_input != 'y': - exit() - os.makedirs(args.libero_target_dir, exist_ok=True) - - # Prepare JSON file to record success/false and initial states per episode - metainfo_json_dict = {} - metainfo_json_out_path = f"./experiments/robot/libero/{args.libero_task_suite}_metainfo.json" - with open(metainfo_json_out_path, "w") as f: - # Just test that we can write to this file (we overwrite it later) - json.dump(metainfo_json_dict, f) - - # Get task suite - benchmark_dict = benchmark.get_benchmark_dict() - task_suite = benchmark_dict[args.libero_task_suite]() - num_tasks_in_suite = task_suite.n_tasks - - # Setup - num_replays = 0 - num_success = 0 - num_noops = 0 - - for task_id in tqdm.tqdm(range(num_tasks_in_suite)): - # Get task in suite - task = task_suite.get_task(task_id) - env, task_description = get_libero_env(task, "llava", resolution=IMAGE_RESOLUTION) - - # Get dataset for task - orig_data_path = os.path.join(args.libero_raw_data_dir, f"{task.name}_demo.hdf5") - assert os.path.exists(orig_data_path), f"Cannot find raw data file {orig_data_path}." - orig_data_file = h5py.File(orig_data_path, "r") - orig_data = orig_data_file["data"] - - # Create new HDF5 file for regenerated demos - new_data_path = os.path.join(args.libero_target_dir, f"{task.name}_demo.hdf5") - new_data_file = h5py.File(new_data_path, "w") - grp = new_data_file.create_group("data") - - for i in range(len(orig_data.keys())): - # Get demo data - demo_data = orig_data[f"demo_{i}"] - orig_actions = demo_data["actions"][()] - orig_states = demo_data["states"][()] - - # Reset environment, set initial state, and wait a few steps for environment to settle - env.reset() - env.set_init_state(orig_states[0]) - for _ in range(10): - obs, reward, done, info = env.step(get_libero_dummy_action("llava")) - - # Set up new data lists - states = [] - actions = [] - ee_states = [] - gripper_states = [] - joint_states = [] - robot_states = [] - agentview_images = [] - eye_in_hand_images = [] - - # Replay original demo actions in environment and record observations - for _, action in enumerate(orig_actions): - # Skip transitions with no-op actions - prev_action = actions[-1] if len(actions) > 0 else None - if is_noop(action, prev_action): - print(f"\tSkipping no-op action: {action}") - num_noops += 1 - continue - - if states == []: - # In the first timestep, since we're using the original initial state to initialize the environment, - # copy the initial state (first state in episode) over from the original HDF5 to the new one - states.append(orig_states[0]) - robot_states.append(demo_data["robot_states"][0]) - else: - # For all other timesteps, get state from environment and record it - states.append(env.sim.get_state().flatten()) - robot_states.append( - np.concatenate([obs["robot0_gripper_qpos"], obs["robot0_eef_pos"], obs["robot0_eef_quat"]]) - ) - - # Record original action (from demo) - actions.append(action) - - # Record data returned by environment - if "robot0_gripper_qpos" in obs: - gripper_states.append(obs["robot0_gripper_qpos"]) - joint_states.append(obs["robot0_joint_pos"]) - ee_states.append( - np.hstack( - ( - obs["robot0_eef_pos"], - T.quat2axisangle(obs["robot0_eef_quat"]), - ) - ) - ) - agentview_images.append(obs["agentview_image"]) - eye_in_hand_images.append(obs["robot0_eye_in_hand_image"]) - - # Execute demo action in environment - obs, reward, done, info = env.step(action.tolist()) - - # At end of episode, save replayed trajectories to new HDF5 files (only keep successes) - if done: - dones = np.zeros(len(actions)).astype(np.uint8) - dones[-1] = 1 - rewards = np.zeros(len(actions)).astype(np.uint8) - rewards[-1] = 1 - assert len(actions) == len(agentview_images) - - ep_data_grp = grp.create_group(f"demo_{i}") - obs_grp = ep_data_grp.create_group("obs") - obs_grp.create_dataset("gripper_states", data=np.stack(gripper_states, axis=0)) - obs_grp.create_dataset("joint_states", data=np.stack(joint_states, axis=0)) - obs_grp.create_dataset("ee_states", data=np.stack(ee_states, axis=0)) - obs_grp.create_dataset("ee_pos", data=np.stack(ee_states, axis=0)[:, :3]) - obs_grp.create_dataset("ee_ori", data=np.stack(ee_states, axis=0)[:, 3:]) - obs_grp.create_dataset("agentview_rgb", data=np.stack(agentview_images, axis=0)) - obs_grp.create_dataset("eye_in_hand_rgb", data=np.stack(eye_in_hand_images, axis=0)) - ep_data_grp.create_dataset("actions", data=actions) - ep_data_grp.create_dataset("states", data=np.stack(states)) - ep_data_grp.create_dataset("robot_states", data=np.stack(robot_states, axis=0)) - ep_data_grp.create_dataset("rewards", data=rewards) - ep_data_grp.create_dataset("dones", data=dones) - - num_success += 1 - - num_replays += 1 - - # Record success/false and initial environment state in metainfo dict - task_key = task_description.replace(" ", "_") - episode_key = f"demo_{i}" - if task_key not in metainfo_json_dict: - metainfo_json_dict[task_key] = {} - if episode_key not in metainfo_json_dict[task_key]: - metainfo_json_dict[task_key][episode_key] = {} - metainfo_json_dict[task_key][episode_key]["success"] = bool(done) - metainfo_json_dict[task_key][episode_key]["initial_state"] = orig_states[0].tolist() - - # Write metainfo dict to JSON file - # (We repeatedly overwrite, rather than doing this once at the end, just in case the script crashes midway) - with open(metainfo_json_out_path, "w") as f: - json.dump(metainfo_json_dict, f, indent=2) - - # Count total number of successful replays so far - print( - f"Total # episodes replayed: {num_replays}, Total # successes: {num_success} ({num_success / num_replays * 100:.1f} %)" - ) - - # Report total number of no-op actions filtered out so far - print(f" Total # no-op actions filtered out: {num_noops}") - - # Close HDF5 files - orig_data_file.close() - new_data_file.close() - print(f"Saved regenerated demos for task '{task_description}' at: {new_data_path}") - - print(f"Dataset regeneration complete! Saved new dataset at: {args.libero_target_dir}") - print(f"Saved metainfo JSON at: {metainfo_json_out_path}") - - -if __name__ == "__main__": - # Parse command-line arguments - parser = argparse.ArgumentParser() - parser.add_argument("--libero_task_suite", type=str, choices=["libero_spatial", "libero_object", "libero_goal", "libero_10", "libero_90"], - help="LIBERO task suite. Example: libero_spatial", required=True) - parser.add_argument("--libero_raw_data_dir", type=str, - help="Path to directory containing raw HDF5 dataset. Example: ./LIBERO/libero/datasets/libero_spatial", required=True) - parser.add_argument("--libero_target_dir", type=str, - help="Path to regenerated dataset directory. Example: ./LIBERO/libero/datasets/libero_spatial_no_noops", required=True) - args = parser.parse_args() - - # Start data regeneration - main(args) diff --git a/capvector-oft/experiments/robot/libero/run_libero_eval.py b/capvector-oft/experiments/robot/libero/run_libero_eval.py deleted file mode 100644 index e1c3085e89932acf7b32786e9005f9a765617576..0000000000000000000000000000000000000000 --- a/capvector-oft/experiments/robot/libero/run_libero_eval.py +++ /dev/null @@ -1,540 +0,0 @@ -""" -run_libero_eval.py - -Evaluates a trained policy in a LIBERO simulation benchmark task suite. -""" - -import json -import logging -import os -import sys -from collections import deque -from dataclasses import dataclass -from enum import Enum -from pathlib import Path -from typing import Optional, Union - -import draccus -import numpy as np -import tqdm -from libero.libero import benchmark - -import wandb - -# Append current directory so that interpreter can find experiments.robot -sys.path.append("../..") -from experiments.robot.libero.libero_utils import ( - get_libero_dummy_action, - get_libero_env, - get_libero_image, - get_libero_wrist_image, - quat2axisangle, - save_rollout_video, -) -from experiments.robot.openvla_utils import ( - get_action_head, - get_noisy_action_projector, - get_processor, - get_proprio_projector, - resize_image_for_policy, -) -from experiments.robot.robot_utils import ( - DATE_TIME, - get_action, - get_image_resize_size, - get_model, - invert_gripper_action, - normalize_gripper_action, - set_seed_everywhere, -) -from prismatic.vla.constants import NUM_ACTIONS_CHUNK - - -# import debugpy -# try: -# debugpy.listen(("localhost", 9501)) -# print("Waiting for debugger attach") -# debugpy.wait_for_client() -# except Exception as e: -# pass - - -# Define task suite constants -class TaskSuite(str, Enum): - LIBERO_SPATIAL = "libero_spatial" - LIBERO_OBJECT = "libero_object" - LIBERO_GOAL = "libero_goal" - LIBERO_10 = "libero_10" - LIBERO_90 = "libero_90" - - -# Define max steps for each task suite -TASK_MAX_STEPS = { - TaskSuite.LIBERO_SPATIAL: 220, # longest training demo has 193 steps - TaskSuite.LIBERO_OBJECT: 280, # longest training demo has 254 steps - TaskSuite.LIBERO_GOAL: 300, # longest training demo has 270 steps - TaskSuite.LIBERO_10: 520, # longest training demo has 505 steps - TaskSuite.LIBERO_90: 400, # longest training demo has 373 steps -} - - -# Set up logging -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s [%(levelname)s] %(message)s", - handlers=[logging.StreamHandler()], -) -logger = logging.getLogger(__name__) - - -@dataclass -class GenerateConfig: - # fmt: off - - ################################################################################################################# - # Model-specific parameters - ################################################################################################################# - model_family: str = "openvla" # Model family - pretrained_checkpoint: Union[str, Path] = "" # Pretrained checkpoint path - - use_l1_regression: bool = True # If True, uses continuous action head with L1 regression objective - use_diffusion: bool = False # If True, uses continuous action head with diffusion modeling objective (DDIM) - num_diffusion_steps_train: int = 50 # (When `diffusion==True`) Number of diffusion steps used for training - num_diffusion_steps_inference: int = 50 # (When `diffusion==True`) Number of diffusion steps used for inference - use_film: bool = False # If True, uses FiLM to infuse language inputs into visual features - num_images_in_input: int = 2 # Number of images in the VLA input (default: 1) - use_proprio: bool = True # Whether to include proprio state in input - - center_crop: bool = True # Center crop? (if trained w/ random crop image aug) - num_open_loop_steps: int = 8 # Number of actions to execute open-loop before requerying policy - - lora_rank: int = 32 # Rank of LoRA weight matrix (MAKE SURE THIS MATCHES TRAINING!) - - unnorm_key: Union[str, Path] = "" # Action un-normalization key - - load_in_8bit: bool = False # (For OpenVLA only) Load with 8-bit quantization - load_in_4bit: bool = False # (For OpenVLA only) Load with 4-bit quantization - - ################################################################################################################# - # LIBERO environment-specific parameters - ################################################################################################################# - task_suite_name: str = TaskSuite.LIBERO_SPATIAL # Task suite - num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize in sim - num_trials_per_task: int = 50 # Number of rollouts per task - initial_states_path: str = "DEFAULT" # "DEFAULT", or path to initial states JSON file - env_img_res: int = 256 # Resolution for environment images (not policy input resolution) - - ################################################################################################################# - # Utils - ################################################################################################################# - run_id_note: Optional[str] = None # Extra note to add to end of run ID for logging - local_log_dir: str = "./experiments/logs" # Local directory for eval logs - - use_wandb: bool = False # Whether to also log results in Weights & Biases - wandb_entity: str = "your-wandb-entity" # Name of WandB entity - wandb_project: str = "your-wandb-project" # Name of WandB project - - seed: int = 7 # Random Seed (for reproducibility) - - # fmt: on - - -def validate_config(cfg: GenerateConfig) -> None: - """Validate configuration parameters.""" - assert cfg.pretrained_checkpoint is not None, "pretrained_checkpoint must not be None!" - - if "image_aug" in str(cfg.pretrained_checkpoint): - assert cfg.center_crop, "Expecting `center_crop==True` because model was trained with image augmentations!" - - assert not (cfg.load_in_8bit and cfg.load_in_4bit), "Cannot use both 8-bit and 4-bit quantization!" - - # Validate task suite - assert cfg.task_suite_name in [suite.value for suite in TaskSuite], f"Invalid task suite: {cfg.task_suite_name}" - - -def initialize_model(cfg: GenerateConfig): - """Initialize model and associated components.""" - # Load model - model = get_model(cfg) - - # Load proprio projector if needed - proprio_projector = None - if cfg.use_proprio: - proprio_projector = get_proprio_projector( - cfg, - model.llm_dim, - proprio_dim=8, # 8-dimensional proprio for LIBERO - ) - - # Load action head if needed - action_head = None - if cfg.use_l1_regression or cfg.use_diffusion: - action_head = get_action_head(cfg, model.llm_dim) - - # Load noisy action projector if using diffusion - noisy_action_projector = None - if cfg.use_diffusion: - noisy_action_projector = get_noisy_action_projector(cfg, model.llm_dim) - - # Get OpenVLA processor if needed - processor = None - if cfg.model_family == "openvla": - processor = get_processor(cfg) - check_unnorm_key(cfg, model) - - return model, action_head, proprio_projector, noisy_action_projector, processor - - -def check_unnorm_key(cfg: GenerateConfig, model) -> None: - """Check that the model contains the action un-normalization key.""" - # Initialize unnorm_key - unnorm_key = cfg.task_suite_name - - # In some cases, the key must be manually modified (e.g. after training on a modified version of the dataset - # with the suffix "_no_noops" in the dataset name) - if unnorm_key not in model.norm_stats and f"{unnorm_key}_no_noops" in model.norm_stats: - unnorm_key = f"{unnorm_key}_no_noops" - - assert unnorm_key in model.norm_stats, f"Action un-norm key {unnorm_key} not found in VLA `norm_stats`!" - - # Set the unnorm_key in cfg - cfg.unnorm_key = unnorm_key - - -def setup_logging(cfg: GenerateConfig): - """Set up logging to file and optionally to wandb.""" - # Create run ID - run_id = f"EVAL-{cfg.task_suite_name}-{cfg.model_family}-{DATE_TIME}" - if cfg.run_id_note is not None: - run_id += f"--{cfg.run_id_note}" - - # Set up local logging - os.makedirs(cfg.local_log_dir, exist_ok=True) - local_log_filepath = os.path.join(cfg.local_log_dir, run_id + ".txt") - log_file = open(local_log_filepath, "w") - logger.info(f"Logging to local log file: {local_log_filepath}") - - # Initialize Weights & Biases logging if enabled - if cfg.use_wandb: - wandb.init( - entity=cfg.wandb_entity, - project=cfg.wandb_project, - name=run_id, - ) - - return log_file, local_log_filepath, run_id - - -def log_message(message: str, log_file=None): - """Log a message to console and optionally to a log file.""" - logger.info(message) - if log_file: - log_file.write(message + "\n") - log_file.flush() - - -def load_initial_states(cfg: GenerateConfig, task_suite, task_id: int, log_file=None): - """Load initial states for the given task.""" - # Get default initial states - initial_states = task_suite.get_task_init_states(task_id) - - # If using custom initial states, load them from file - if cfg.initial_states_path != "DEFAULT": - with open(cfg.initial_states_path, "r") as f: - all_initial_states = json.load(f) - log_message(f"Using initial states from {cfg.initial_states_path}", log_file) - return initial_states, all_initial_states - else: - log_message("Using default initial states", log_file) - return initial_states, None - - -def prepare_observation(obs, resize_size): - """Prepare observation for policy input.""" - # Get preprocessed images - img = get_libero_image(obs) - wrist_img = get_libero_wrist_image(obs) - - # Resize images to size expected by model - img_resized = resize_image_for_policy(img, resize_size) - wrist_img_resized = resize_image_for_policy(wrist_img, resize_size) - - # Prepare observations dict - observation = { - "full_image": img_resized, - "wrist_image": wrist_img_resized, - "state": np.concatenate( - (obs["robot0_eef_pos"], quat2axisangle(obs["robot0_eef_quat"]), obs["robot0_gripper_qpos"]) - ), - } - - return observation, img # Return both processed observation and original image for replay - - -def process_action(action, model_family): - """Process action before sending to environment.""" - # Normalize gripper action [0,1] -> [-1,+1] because the environment expects the latter - action = normalize_gripper_action(action, binarize=True) - - # [OpenVLA] The dataloader flips the sign of the gripper action to align with other datasets - # (0 = close, 1 = open), so flip it back (-1 = open, +1 = close) before executing the action - if model_family == "openvla": - action = invert_gripper_action(action) - - return action - - -def run_episode( - cfg: GenerateConfig, - env, - task_description: str, - model, - resize_size, - processor=None, - action_head=None, - proprio_projector=None, - noisy_action_projector=None, - initial_state=None, - log_file=None, -): - """Run a single episode in the environment.""" - # Reset environment - env.reset() - - # Set initial state if provided - if initial_state is not None: - obs = env.set_init_state(initial_state) - else: - obs = env.get_observation() - - # Initialize action queue - if cfg.num_open_loop_steps != NUM_ACTIONS_CHUNK: - print(f"WARNING: cfg.num_open_loop_steps ({cfg.num_open_loop_steps}) does not match the NUM_ACTIONS_CHUNK " - f"({NUM_ACTIONS_CHUNK}) constant defined in prismatic.vla.constants! For best performance (in terms of " - "both speed and success rate), we recommend executing the full action chunk.") - action_queue = deque(maxlen=cfg.num_open_loop_steps) - - # Setup - t = 0 - replay_images = [] - max_steps = TASK_MAX_STEPS[cfg.task_suite_name] - - # Run episode - success = False - try: - while t < max_steps + cfg.num_steps_wait: - # Do nothing for the first few timesteps to let objects stabilize - if t < cfg.num_steps_wait: - obs, reward, done, info = env.step(get_libero_dummy_action(cfg.model_family)) - t += 1 - continue - - # Prepare observation - observation, img = prepare_observation(obs, resize_size) - replay_images.append(img) - - # If action queue is empty, requery model - if len(action_queue) == 0: - # Query model to get action - actions = get_action( - cfg, - model, - observation, - task_description, - processor=processor, - action_head=action_head, - proprio_projector=proprio_projector, - noisy_action_projector=noisy_action_projector, - use_film=cfg.use_film, - ) - action_queue.extend(actions) - - # Get action from queue - action = action_queue.popleft() - - # Process action - action = process_action(action, cfg.model_family) - - # Execute action in environment - obs, reward, done, info = env.step(action.tolist()) - if done: - success = True - break - t += 1 - - except Exception as e: - log_message(f"Episode error: {e}", log_file) - - return success, replay_images - - -def run_task( - cfg: GenerateConfig, - task_suite, - task_id: int, - model, - resize_size, - processor=None, - action_head=None, - proprio_projector=None, - noisy_action_projector=None, - total_episodes=0, - total_successes=0, - log_file=None, -): - """Run evaluation for a single task.""" - # Get task - task = task_suite.get_task(task_id) - - # Get initial states - initial_states, all_initial_states = load_initial_states(cfg, task_suite, task_id, log_file) - - # Initialize environment and get task description - env, task_description = get_libero_env(task, cfg.model_family, resolution=cfg.env_img_res) - - # Start episodes - task_episodes, task_successes = 0, 0 - for episode_idx in tqdm.tqdm(range(cfg.num_trials_per_task)): - log_message(f"\nTask: {task_description}", log_file) - - # Handle initial state - if cfg.initial_states_path == "DEFAULT": - # Use default initial state - initial_state = initial_states[episode_idx] - else: - # Get keys for fetching initial episode state from JSON - initial_states_task_key = task_description.replace(" ", "_") - episode_key = f"demo_{episode_idx}" - - # Skip episode if expert demonstration failed to complete the task - if not all_initial_states[initial_states_task_key][episode_key]["success"]: - log_message(f"Skipping task {task_id} episode {episode_idx} due to failed expert demo!", log_file) - continue - - # Get initial state - initial_state = np.array(all_initial_states[initial_states_task_key][episode_key]["initial_state"]) - - log_message(f"Starting episode {task_episodes + 1}...", log_file) - - # Run episode - success, replay_images = run_episode( - cfg, - env, - task_description, - model, - resize_size, - processor, - action_head, - proprio_projector, - noisy_action_projector, - initial_state, - log_file, - ) - - # Update counters - task_episodes += 1 - total_episodes += 1 - if success: - task_successes += 1 - total_successes += 1 - - # Save replay video - save_rollout_video( - replay_images, total_episodes, success=success, task_description=task_description, log_file=log_file - ) - - # Log results - log_message(f"Success: {success}", log_file) - log_message(f"# episodes completed so far: {total_episodes}", log_file) - log_message(f"# successes: {total_successes} ({total_successes / total_episodes * 100:.1f}%)", log_file) - - # Log task results - task_success_rate = float(task_successes) / float(task_episodes) if task_episodes > 0 else 0 - total_success_rate = float(total_successes) / float(total_episodes) if total_episodes > 0 else 0 - - log_message(f"Current task success rate: {task_success_rate}", log_file) - log_message(f"Current total success rate: {total_success_rate}", log_file) - - # Log to wandb if enabled - if cfg.use_wandb: - wandb.log( - { - f"success_rate/{task_description}": task_success_rate, - f"num_episodes/{task_description}": task_episodes, - } - ) - - return total_episodes, total_successes - - -@draccus.wrap() -def eval_libero(cfg: GenerateConfig) -> float: - """Main function to evaluate a trained policy on LIBERO benchmark tasks.""" - # Validate configuration - validate_config(cfg) - - # Set random seed - set_seed_everywhere(cfg.seed) - - # Initialize model and components - model, action_head, proprio_projector, noisy_action_projector, processor = initialize_model(cfg) - - # Get expected image dimensions - resize_size = get_image_resize_size(cfg) - - # Setup logging - log_file, local_log_filepath, run_id = setup_logging(cfg) - - # Initialize LIBERO task suite - benchmark_dict = benchmark.get_benchmark_dict() - task_suite = benchmark_dict[cfg.task_suite_name]() - num_tasks = task_suite.n_tasks - - log_message(f"Task suite: {cfg.task_suite_name}", log_file) - - # Start evaluation - total_episodes, total_successes = 0, 0 - for task_id in tqdm.tqdm(range(num_tasks)): - total_episodes, total_successes = run_task( - cfg, - task_suite, - task_id, - model, - resize_size, - processor, - action_head, - proprio_projector, - noisy_action_projector, - total_episodes, - total_successes, - log_file, - ) - - # Calculate final success rate - final_success_rate = float(total_successes) / float(total_episodes) if total_episodes > 0 else 0 - - # Log final results - log_message("Final results:", log_file) - log_message(f"Total episodes: {total_episodes}", log_file) - log_message(f"Total successes: {total_successes}", log_file) - log_message(f"Overall success rate: {final_success_rate:.4f} ({final_success_rate * 100:.1f}%)", log_file) - - # Log to wandb if enabled - if cfg.use_wandb: - wandb.log( - { - "success_rate/total": final_success_rate, - "num_episodes/total": total_episodes, - } - ) - wandb.save(local_log_filepath) - - # Close log file - if log_file: - log_file.close() - - return final_success_rate - - -if __name__ == "__main__": - eval_libero() diff --git a/capvector-oft/experiments/robot/libero/sample_libero_spatial_observation.pkl b/capvector-oft/experiments/robot/libero/sample_libero_spatial_observation.pkl deleted file mode 100644 index 508c2640125e9279e4e5e2a9fbeadca1dd9bade2..0000000000000000000000000000000000000000 --- a/capvector-oft/experiments/robot/libero/sample_libero_spatial_observation.pkl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:326db6c78dd0a9d91c11f05af03b93fa3095338ee3cb5a5eb15adf3d87eb0109 -size 301501 diff --git a/capvector-oft/experiments/robot/openvla_utils.py b/capvector-oft/experiments/robot/openvla_utils.py deleted file mode 100644 index 9478246f9812d75fc4e1b4b9bd4d26dffdca1900..0000000000000000000000000000000000000000 --- a/capvector-oft/experiments/robot/openvla_utils.py +++ /dev/null @@ -1,818 +0,0 @@ -"""Utils for evaluating OpenVLA or fine-tuned OpenVLA policies.""" - -import filecmp -import json -import os -import shutil -import time -from datetime import datetime -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union - -import json_numpy -import numpy as np -import requests -import tensorflow as tf -import torch -from huggingface_hub import HfApi, hf_hub_download -from PIL import Image -from transformers import AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor - -# Apply JSON numpy patch for serialization -json_numpy.patch() - -from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig -from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction -from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor -from prismatic.models.action_heads import DiffusionActionHead, L1RegressionActionHead -from prismatic.models.film_vit_wrapper import FiLMedPrismaticVisionBackbone -from prismatic.models.projectors import NoisyActionProjector, ProprioProjector -from prismatic.vla.constants import ( - ACTION_DIM, - ACTION_PROPRIO_NORMALIZATION_TYPE, -) -from prismatic.vla.datasets.rlds.utils.data_utils import NormalizationType - -# Initialize important constants -DATE = time.strftime("%Y_%m_%d") -DATE_TIME = time.strftime("%Y_%m_%d-%H_%M_%S") -DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") -OPENVLA_IMAGE_SIZE = 224 # Standard image size expected by OpenVLA - -# Configure NumPy print settings -np.set_printoptions(formatter={"float": lambda x: "{0:0.3f}".format(x)}) - - -def model_is_on_hf_hub(model_path: str) -> bool: - """Checks whether a model path points to a model on Hugging Face Hub.""" - # If the API call below runs without error, the model is on the hub - try: - HfApi().model_info(model_path) - return True - except Exception: - return False - - -def update_auto_map(pretrained_checkpoint: str) -> None: - """ - Update the AutoMap configuration in the checkpoint config.json file. - - This loads the config.json file inside the checkpoint directory and overwrites - the AutoConfig and AutoModelForVision2Seq fields to use OpenVLA-specific classes. - - Args: - pretrained_checkpoint: Path to the checkpoint directory - """ - if not os.path.isdir(pretrained_checkpoint): - return - - config_path = os.path.join(pretrained_checkpoint, "config.json") - if not os.path.exists(config_path): - print(f"Warning: No config.json found at {config_path}") - return - - # Create timestamped backup - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - backup_path = os.path.join(pretrained_checkpoint, f"config.json.back.{timestamp}") - shutil.copy2(config_path, backup_path) - print(f"Created backup of original config at: {os.path.abspath(backup_path)}") - - # Read and update the config - with open(config_path, "r") as f: - config = json.load(f) - - config["auto_map"] = { - "AutoConfig": "configuration_prismatic.OpenVLAConfig", - "AutoModelForVision2Seq": "modeling_prismatic.OpenVLAForActionPrediction", - } - - # Write back the updated config - with open(config_path, "w") as f: - json.dump(config, f, indent=2) - - print(f"Updated config.json at: {os.path.abspath(config_path)}") - print("Changes made:") - print(' - Set AutoConfig to "configuration_prismatic.OpenVLAConfig"') - print(' - Set AutoModelForVision2Seq to "modeling_prismatic.OpenVLAForActionPrediction"') - - -def check_identical_files(path1: Union[str, Path], path2: Union[str, Path]) -> bool: - """ - Check if two files are identical in content. - - Args: - path1: Path to the first file - path2: Path to the second file - - Returns: - bool: True if files are identical, False otherwise - """ - path1, path2 = Path(path1), Path(path2) - - # First check if file sizes match - if path1.stat().st_size != path2.stat().st_size: - return False - - # Check if contents match - return filecmp.cmp(path1, path2, shallow=False) - - -def _handle_file_sync(curr_filepath: str, checkpoint_filepath: str, file_type: str) -> None: - """ - Handle syncing of files between current directory and checkpoint. - - Creates backups if files exist but differ, and copies current versions to checkpoint. - - Args: - curr_filepath: Path to the current file version - checkpoint_filepath: Path where the file should be in the checkpoint - file_type: Description of the file type for logging - """ - if os.path.exists(checkpoint_filepath): - # Check if existing files are identical - match = check_identical_files(curr_filepath, checkpoint_filepath) - - if not match: - print( - "\n------------------------------------------------------------------------------------------------\n" - f"Found mismatch between:\n" - f"Current: {curr_filepath}\n" - f"Checkpoint: {checkpoint_filepath}\n" - ) - - # Create timestamped backup - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - backup_path = f"{checkpoint_filepath}.back.{timestamp}" - shutil.copy2(checkpoint_filepath, backup_path) - print(f"Created backup of original checkpoint file at: {os.path.abspath(backup_path)}") - - # Copy current version to checkpoint directory - shutil.copy2(curr_filepath, checkpoint_filepath) - print(f"Copied current version to checkpoint at: {os.path.abspath(checkpoint_filepath)}") - print( - f"Changes complete. The checkpoint will now use the current version of {file_type}" - "\n------------------------------------------------------------------------------------------------\n" - ) - else: - # If file doesn't exist in checkpoint directory, copy it - shutil.copy2(curr_filepath, checkpoint_filepath) - print( - "\n------------------------------------------------------------------------------------------------\n" - f"No {file_type} found in checkpoint directory.\n" - f"Copied current version from: {curr_filepath}\n" - f"To checkpoint location: {os.path.abspath(checkpoint_filepath)}" - "\n------------------------------------------------------------------------------------------------\n" - ) - - -def check_model_logic_mismatch(pretrained_checkpoint: str) -> None: - """ - Check and sync model logic files between current code and checkpoint. - - Handles the relationship between current and checkpoint versions of both - modeling_prismatic.py and configuration_prismatic.py: - - If checkpoint file exists and differs: creates backup and copies current version - - If checkpoint file doesn't exist: copies current version - - Args: - pretrained_checkpoint: Path to the checkpoint directory - """ - if not os.path.isdir(pretrained_checkpoint): - return - - # Find current files - curr_files = {"modeling_prismatic.py": None, "configuration_prismatic.py": None} - - for root, _, files in os.walk("./prismatic/"): - for filename in curr_files.keys(): - if filename in files and curr_files[filename] is None: - curr_files[filename] = os.path.join(root, filename) - - # Check and handle each file - for filename, curr_filepath in curr_files.items(): - if curr_filepath is None: - print(f"WARNING: `{filename}` is not found anywhere in the current directory.") - continue - - checkpoint_filepath = os.path.join(pretrained_checkpoint, filename) - _handle_file_sync(curr_filepath, checkpoint_filepath, filename) - - -def find_checkpoint_file(pretrained_checkpoint: str, file_pattern: str) -> str: - """ - Find a specific checkpoint file matching a pattern. - - Args: - pretrained_checkpoint: Path to the checkpoint directory - file_pattern: String pattern to match in filenames - - Returns: - str: Path to the matching checkpoint file - - Raises: - AssertionError: If no files or multiple files match the pattern - """ - assert os.path.isdir(pretrained_checkpoint), f"Checkpoint path must be a directory: {pretrained_checkpoint}" - - checkpoint_files = [] - for filename in os.listdir(pretrained_checkpoint): - if file_pattern in filename and "checkpoint" in filename: - full_path = os.path.join(pretrained_checkpoint, filename) - checkpoint_files.append(full_path) - - assert len(checkpoint_files) == 1, ( - f"Expected exactly 1 {file_pattern} checkpoint but found {len(checkpoint_files)} in directory: {pretrained_checkpoint}" - ) - - return checkpoint_files[0] - - -def load_component_state_dict(checkpoint_path: str) -> Dict[str, torch.Tensor]: - """ - Load a component's state dict from checkpoint and handle DDP prefix if present. - - Args: - checkpoint_path: Path to the checkpoint file - - Returns: - Dict: The processed state dictionary for loading - """ - state_dict = torch.load(checkpoint_path, weights_only=True) - - # If the component was trained with DDP, elements in the state dict have prefix "module." which we must remove - new_state_dict = {} - for k, v in state_dict.items(): - if k.startswith("module."): - new_state_dict[k[7:]] = v - else: - new_state_dict[k] = v - - return new_state_dict - - -def get_vla(cfg: Any) -> torch.nn.Module: - """ - Load and initialize the VLA model from checkpoint. - - Args: - cfg: Configuration object - - Returns: - torch.nn.Module: The initialized VLA model - """ - print("Instantiating pretrained VLA policy...") - - # If loading a locally stored pretrained checkpoint, check whether config or model files - # need to be synced so that any changes the user makes to the VLA modeling code will - # actually go into effect - # If loading a pretrained checkpoint from Hugging Face Hub, we just assume that the policy - # will be used as is, with its original modeling logic - if not model_is_on_hf_hub(cfg.pretrained_checkpoint): - # Register OpenVLA model to HF Auto Classes (not needed if the model is on HF Hub) - AutoConfig.register("openvla", OpenVLAConfig) - AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor) - AutoProcessor.register(OpenVLAConfig, PrismaticProcessor) - AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction) - - # Update config.json and sync model files - update_auto_map(cfg.pretrained_checkpoint) - check_model_logic_mismatch(cfg.pretrained_checkpoint) - - # Load the model - vla = AutoModelForVision2Seq.from_pretrained( - cfg.pretrained_checkpoint, - # attn_implementation="flash_attention_2", - torch_dtype=torch.bfloat16, - load_in_8bit=cfg.load_in_8bit, - load_in_4bit=cfg.load_in_4bit, - low_cpu_mem_usage=True, - trust_remote_code=True, - ) - - # If using FiLM, wrap the vision backbone to allow for infusion of language inputs - if cfg.use_film: - vla = _apply_film_to_vla(vla, cfg) - - # Set number of images in model input - vla.vision_backbone.set_num_images_in_input(cfg.num_images_in_input) - - vla.eval() - - # Move model to device if not using quantization - if not cfg.load_in_8bit and not cfg.load_in_4bit: - vla = vla.to(DEVICE) - - # Load dataset stats for action normalization - _load_dataset_stats(vla, cfg.pretrained_checkpoint) - - return vla - - -def _apply_film_to_vla(vla: torch.nn.Module, cfg: Any) -> torch.nn.Module: - """ - Apply FiLM (Feature-wise Linear Modulation) to the VLA vision backbone. - - Args: - vla: The VLA model - cfg: Configuration object with model parameters - - Returns: - torch.nn.Module: VLA model with FiLM applied - """ - from peft import LoraConfig, get_peft_model - - # Apply LoRA configuration - lora_config = LoraConfig( - r=cfg.lora_rank, - lora_alpha=min(cfg.lora_rank, 16), - lora_dropout=0.0, - target_modules="all-linear", - init_lora_weights="gaussian", - ) - vla = get_peft_model(vla, lora_config) - - # Create and apply FiLMed vision backbone - new_vision_backbone = FiLMedPrismaticVisionBackbone( - vision_backbone=vla.vision_backbone, llm_dim=vla.llm_dim, - ) - vla.model.vision_backbone = new_vision_backbone - - # Load vision backbone checkpoint - checkpoint_path = find_checkpoint_file(cfg.pretrained_checkpoint, "vision_backbone") - state_dict = torch.load(checkpoint_path, weights_only=True) - vla.model.vision_backbone.load_state_dict(state_dict) - - # Use the model component instead of wrapper and convert to bfloat16 - vla = vla.model - vla.vision_backbone = vla.vision_backbone.to(torch.bfloat16) - - return vla - - -def _load_dataset_stats(vla: torch.nn.Module, checkpoint_path: str) -> None: - """ - Load dataset statistics used during training for action normalization. - - Args: - vla: The VLA model - checkpoint_path: Path to the checkpoint directory - """ - if model_is_on_hf_hub(checkpoint_path): - # Download dataset stats directly from HF Hub - dataset_statistics_path = hf_hub_download( - repo_id=checkpoint_path, - filename="dataset_statistics.json", - ) - else: - dataset_statistics_path = os.path.join(checkpoint_path, "dataset_statistics.json") - if os.path.isfile(dataset_statistics_path): - with open(dataset_statistics_path, "r") as f: - norm_stats = json.load(f) - vla.norm_stats = norm_stats - else: - print( - "WARNING: No local dataset_statistics.json file found for current checkpoint.\n" - "You can ignore this if you are loading the base VLA (i.e. not fine-tuned) checkpoint." - "Otherwise, you may run into errors when trying to call `predict_action()` due to an absent `unnorm_key`." - ) - - -def get_processor(cfg: Any) -> AutoProcessor: - """ - Get the VLA model's Hugging Face processor. - - Args: - cfg: Configuration object with model parameters - - Returns: - AutoProcessor: The model's processor - """ - return AutoProcessor.from_pretrained(cfg.pretrained_checkpoint, trust_remote_code=True) - - -def get_proprio_projector(cfg: Any, llm_dim: int, proprio_dim: int) -> ProprioProjector: - """ - Get proprioception projector for the VLA model. - - Args: - cfg: Configuration object with model parameters - llm_dim: Dimension of the language model - proprio_dim: Dimension of proprioception data - - Returns: - ProprioProjector: The initialized proprio projector - """ - # Initialize projector and move to device - proprio_projector = ProprioProjector( - llm_dim=llm_dim, - proprio_dim=proprio_dim, - ).to(DEVICE) - proprio_projector = proprio_projector.to(torch.bfloat16).to(DEVICE) - proprio_projector.eval() - - # Find and load checkpoint (may be on Hugging Face Hub or stored locally) - if model_is_on_hf_hub(cfg.pretrained_checkpoint): - model_path_to_proprio_projector_name = { - "moojink/openvla-7b-oft-finetuned-libero-spatial": "proprio_projector--150000_checkpoint.pt", - "moojink/openvla-7b-oft-finetuned-libero-object": "proprio_projector--150000_checkpoint.pt", - "moojink/openvla-7b-oft-finetuned-libero-goal": "proprio_projector--50000_checkpoint.pt", - "moojink/openvla-7b-oft-finetuned-libero-10": "proprio_projector--150000_checkpoint.pt", - "moojink/openvla-7b-oft-finetuned-libero-spatial-object-goal-10": "proprio_projector--300000_checkpoint.pt", - } - if cfg.pretrained_checkpoint not in model_path_to_proprio_projector_name.keys(): - raise ValueError("Unsupported HF Hub pretrained checkpoint found!") - # Download proprio projector directly from HF Hub - proprio_projector_path = hf_hub_download( - repo_id=cfg.pretrained_checkpoint, filename=model_path_to_proprio_projector_name[cfg.pretrained_checkpoint] - ) - state_dict = load_component_state_dict(proprio_projector_path) - proprio_projector.load_state_dict(state_dict) - else: - checkpoint_path = find_checkpoint_file(cfg.pretrained_checkpoint, "proprio_projector") - state_dict = load_component_state_dict(checkpoint_path) - proprio_projector.load_state_dict(state_dict) - - return proprio_projector - - -def get_noisy_action_projector(cfg: Any, llm_dim: int) -> NoisyActionProjector: - """ - Get noisy action projector for diffusion-based action prediction. - - Args: - cfg: Configuration object with model parameters - llm_dim: Dimension of the language model - - Returns: - NoisyActionProjector: The initialized noisy action projector - """ - # Initialize projector and move to device - noisy_action_projector = NoisyActionProjector( - llm_dim=llm_dim, - ).to(DEVICE) - noisy_action_projector = noisy_action_projector.to(torch.bfloat16).to(DEVICE) - noisy_action_projector.eval() - - # Find and load checkpoint - checkpoint_path = find_checkpoint_file(cfg.pretrained_checkpoint, "noisy_action_projector") - state_dict = load_component_state_dict(checkpoint_path) - noisy_action_projector.load_state_dict(state_dict) - - return noisy_action_projector - - -def get_action_head(cfg: Any, llm_dim: int) -> Union[L1RegressionActionHead, DiffusionActionHead]: - """ - Get action head for continuous value prediction. - - Args: - cfg: Configuration object with model parameters - llm_dim: Dimension of the language model - - Returns: - Union[L1RegressionActionHead, DiffusionActionHead]: The initialized action head - - Raises: - AssertionError: If both L1 regression and diffusion are specified - """ - assert not (cfg.use_l1_regression and cfg.use_diffusion), "Cannot use both L1 regression and diffusion action head!" - - # Initialize appropriate action head based on configuration - if cfg.use_l1_regression: - action_head = L1RegressionActionHead(input_dim=llm_dim, hidden_dim=llm_dim, action_dim=ACTION_DIM) - elif cfg.use_diffusion: - action_head = DiffusionActionHead( - input_dim=llm_dim, hidden_dim=llm_dim, action_dim=ACTION_DIM, num_diffusion_steps_train=cfg.num_diffusion_steps_train - ) - # Set number of diffusion steps for inference - action_head.noise_scheduler.set_timesteps(cfg.num_diffusion_steps_inference) - else: - raise ValueError("Either use_l1_regression or use_diffusion must be True") - - action_head = action_head.to(torch.bfloat16).to(DEVICE) - action_head.eval() - - # Find and load checkpoint (may be on Hugging Face Hub or stored locally) - if model_is_on_hf_hub(cfg.pretrained_checkpoint): - model_path_to_action_head_name = { - "moojink/openvla-7b-oft-finetuned-libero-spatial": "action_head--150000_checkpoint.pt", - "moojink/openvla-7b-oft-finetuned-libero-object": "action_head--150000_checkpoint.pt", - "moojink/openvla-7b-oft-finetuned-libero-goal": "action_head--50000_checkpoint.pt", - "moojink/openvla-7b-oft-finetuned-libero-10": "action_head--150000_checkpoint.pt", - "moojink/openvla-7b-oft-finetuned-libero-spatial-object-goal-10": "action_head--300000_checkpoint.pt", - } - if cfg.pretrained_checkpoint not in model_path_to_action_head_name.keys(): - raise ValueError("Unsupported HF Hub pretrained checkpoint found!") - # Download proprio projector directly from HF Hub - action_head_path = hf_hub_download( - repo_id=cfg.pretrained_checkpoint, filename=model_path_to_action_head_name[cfg.pretrained_checkpoint] - ) - state_dict = load_component_state_dict(action_head_path) - action_head.load_state_dict(state_dict) - else: - checkpoint_path = find_checkpoint_file(cfg.pretrained_checkpoint, "action_head") - state_dict = load_component_state_dict(checkpoint_path) - action_head.load_state_dict(state_dict) - - return action_head - - -def resize_image_for_policy(img: np.ndarray, resize_size: Union[int, Tuple[int, int]]) -> np.ndarray: - """ - Resize an image to match the policy's expected input size. - - Uses the same resizing scheme as in the training data pipeline for distribution matching. - - Args: - img: Numpy array containing the image - resize_size: Target size as int (square) or (height, width) tuple - - Returns: - np.ndarray: The resized image - """ - assert isinstance(resize_size, int) or isinstance(resize_size, tuple) - if isinstance(resize_size, int): - resize_size = (resize_size, resize_size) - - # Resize using the same pipeline as in RLDS dataset builder - img = tf.image.encode_jpeg(img) # Encode as JPEG - img = tf.io.decode_image(img, expand_animations=False, dtype=tf.uint8) # Decode back - img = tf.image.resize(img, resize_size, method="lanczos3", antialias=True) - img = tf.cast(tf.clip_by_value(tf.round(img), 0, 255), tf.uint8) - - return img.numpy() - - -def crop_and_resize(image: tf.Tensor, crop_scale: float, batch_size: int) -> tf.Tensor: - """ - Center-crop an image and resize it back to original dimensions. - - Uses the same logic as in the training data pipeline for distribution matching. - - Args: - image: TF Tensor of shape (batch_size, H, W, C) or (H, W, C) with values in [0,1] - crop_scale: Area of center crop relative to original image - batch_size: Batch size - - Returns: - tf.Tensor: The cropped and resized image - """ - # Handle 3D inputs by adding batch dimension if needed - assert image.shape.ndims in (3, 4), "Image must be 3D or 4D tensor" - expanded_dims = False - if image.shape.ndims == 3: - image = tf.expand_dims(image, axis=0) - expanded_dims = True - - # Calculate crop dimensions (note: we use sqrt(crop_scale) for h/w) - new_heights = tf.reshape(tf.clip_by_value(tf.sqrt(crop_scale), 0, 1), shape=(batch_size,)) - new_widths = tf.reshape(tf.clip_by_value(tf.sqrt(crop_scale), 0, 1), shape=(batch_size,)) - - # Create bounding box for the crop - height_offsets = (1 - new_heights) / 2 - width_offsets = (1 - new_widths) / 2 - bounding_boxes = tf.stack( - [ - height_offsets, - width_offsets, - height_offsets + new_heights, - width_offsets + new_widths, - ], - axis=1, - ) - - # Apply crop and resize - image = tf.image.crop_and_resize( - image, bounding_boxes, tf.range(batch_size), (OPENVLA_IMAGE_SIZE, OPENVLA_IMAGE_SIZE) - ) - - # Remove batch dimension if it was added - if expanded_dims: - image = image[0] - - return image - - -def center_crop_image(image: Union[np.ndarray, Image.Image]) -> Image.Image: - """ - Center crop an image to match training data distribution. - - Args: - image: Input image (PIL or numpy array) - - Returns: - Image.Image: Cropped PIL Image - """ - batch_size = 1 - crop_scale = 0.9 - - # Convert to TF Tensor if needed - if not isinstance(image, tf.Tensor): - image = tf.convert_to_tensor(np.array(image)) - - orig_dtype = image.dtype - - # Convert to float32 in range [0,1] - image = tf.image.convert_image_dtype(image, tf.float32) - - # Apply center crop and resize - image = crop_and_resize(image, crop_scale, batch_size) - - # Convert back to original data type - image = tf.clip_by_value(image, 0, 1) - image = tf.image.convert_image_dtype(image, orig_dtype, saturate=True) - - # Convert to PIL Image - return Image.fromarray(image.numpy()).convert("RGB") - - -def check_image_format(image: Any) -> None: - """ - Validate input image format. - - Args: - image: Image to check - - Raises: - AssertionError: If image format is invalid - """ - is_numpy_array = isinstance(image, np.ndarray) - has_correct_shape = len(image.shape) == 3 and image.shape[-1] == 3 - has_correct_dtype = image.dtype == np.uint8 - - assert is_numpy_array and has_correct_shape and has_correct_dtype, ( - "Incorrect image format detected! Make sure that the input image is a " - "numpy array with shape (H, W, 3) and dtype np.uint8!" - ) - - -def normalize_proprio(proprio: np.ndarray, norm_stats: Dict[str, Any]) -> np.ndarray: - """ - Normalize proprioception data to match training distribution. - - Args: - proprio: Raw proprioception data - norm_stats: Normalization statistics - - Returns: - np.ndarray: Normalized proprioception data - """ - if ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS: - mask = norm_stats.get("mask", np.ones_like(norm_stats["min"], dtype=bool)) - proprio_high, proprio_low = np.array(norm_stats["max"]), np.array(norm_stats["min"]) - elif ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS_Q99: - mask = norm_stats.get("mask", np.ones_like(norm_stats["q01"], dtype=bool)) - proprio_high, proprio_low = np.array(norm_stats["q99"]), np.array(norm_stats["q01"]) - else: - raise ValueError("Unsupported action/proprio normalization type detected!") - - normalized_proprio = np.clip( - np.where( - mask, - 2 * (proprio - proprio_low) / (proprio_high - proprio_low + 1e-8) - 1, - proprio, - ), - a_min=-1.0, - a_max=1.0, - ) - - return normalized_proprio - - -def prepare_images_for_vla(images: List[np.ndarray], cfg: Any) -> List[Image.Image]: - """ - Prepare images for VLA input by resizing and cropping as needed. - - Args: - images: List of input images as numpy arrays - cfg: Configuration object with parameters - - Returns: - List[Image.Image]: Processed images ready for the model - """ - processed_images = [] - - for image in images: - # Validate format - check_image_format(image) - - # Resize if needed - if image.shape != (OPENVLA_IMAGE_SIZE, OPENVLA_IMAGE_SIZE, 3): - image = resize_image_for_policy(image, OPENVLA_IMAGE_SIZE) - - # Convert to PIL image - pil_image = Image.fromarray(image).convert("RGB") - - # Apply center crop if configured - if cfg.center_crop: - pil_image = center_crop_image(pil_image) - - processed_images.append(pil_image) - - return processed_images - - -def get_vla_action( - cfg: Any, - vla: torch.nn.Module, - processor: Any, - obs: Dict[str, Any], - task_label: str, - action_head: Optional[torch.nn.Module] = None, - proprio_projector: Optional[torch.nn.Module] = None, - noisy_action_projector: Optional[torch.nn.Module] = None, - use_film: bool = False, -) -> List[np.ndarray]: - """ - Generate action predictions with the VLA policy. - - Args: - cfg: Configuration object with parameters - vla: The VLA model - processor: Model processor for inputs - obs: Observation dictionary - task_label: Text description of the task - action_head: Optional action head for continuous actions - proprio_projector: Optional proprioception projector - noisy_action_projector: Optional noisy action projector for diffusion - use_film: Whether to use FiLM - - Returns: - List[np.ndarray]: Predicted actions - """ - with torch.inference_mode(): - - # Collect all input images - all_images = [obs["full_image"]] - if cfg.num_images_in_input > 1: - all_images.extend([obs[k] for k in obs.keys() if "wrist" in k]) - - # Process images - all_images = prepare_images_for_vla(all_images, cfg) - - # Extract primary image and additional images - primary_image = all_images.pop(0) - - # Build VLA prompt - prompt = f"In: What action should the robot take to {task_label.lower()}?\nOut:" - - # Process primary image - inputs = processor(prompt, primary_image).to(DEVICE, dtype=torch.bfloat16) - - # Process additional wrist images if any - if all_images: - all_wrist_inputs = [ - processor(prompt, image_wrist).to(DEVICE, dtype=torch.bfloat16) for image_wrist in all_images - ] - # Concatenate all images - primary_pixel_values = inputs["pixel_values"] - all_wrist_pixel_values = [wrist_inputs["pixel_values"] for wrist_inputs in all_wrist_inputs] - inputs["pixel_values"] = torch.cat([primary_pixel_values] + all_wrist_pixel_values, dim=1) - - # Process proprioception data if used - proprio = None - if cfg.use_proprio: - proprio = obs["state"] - proprio_norm_stats = vla.norm_stats[cfg.unnorm_key]["proprio"] - obs["state"] = normalize_proprio(proprio, proprio_norm_stats) - proprio = obs["state"] - - # Generate action - if action_head is None: - # Standard VLA output (single-image inputs, discrete actions) - action, _ = vla.predict_action(**inputs, unnorm_key=cfg.unnorm_key, do_sample=False) - else: - # Custom action head for continuous actions - action, _ = vla.predict_action( - **inputs, - unnorm_key=cfg.unnorm_key, - do_sample=False, - proprio=proprio, - proprio_projector=proprio_projector, - noisy_action_projector=noisy_action_projector, - action_head=action_head, - use_film=use_film, - ) - - # Return action chunk as list of actions - return [action[i] for i in range(len(action))] - - -def get_action_from_server( - observation: Dict[str, Any], server_endpoint: str = "http://0.0.0.0:8777/act" -) -> Dict[str, Any]: - """ - Get VLA action from remote inference server. - - Args: - observation: Observation data to send to server - server_endpoint: URL of the inference server - - Returns: - Dict[str, Any]: Action response from server - """ - response = requests.post( - server_endpoint, - json=observation, - ) - return response.json() diff --git a/capvector-oft/experiments/robot/robot_utils.py b/capvector-oft/experiments/robot/robot_utils.py deleted file mode 100644 index bb905d8428e943df820143d800740adc79f32602..0000000000000000000000000000000000000000 --- a/capvector-oft/experiments/robot/robot_utils.py +++ /dev/null @@ -1,199 +0,0 @@ -"""Utils for evaluating robot policies in various environments.""" - -import os -import random -import time -from typing import Any, Dict, List, Optional, Union - -import numpy as np -import torch - -from experiments.robot.openvla_utils import ( - get_vla, - get_vla_action, -) - -# Initialize important constants -ACTION_DIM = 7 -DATE = time.strftime("%Y_%m_%d") -DATE_TIME = time.strftime("%Y_%m_%d-%H_%M_%S") -DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") - -# Configure NumPy print settings -np.set_printoptions(formatter={"float": lambda x: "{0:0.3f}".format(x)}) - -# Initialize system prompt for OpenVLA v0.1 -OPENVLA_V01_SYSTEM_PROMPT = ( - "A chat between a curious user and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the user's questions." -) - -# Model image size configuration -MODEL_IMAGE_SIZES = { - "openvla": 224, - # Add other models as needed -} - - -def set_seed_everywhere(seed: int) -> None: - """ - Set random seed for all random number generators for reproducibility. - - Args: - seed: The random seed to use - """ - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - np.random.seed(seed) - random.seed(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - os.environ["PYTHONHASHSEED"] = str(seed) - - -def get_model(cfg: Any, wrap_diffusion_policy_for_droid: bool = False) -> torch.nn.Module: - """ - Load and initialize model for evaluation based on configuration. - - Args: - cfg: Configuration object with model parameters - wrap_diffusion_policy_for_droid: Whether to wrap diffusion policy for DROID - - Returns: - torch.nn.Module: The loaded model - - Raises: - ValueError: If model family is not supported - """ - if cfg.model_family == "openvla": - model = get_vla(cfg) - else: - raise ValueError(f"Unsupported model family: {cfg.model_family}") - - print(f"Loaded model: {type(model)}") - return model - - -def get_image_resize_size(cfg: Any) -> Union[int, tuple]: - """ - Get image resize dimensions for a specific model. - - If returned value is an int, the resized image will be a square. - If returned value is a tuple, the resized image will be a rectangle. - - Args: - cfg: Configuration object with model parameters - - Returns: - Union[int, tuple]: Image resize dimensions - - Raises: - ValueError: If model family is not supported - """ - if cfg.model_family not in MODEL_IMAGE_SIZES: - raise ValueError(f"Unsupported model family: {cfg.model_family}") - - return MODEL_IMAGE_SIZES[cfg.model_family] - - -def get_action( - cfg: Any, - model: torch.nn.Module, - obs: Dict[str, Any], - task_label: str, - processor: Optional[Any] = None, - action_head: Optional[torch.nn.Module] = None, - proprio_projector: Optional[torch.nn.Module] = None, - noisy_action_projector: Optional[torch.nn.Module] = None, - use_film: bool = False, -) -> Union[List[np.ndarray], np.ndarray]: - """ - Query the model to get action predictions. - - Args: - cfg: Configuration object with model parameters - model: The loaded model - obs: Observation dictionary - task_label: Text description of the task - processor: Model processor for inputs - action_head: Optional action head for continuous actions - proprio_projector: Optional proprioception projector - noisy_action_projector: Optional noisy action projector for diffusion - use_film: Whether to use FiLM - - Returns: - Union[List[np.ndarray], np.ndarray]: Predicted actions - - Raises: - ValueError: If model family is not supported - """ - with torch.no_grad(): - if cfg.model_family == "openvla": - action = get_vla_action( - cfg=cfg, - vla=model, - processor=processor, - obs=obs, - task_label=task_label, - action_head=action_head, - proprio_projector=proprio_projector, - noisy_action_projector=noisy_action_projector, - use_film=use_film, - ) - else: - raise ValueError(f"Unsupported model family: {cfg.model_family}") - - return action - - -def normalize_gripper_action(action: np.ndarray, binarize: bool = True) -> np.ndarray: - """ - Normalize gripper action from [0,1] to [-1,+1] range. - - This is necessary for some environments because the dataset wrapper - standardizes gripper actions to [0,1]. Note that unlike the other action - dimensions, the gripper action is not normalized to [-1,+1] by default. - - Normalization formula: y = 2 * (x - orig_low) / (orig_high - orig_low) - 1 - - Args: - action: Action array with gripper action in the last dimension - binarize: Whether to binarize gripper action to -1 or +1 - - Returns: - np.ndarray: Action array with normalized gripper action - """ - # Create a copy to avoid modifying the original - normalized_action = action.copy() - - # Normalize the last action dimension to [-1,+1] - orig_low, orig_high = 0.0, 1.0 - normalized_action[..., -1] = 2 * (normalized_action[..., -1] - orig_low) / (orig_high - orig_low) - 1 - - if binarize: - # Binarize to -1 or +1 - normalized_action[..., -1] = np.sign(normalized_action[..., -1]) - - return normalized_action - - -def invert_gripper_action(action: np.ndarray) -> np.ndarray: - """ - Flip the sign of the gripper action (last dimension of action vector). - - This is necessary for environments where -1 = open, +1 = close, since - the RLDS dataloader aligns gripper actions such that 0 = close, 1 = open. - - Args: - action: Action array with gripper action in the last dimension - - Returns: - np.ndarray: Action array with inverted gripper action - """ - # Create a copy to avoid modifying the original - inverted_action = action.copy() - - # Invert the gripper action - inverted_action[..., -1] *= -1.0 - - return inverted_action diff --git a/capvector-oft/prismatic/__init__.py b/capvector-oft/prismatic/__init__.py deleted file mode 100644 index a710b2bf5a6f40b66ad23c8b9c10fcac23e08398..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .models import available_model_names, available_models, get_model_description, load diff --git a/capvector-oft/prismatic/conf/__init__.py b/capvector-oft/prismatic/conf/__init__.py deleted file mode 100644 index 25c29771f319a363aa2692202ece4e0472e5343f..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/conf/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .datasets import DatasetConfig, DatasetRegistry -from .models import ModelConfig, ModelRegistry -from .vla import VLAConfig, VLARegistry diff --git a/capvector-oft/prismatic/conf/datasets.py b/capvector-oft/prismatic/conf/datasets.py deleted file mode 100644 index a085882e911260e0facada164df0e0cee79c85b2..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/conf/datasets.py +++ /dev/null @@ -1,133 +0,0 @@ -""" -datasets.py - -Draccus Dataclass Definition for a DatasetConfig object, with various registered subclasses for each dataset variant -and processing scheme. A given dataset variant (e.g., `llava-lightning`) configures the following attributes: - - Dataset Variant (Identifier) --> e.g., "llava-v15" - - Align Stage Dataset Components (annotations, images) - - Finetune Stage Dataset Components (annotations, images) - - Dataset Root Directory (Path) -""" - -from dataclasses import dataclass -from enum import Enum, unique -from pathlib import Path -from typing import Tuple - -from draccus import ChoiceRegistry - - -@dataclass -class DatasetConfig(ChoiceRegistry): - # fmt: off - dataset_id: str # Unique ID that fully specifies a dataset variant - - # Dataset Components for each Stage in < align | finetune > - align_stage_components: Tuple[Path, Path] # Path to annotation file and images directory for `align` stage - finetune_stage_components: Tuple[Path, Path] # Path to annotation file and images directory for `finetune` stage - - dataset_root_dir: Path # Path to dataset root directory; others paths are relative to root - # fmt: on - - -# [Reproduction] LLaVa-v15 (exact dataset used in all public LLaVa-v15 models) -@dataclass -class LLaVa_V15_Config(DatasetConfig): - dataset_id: str = "llava-v15" - - align_stage_components: Tuple[Path, Path] = ( - Path("download/llava-laion-cc-sbu-558k/chat.json"), - Path("download/llava-laion-cc-sbu-558k/"), - ) - finetune_stage_components: Tuple[Path, Path] = ( - Path("download/llava-v1.5-instruct/llava_v1_5_mix665k.json"), - Path("download/llava-v1.5-instruct/"), - ) - dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms") - - -# [Multimodal-Only] LLava-v15 WITHOUT the Language-Only ShareGPT Data (No Co-Training) -@dataclass -class LLaVa_Multimodal_Only_Config(DatasetConfig): - dataset_id: str = "llava-multimodal" - - align_stage_components: Tuple[Path, Path] = ( - Path("download/llava-laion-cc-sbu-558k/chat.json"), - Path("download/llava-laion-cc-sbu-558k/"), - ) - finetune_stage_components: Tuple[Path, Path] = ( - Path("download/llava-v1.5-instruct/llava_v1_5_stripped625k.json"), - Path("download/llava-v1.5-instruct/"), - ) - dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms") - - -# LLaVa-v15 + LVIS-Instruct-4V -@dataclass -class LLaVa_LVIS4V_Config(DatasetConfig): - dataset_id: str = "llava-lvis4v" - - align_stage_components: Tuple[Path, Path] = ( - Path("download/llava-laion-cc-sbu-558k/chat.json"), - Path("download/llava-laion-cc-sbu-558k/"), - ) - finetune_stage_components: Tuple[Path, Path] = ( - Path("download/llava-v1.5-instruct/llava_v1_5_lvis4v_mix888k.json"), - Path("download/llava-v1.5-instruct/"), - ) - dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms") - - -# LLaVa-v15 + LRV-Instruct -@dataclass -class LLaVa_LRV_Config(DatasetConfig): - dataset_id: str = "llava-lrv" - - align_stage_components: Tuple[Path, Path] = ( - Path("download/llava-laion-cc-sbu-558k/chat.json"), - Path("download/llava-laion-cc-sbu-558k/"), - ) - finetune_stage_components: Tuple[Path, Path] = ( - Path("download/llava-v1.5-instruct/llava_v1_5_lrv_mix1008k.json"), - Path("download/llava-v1.5-instruct/"), - ) - dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms") - - -# LLaVa-v15 + LVIS-Instruct-4V + LRV-Instruct -@dataclass -class LLaVa_LVIS4V_LRV_Config(DatasetConfig): - dataset_id: str = "llava-lvis4v-lrv" - - align_stage_components: Tuple[Path, Path] = ( - Path("download/llava-laion-cc-sbu-558k/chat.json"), - Path("download/llava-laion-cc-sbu-558k/"), - ) - finetune_stage_components: Tuple[Path, Path] = ( - Path("download/llava-v1.5-instruct/llava_v1_5_lvis4v_lrv_mix1231k.json"), - Path("download/llava-v1.5-instruct/"), - ) - dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms") - - -# === Define a Dataset Registry Enum for Reference & Validation =>> all *new* datasets must be added here! === -@unique -class DatasetRegistry(Enum): - # === LLaVa v1.5 === - LLAVA_V15 = LLaVa_V15_Config - - LLAVA_MULTIMODAL_ONLY = LLaVa_Multimodal_Only_Config - - LLAVA_LVIS4V = LLaVa_LVIS4V_Config - LLAVA_LRV = LLaVa_LRV_Config - - LLAVA_LVIS4V_LRV = LLaVa_LVIS4V_LRV_Config - - @property - def dataset_id(self) -> str: - return self.value.dataset_id - - -# Register Datasets in Choice Registry -for dataset_variant in DatasetRegistry: - DatasetConfig.register_subclass(dataset_variant.dataset_id, dataset_variant.value) diff --git a/capvector-oft/prismatic/conf/models.py b/capvector-oft/prismatic/conf/models.py deleted file mode 100644 index 7a3d742e47fb5a542eb8438be16e08768db90d35..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/conf/models.py +++ /dev/null @@ -1,584 +0,0 @@ -""" -models.py - -Draccus Dataclass Definition for a ModelConfig object, with various registered subclasses for each model family and -variant thereof. A given model variant configures the following attributes: - - Pretrained Visual Representation (e.g., OpenAI CLIP ViT-L/14) + Pretrained LLM Backbone (e.g., LLaMa-2 7B) - - VLM Configuration + Parameters (e.g., MLP Projector, Image Preprocessing, etc.) - - [Optional] Stage 1 (`align`) Optimization Hyperparameters - - Stage 2 (`finetune`) Optimization Hyperparameters -""" - -from dataclasses import dataclass -from enum import Enum, unique -from typing import Optional - -from draccus import ChoiceRegistry - - -@dataclass -class ModelConfig(ChoiceRegistry): - # fmt: off - model_id: str # Unique Model ID that fully specifies a given variant - arch_specifier: str # Architecture specifier string (e.g., "gelu-mlp") - - # Pretrained Backbones - vision_backbone_id: str # Pretrained Visual Featurizer (from TIMM) to load - llm_backbone_id: str # Pretrained LLM (from HF Transformers) to load - - # Backbone Parameters - image_resize_strategy: str # Resizing strategy in < crop | letterbox | corner-pad > - llm_max_length: int # Maximum context length for LLM (can be < than max!) - - # === Multi-Stage Optimization Hyperparameters === - # By default, we assume an AdamW optimizer with FSDP (Gradient Sharding or Full Sharding depending on stage) - - # Align Stage Optimization Parameters - align_epochs: int # Epochs to Run (in case `max_steps` is not specified) - align_max_steps: Optional[int] # [Optional] Max Gradient Steps (overrides epochs) - align_global_batch_size: int # Global Batch Size (divided across processes) - align_per_device_batch_size: int # Per-Device Batch Size (per-process) - # => # of accumulation steps is auto-computed - - align_learning_rate: float # Peak Learning Rate (lr_scheduler sets warmup/decay) - align_weight_decay: float # Weight Decay for AdamW Optimizer - align_max_grad_norm: float # Max Grad Norm (for global gradient clipping) - align_lr_scheduler_type: str # LR Scheduler (default: "linear-warmup+cosine-decay") - align_warmup_ratio: float # Fraction of total steps to warmup - - align_train_strategy: str # Align Train Strategy (default: "fsdp-shard-grad-op") - - # Finetune Stage Optimization Parameters - finetune_epochs: int # Epochs to Run (in case `max_steps` is not specified) - finetune_max_steps: Optional[int] # [Optional] Max Gradient Steps (overrides epochs) - finetune_global_batch_size: int # Global Batch Size (divided across processes) - finetune_per_device_batch_size: int # Per-Device Batch Size (per-process) - # => # of accumulation steps is auto-computed - - finetune_learning_rate: float # Peak Learning Rate (lr_scheduler sets warmup/decay) - finetune_weight_decay: float # Weight Decay for AdamW Optimizer - finetune_max_grad_norm: float # Max Grad Norm (for global gradient clipping) - finetune_lr_scheduler_type: str # LR Scheduler (default: "linear-warmup+cosine-decay") - finetune_warmup_ratio: float # Fraction of total steps to warmup - - finetune_train_strategy: str # Finetune Train Strategy (default: "fsdp-full-shard") - - # Enable Gradient/Activation Checkpointing (for the LLM Backbone) - enable_gradient_checkpointing: bool = True - - # Enable Traditional Mixed Precision Training via Torch Native AMP (`autocast`) - enable_mixed_precision_training: bool = True # Whether to enable mixed precision training - reduce_in_full_precision: bool = False # Whether to run gradient reduction in FP32 - - # fmt: on - - -# === LLaVa v1.5 Reproduction - Fully Specified Configurations === -@dataclass -class LLaVa_v15_Reproduction_7B(ModelConfig): - model_id: str = "reproduction-llava-v15+7b" - arch_specifier: str = "gelu-mlp" - - vision_backbone_id: str = "clip-vit-l-336px" - llm_backbone_id: str = "vicuna-v15-7b" - - image_resize_strategy: str = "letterbox" - llm_max_length: int = 2048 - - # Align Stage Optimization Parameters - align_epochs: int = 1 - align_max_steps: Optional[int] = None - align_global_batch_size: int = 256 - align_per_device_batch_size: int = 16 - - align_learning_rate: float = 1e-3 - align_weight_decay: float = 0.0 - align_max_grad_norm: float = 1.0 - align_lr_scheduler_type: str = "linear-warmup+cosine-decay" - align_warmup_ratio: float = 0.03 - - align_train_strategy: str = "fsdp-shard-grad-op" - - # Finetune Stage Optimization Parameters - finetune_epochs: int = 1 - finetune_max_steps: Optional[int] = None - finetune_global_batch_size: int = 128 - finetune_per_device_batch_size: int = 16 - - finetune_learning_rate: float = 2e-5 - finetune_weight_decay: float = 0.1 - finetune_max_grad_norm: float = 1.0 - finetune_lr_scheduler_type: str = "linear-warmup+cosine-decay" - finetune_warmup_ratio: float = 0.03 - - finetune_train_strategy: str = "fsdp-full-shard" - - -@dataclass -class LLaVa_v15_Reproduction_13B(LLaVa_v15_Reproduction_7B): - model_id: str = "reproduction-llava-v15+13b" - llm_backbone_id: str = "vicuna-v15-13b" - - -# === Section 4.1 :: Optimization Procedure === - - -# Section 4.1A :: 🚀 --> Necessity of Multi-Stage Training -@dataclass -class Exp_7B_One_Stage(LLaVa_v15_Reproduction_7B): - model_id: str = "one-stage+7b" - arch_specifier: str = "no-align+gelu-mlp" - - -@dataclass -class Exp_13B_One_Stage(LLaVa_v15_Reproduction_13B): - model_id: str = "one-stage+13b" - arch_specifier: str = "no-align+gelu-mlp" - - -# Section 4.1B :: 🛠️ --> Full Finetuning through Visual Backbones -# =>> Note :: Run with `--stage full-finetune` -@dataclass -class Exp_7B_Full_Finetune_Multi_Stage(LLaVa_v15_Reproduction_7B): - model_id: str = "full-ft-multi-stage+7b" - - -@dataclass -class Exp_7B_Full_Finetune_One_Stage(Exp_7B_One_Stage): - model_id: str = "full-ft-one-stage+7b" - - -# === Section 4.2 :: Image Processing and Visual Representations === - - -# Section 4.2A :: 📸 --> Choosing a Pretrained Representation -@dataclass -class Exp_7B_IN1K_ViT_L_p16_224px(Exp_7B_One_Stage): - model_id: str = "in1k-224px+7b" - vision_backbone_id: str = "in1k-vit-l" - - -@dataclass -class Exp_7B_DINOv2_ViT_L_p14_224px(Exp_7B_One_Stage): - model_id: str = "dinov2-224px+7b" - vision_backbone_id: str = "dinov2-vit-l" - - -@dataclass -class Exp_7B_CLIP_ViT_L_p14_224px(Exp_7B_One_Stage): - model_id: str = "clip-224px+7b" - vision_backbone_id: str = "clip-vit-l" - - -@dataclass -class Exp_7B_SigLIP_ViT_SO_p14_224px(Exp_7B_One_Stage): - model_id: str = "siglip-224px+7b" - vision_backbone_id: str = "siglip-vit-so400m" - - -# Section 4.2B :: 📐 --> Choosing an Image Preprocessing Strategy -@dataclass -class Exp_7B_CLIP_ViT_L_p14_336px_Resize_Crop(Exp_7B_One_Stage): - model_id: str = "clip-336px-resize-crop+7b" - image_resize_strategy: str = "resize-crop" - - -@dataclass -class Exp_7B_CLIP_ViT_L_p14_336px_Resize_Naive(Exp_7B_One_Stage): - model_id: str = "clip-336px-resize-naive+7b" - image_resize_strategy: str = "resize-naive" - - -@dataclass -class Exp_7B_SigLIP_ViT_SO_p14_384px_Letterbox(Exp_7B_One_Stage): - model_id: str = "siglip-384px-letterbox+7b" - vision_backbone_id: str = "siglip-vit-so400m-384px" - image_resize_strategy: str = "letterbox" - - -@dataclass -class Exp_7B_SigLIP_ViT_SO_p14_384px_Resize_Crop(Exp_7B_One_Stage): - model_id: str = "siglip-384px-resize-crop+7b" - vision_backbone_id: str = "siglip-vit-so400m-384px" - image_resize_strategy: str = "resize-crop" - - -@dataclass -class Exp_7B_SigLIP_ViT_SO_p14_384px_Resize_Naive(Exp_7B_One_Stage): - model_id: str = "siglip-384px-resize-naive+7b" - vision_backbone_id: str = "siglip-vit-so400m-384px" - image_resize_strategy: str = "resize-naive" - - -# Section 4.2D :: 🥞 --> Stacking/Ensembling Visual Representations -@dataclass -class Exp_7B_DINOCLIP_ViT_L_p14_336px_Letterbox(Exp_7B_One_Stage): - model_id: str = "dinoclip-336px-letterbox+7b" - vision_backbone_id: str = "dinoclip-vit-l-336px" - image_resize_strategy: str = "letterbox" - arch_specifier: str = "no-align+fused-gelu-mlp" - - -@dataclass -class Exp_7B_DINOCLIP_ViT_L_p14_336px_Resize_Naive(Exp_7B_One_Stage): - model_id: str = "dinoclip-336px-resize-naive+7b" - vision_backbone_id: str = "dinoclip-vit-l-336px" - image_resize_strategy: str = "resize-naive" - arch_specifier: str = "no-align+fused-gelu-mlp" - - -@dataclass -class Exp_7B_DINOSigLIP_ViT_L_p14_384px_Letterbox(Exp_7B_One_Stage): - model_id: str = "dinosiglip-384px-letterbox+7b" - vision_backbone_id: str = "dinosiglip-vit-so-384px" - image_resize_strategy: str = "letterbox" - arch_specifier: str = "no-align+fused-gelu-mlp" - - -@dataclass -class Exp_7B_DINOSigLIP_ViT_L_p14_384px_Resize_Naive(Exp_7B_One_Stage): - model_id: str = "dinosiglip-384px-resize-naive+7b" - vision_backbone_id: str = "dinosiglip-vit-so-384px" - image_resize_strategy: str = "resize-naive" - arch_specifier: str = "no-align+fused-gelu-mlp" - - -# === Section 4.3 :: Language Models === - - -# Section 4.3A :: 📝 --> Base vs. Instruct-Tuned (Chat) LLMs -@dataclass -class Exp_7B_Llama2(Exp_7B_One_Stage): - model_id: str = "llama2+7b" - llm_backbone_id: str = "llama2-7b-pure" - - -@dataclass -class Exp_13B_Llama2(Exp_13B_One_Stage): - model_id: str = "llama2+13b" - llm_backbone_id: str = "llama2-13b-pure" - - -# ~ Additional LLM Backbones :: LLaMa-2 Chat, Mistral v0.1, Mistral v0.1 Instruct, Phi-2 ~ -@dataclass -class Ext_Exp_7B_Llama2_Chat(Exp_7B_One_Stage): - model_id: str = "llama2-chat+7b" - llm_backbone_id: str = "llama2-7b-chat" - - -@dataclass -class Ext_Exp_13B_Llama2_Chat(Exp_13B_One_Stage): - model_id: str = "llama2-chat+13b" - llm_backbone_id: str = "llama2-13b-chat" - - -@dataclass -class Ext_Exp_7B_Mistral_V1(Exp_7B_One_Stage): - model_id: str = "mistral-v0.1+7b" - llm_backbone_id: str = "mistral-v0.1-7b-pure" - - -@dataclass -class Ext_Exp_7B_Mistral_Instruct_V1(Exp_7B_One_Stage): - model_id: str = "mistral-instruct-v0.1+7b" - llm_backbone_id: str = "mistral-v0.1-7b-instruct" - - -@dataclass -class Ext_Exp_3B_Phi_2(Exp_7B_One_Stage): - model_id: str = "phi-2+3b" - llm_backbone_id: str = "phi-2-3b" - - -# Section 4.3B :: ✌️ --> Co-training on Language-only Data -# =>> Note :: Run with `--dataset.type "llava-multimodal" (multimodal data only / no co-training) -@dataclass -class Exp_7B_Vicuna_No_Cotraining(Exp_7B_One_Stage): - model_id: str = "vicuna-no-cotraining+7b" - - -@dataclass -class Exp_7B_Llama2_No_Cotraining(Exp_7B_One_Stage): - model_id: str = "llama2-no-cotraining+7b" - llm_backbone_id: str = "llama2-7b-pure" - - -# === Section 4.4 :: Scaling Properties - Train Time & Data === - - -# Section 4.4A :: ⏰ --> Scaling Train Time -@dataclass -class Exp_7B_1p25_Epochs(Exp_7B_One_Stage): - model_id: str = "train-1.25-epochs+7b" - finetune_max_steps: int = 6500 - - -@dataclass -class Exp_7B_1p5_Epochs(Exp_7B_One_Stage): - model_id: str = "train-1.5-epochs+7b" - finetune_max_steps: int = 7800 - - -@dataclass -class Exp_7B_2_Epochs(Exp_7B_One_Stage): - model_id: str = "train-2-epochs+7b" - finetune_epochs: int = 2 - - -@dataclass -class Exp_7B_3_Epochs(Exp_7B_One_Stage): - model_id: str = "train-3-epochs+7b" - finetune_epochs: int = 3 - - -# Section 4.4B :: 📚 --> Scaling Data -# =>> Note :: Run with `--dataset.type "llava-lvis4v"` -@dataclass -class Exp_7B_LLaVa_LVIS4V(Exp_7B_One_Stage): - model_id: str = "llava-lvis4v+7b" - - -# =>> Note :: Run with `--dataset.type "llava-lrv"` -@dataclass -class Exp_7B_LLaVa_LRV(Exp_7B_One_Stage): - model_id: str = "llava-lrv+7b" - - -# =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"` -@dataclass -class Exp_7B_LLaVa_LVIS4V_LRV(Exp_7B_One_Stage): - model_id: str = "llava-lvis4v-lrv+7b" - - -# === Section 5 :: Prisms === - - -# Prism-CLIP -@dataclass -class Prism_7B_CLIP_Controlled(Exp_7B_One_Stage): - model_id: str = "prism-clip-controlled+7b" - vision_backbone_id: str = "clip-vit-l-336px" - image_resize_strategy: str = "resize-naive" - llm_backbone_id: str = "llama2-7b-pure" - - -@dataclass -class Prism_13B_CLIP_Controlled(Exp_13B_One_Stage): - model_id: str = "prism-clip-controlled+13b" - vision_backbone_id: str = "clip-vit-l-336px" - image_resize_strategy: str = "resize-naive" - llm_backbone_id: str = "llama2-13b-pure" - - -# =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"` -@dataclass -class Prism_7B_CLIP(Exp_7B_One_Stage): - model_id: str = "prism-clip+7b" - vision_backbone_id: str = "clip-vit-l-336px" - image_resize_strategy: str = "resize-naive" - llm_backbone_id: str = "llama2-7b-pure" - finetune_epochs: int = 2 - - -# =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"` -@dataclass -class Prism_13B_CLIP(Exp_13B_One_Stage): - model_id: str = "prism-clip+13b" - vision_backbone_id: str = "clip-vit-l-336px" - image_resize_strategy: str = "resize-naive" - llm_backbone_id: str = "llama2-13b-pure" - finetune_epochs: int = 2 - - -# Prism-SigLIP -@dataclass -class Prism_7B_SigLIP_Controlled(Exp_7B_One_Stage): - model_id: str = "prism-siglip-controlled+7b" - vision_backbone_id: str = "siglip-vit-so400m-384px" - image_resize_strategy: str = "resize-naive" - llm_backbone_id: str = "llama2-7b-pure" - - -@dataclass -class Prism_13B_SigLIP_Controlled(Exp_13B_One_Stage): - model_id: str = "prism-siglip-controlled+13b" - vision_backbone_id: str = "siglip-vit-so400m-384px" - image_resize_strategy: str = "resize-naive" - llm_backbone_id: str = "llama2-13b-pure" - - -# =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"` -@dataclass -class Prism_7B_SigLIP(Exp_7B_One_Stage): - model_id: str = "prism-siglip+7b" - vision_backbone_id: str = "siglip-vit-so400m-384px" - image_resize_strategy: str = "resize-naive" - llm_backbone_id: str = "llama2-7b-pure" - finetune_epochs: int = 2 - - -# =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"` -@dataclass -class Prism_13B_SigLIP(Exp_13B_One_Stage): - model_id: str = "prism-siglip+13b" - vision_backbone_id: str = "clip-vit-l-336px" - image_resize_strategy: str = "resize-naive" - llm_backbone_id: str = "llama2-13b-pure" - finetune_epochs: int = 2 - - -# Prism-DINOSigLIP -@dataclass -class Prism_7B_DINOSigLIP_Controlled(Exp_7B_One_Stage): - model_id: str = "prism-dinosiglip-controlled+7b" - vision_backbone_id: str = "dinosiglip-vit-so-384px" - image_resize_strategy: str = "resize-naive" - llm_backbone_id: str = "llama2-7b-pure" - arch_specifier: str = "no-align+fused-gelu-mlp" - - -@dataclass -class Prism_13B_DINOSigLIP_Controlled(Exp_13B_One_Stage): - model_id: str = "prism-dinosiglip-controlled+13b" - vision_backbone_id: str = "dinosiglip-vit-so-384px" - image_resize_strategy: str = "resize-naive" - llm_backbone_id: str = "llama2-13b-pure" - arch_specifier: str = "no-align+fused-gelu-mlp" - - -# =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"` -@dataclass -class Prism_7B_DINOSigLIP(Exp_7B_One_Stage): - model_id: str = "prism-dinosiglip+7b" - vision_backbone_id: str = "dinosiglip-vit-so-384px" - image_resize_strategy: str = "resize-naive" - llm_backbone_id: str = "llama2-7b-pure" - arch_specifier: str = "no-align+fused-gelu-mlp" - finetune_epochs: int = 2 - - -# =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"` -@dataclass -class Prism_13B_DINOSigLIP(Exp_13B_One_Stage): - model_id: str = "prism-dinosiglip+13b" - vision_backbone_id: str = "dinosiglip-vit-so-384px" - image_resize_strategy: str = "resize-naive" - llm_backbone_id: str = "llama2-13b-pure" - arch_specifier: str = "no-align+fused-gelu-mlp" - finetune_epochs: int = 2 - - -# [Inference-Optimized] 224px Prisms -@dataclass -class Opt_7B_DINOSigLIP_ViT_SO_p14_224px_Resize_Naive(Exp_7B_One_Stage): - model_id: str = "dinosiglip-224px-resize-naive+7b" - vision_backbone_id: str = "dinosiglip-vit-so-224px" - image_resize_strategy: str = "resize-naive" - arch_specifier: str = "no-align+fused-gelu-mlp" - - -@dataclass -class Prism_7B_DINOSigLIP_224px_Controlled(Exp_7B_One_Stage): - model_id: str = "prism-dinosiglip-224px-controlled+7b" - vision_backbone_id: str = "dinosiglip-vit-so-224px" - image_resize_strategy: str = "resize-naive" - llm_backbone_id: str = "llama2-7b-pure" - arch_specifier: str = "no-align+fused-gelu-mlp" - - -# =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"` -@dataclass -class Prism_7B_DINOSigLIP_224px(Exp_7B_One_Stage): - model_id: str = "prism-dinosiglip-224px+7b" - vision_backbone_id: str = "dinosiglip-vit-so-224px" - image_resize_strategy: str = "resize-naive" - llm_backbone_id: str = "llama2-7b-pure" - arch_specifier: str = "no-align+fused-gelu-mlp" - finetune_epochs: int = 2 - - -# === Define a Model Registry Enum for Reference & Validation === -@unique -class ModelRegistry(Enum): - # === LLaVa v1.5 Base Reproductions === - REPRODUCTION_7B = LLaVa_v15_Reproduction_7B - REPRODUCTION_13B = LLaVa_v15_Reproduction_13B - - # === Section 4.1 :: Optimization Procedure === - EXP_ONE_STAGE_7B = Exp_7B_One_Stage - EXP_ONE_STAGE_13B = Exp_13B_One_Stage - - EXP_FULL_FT_MULTI_STAGE = Exp_7B_Full_Finetune_Multi_Stage - EXP_FULL_FT_ONE_STAGE = Exp_7B_Full_Finetune_One_Stage - - # === Section 4.2 :: Image Processing and Visual Representations === - EXP_IN1K_224PX = Exp_7B_IN1K_ViT_L_p16_224px - EXP_DINOV2_224PX = Exp_7B_DINOv2_ViT_L_p14_224px - EXP_CLIP_224PX = Exp_7B_CLIP_ViT_L_p14_224px - EXP_SIGLIP_224PX = Exp_7B_SigLIP_ViT_SO_p14_224px - - EXP_CLIP_336PX_RESIZE_CROP = Exp_7B_CLIP_ViT_L_p14_336px_Resize_Crop - EXP_CLIP_336PX_RESIZE_NAIVE = Exp_7B_CLIP_ViT_L_p14_336px_Resize_Naive - EXP_SIGLIP_384PX_LETTERBOX = Exp_7B_SigLIP_ViT_SO_p14_384px_Letterbox - EXP_SIGLIP_384PX_RESIZE_CROP = Exp_7B_SigLIP_ViT_SO_p14_384px_Resize_Crop - EXP_SIGLIP_384PX_RESIZE_NAIVE = Exp_7B_SigLIP_ViT_SO_p14_384px_Resize_Naive - - EXP_DINOCLIP_336PX_LETTERBOX = Exp_7B_DINOCLIP_ViT_L_p14_336px_Letterbox - EXP_DINOCLIP_336PX_RESIZE_NAIVE = Exp_7B_DINOCLIP_ViT_L_p14_336px_Resize_Naive - EXP_DINOSIGLIP_384PX_LETTERBOX = Exp_7B_DINOSigLIP_ViT_L_p14_384px_Letterbox - EXP_DINOSIGLIP_384PX_RESIZE_NAIVE = Exp_7B_DINOSigLIP_ViT_L_p14_384px_Resize_Naive - - # === Section 4.3 :: Language Models === - EXP_LLAMA2_7B = Exp_7B_Llama2 - EXP_LLAMA2_13B = Exp_13B_Llama2 - - # ~ Additional LLM Backbone Experiments :: LLaMa-2 Chat, Mistral v0.1, Mistral v0.1 Instruct ~ - EXT_EXP_LLAMA2_CHAT_7B = Ext_Exp_7B_Llama2_Chat - EXT_EXP_LLAMA2_CHAT_13B = Ext_Exp_13B_Llama2_Chat - EXT_EXP_MISTRAL_V1_7B = Ext_Exp_7B_Mistral_V1 - EXT_EXP_MISTRAL_INSTRUCT_V1_7B = Ext_Exp_7B_Mistral_Instruct_V1 - EXT_EXP_PHI_2_3B = Ext_Exp_3B_Phi_2 - - # Cotraining w/ Unimodal Data - EXP_VICUNA_NO_COTRAINING_7B = Exp_7B_Vicuna_No_Cotraining - EXP_LLAMA2_NO_COTRAINING_7B = Exp_7B_Llama2_No_Cotraining - - # === Section 4.4 :: Scaling Properties - Train Time & Data === - EXP_1P25_EPOCHS = Exp_7B_1p25_Epochs - EXP_1P5_EPOCHS = Exp_7B_1p5_Epochs - EXP_2_EPOCHS = Exp_7B_2_Epochs - EXP_3_EPOCHS = Exp_7B_3_Epochs - - EXP_LLAVA_LVIS4V = Exp_7B_LLaVa_LVIS4V - EXP_LLAVA_LRV = Exp_7B_LLaVa_LRV - EXP_LLAVA_LVIS4V_LRV = Exp_7B_LLaVa_LVIS4V_LRV - - # === Section 5 :: Prisms === - PRISM_CLIP_CONTROLLED_7B = Prism_7B_CLIP_Controlled - PRISM_CLIP_CONTROLLED_13B = Prism_13B_CLIP_Controlled - PRISM_CLIP_7B = Prism_7B_CLIP - PRISM_CLIP_13B = Prism_13B_CLIP - - PRISM_SIGLIP_CONTROLLED_7B = Prism_7B_SigLIP_Controlled - PRISM_SIGLIP_CONTROLLED_13B = Prism_13B_SigLIP_Controlled - PRISM_SIGLIP_7B = Prism_7B_SigLIP - PRISM_SIGLIP_13B = Prism_13B_SigLIP - - PRISM_DINOSIGLIP_CONTROLLED_7B = Prism_7B_DINOSigLIP_Controlled - PRISM_DINOSIGLIP_CONTROLLED_13B = Prism_13B_DINOSigLIP_Controlled - PRISM_DINOSIGLIP_7B = Prism_7B_DINOSigLIP - PRISM_DINOSIGLIP_13B = Prism_13B_DINOSigLIP - - # === Inference Optimized :: 224px Prisms === - OPT_DINOSIGLIP_224PX_RESIZE_NAIVE = Opt_7B_DINOSigLIP_ViT_SO_p14_224px_Resize_Naive - PRISM_DINOSIGLIP_224PX_CONTROLLED_7B = Prism_7B_DINOSigLIP_224px_Controlled - PRISM_DINOSIGLIP_224PX_7B = Prism_7B_DINOSigLIP_224px - - @property - def model_id(self) -> str: - return self.value.model_id - - -# Register Models in Choice Registry -for model_variant in ModelRegistry: - ModelConfig.register_subclass(model_variant.model_id, model_variant.value) diff --git a/capvector-oft/prismatic/conf/vla.py b/capvector-oft/prismatic/conf/vla.py deleted file mode 100644 index 102694a5d58df3038e9914d18c2d42bbfa5d6383..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/conf/vla.py +++ /dev/null @@ -1,235 +0,0 @@ -""" -vla.py - -Draccus Dataclass Definition for a VLAConfig object, with various registered subclasses for each VLA experiment and -model configuration thereof. A given VLA model (`policy`) configures the following attributes: - - Data Mixture (e.g., Bridge, OXE_MAGIC_SOUP, etc.) - - Base VLM from Prismatic Registry (e.g., `prism-dinosiglip+7b`) - - VLA Model Architecture / Parameters (e.g., freeze vision encoder, last layer finetuning) - - Training / Optimization Hyperparameters -""" - -from dataclasses import dataclass -from enum import Enum, unique -from pathlib import Path -from typing import Optional, Union - -from draccus import ChoiceRegistry - - -@dataclass -class VLAConfig(ChoiceRegistry): - # fmt: off - vla_id: str # Unique VLA Policy ID that fully specifies a configuration variant - base_vlm: Union[str, Path] # Base VLM as ID/Path to Run Directory (e.g., `prism-dinosiglip+7b`) - freeze_vision_backbone: bool # Freeze Vision Backbone Parameters (akin to pretraining) - freeze_llm_backbone: bool # Freeze LLM Backbone parameters - unfreeze_last_llm_layer: bool # Unfreeze final layer of LLM (only takes effect if LLM is frozen) - - # Data Mixture Parameters - data_mix: str # Open-X Embodiment Dataset =>> Unique Mixture ID (e.g., `bridge`) - shuffle_buffer_size: int # Size of Shuffle Buffer (100K for Bridge, 1M for OXE) - - # Optimization Parameters - epochs: int # Epochs to Run (in case `max_steps` is not specified) - max_steps: Optional[int] # [Optional] Max Gradient Steps to Run (overrides `epochs`) - - expected_world_size: int # Expected # of GPUs =>> allows us to gate training on hardware - global_batch_size: int # Global Batch Size (divided across processes / world size) - per_device_batch_size: int # Per-Device Batch Size (per-process / individual GPU) - # =>> # of accumulation steps is auto-computed - - learning_rate: float # Peak Learning Rate (`lr_scheduler_type` sets warmup/decay) - weight_decay: float # Weight Decay for AdamW Optimizer - max_grad_norm: float # Max Grad Norm (for global gradient clipping) - lr_scheduler_type: str # LR Scheduler (usually: "constant" | "linear-warmup+cosine-decay") - warmup_ratio: float # Fraction of Steps to Warmup (for warmup LR schedulers) - - train_strategy: str # Train Strategy (default "fsdp-full-shard") - - # Enable Gradient/Activation Checkpointing (for the LLM Backbone) - enable_gradient_checkpointing: bool = True # Enable Gradient/Activation Checkpointing during Training - - # Mixed Precision Training via Torch Native AMP (`autocast`) - enable_mixed_precision_training: bool = True # Enable Traditional BF16 Mixed Precision - reduce_in_full_precision: bool = True # Accumulate/Reduce All-Gather Gradients in FP32 Full Precision - - # fmt: on - - -# === OpenVLA Training Configurations === - - -# = [8 GPU] Fast Iteration =>> SigLIP 224px + Bridge = -@dataclass -class Exp_SigLIP_224px_Bridge(VLAConfig): - vla_id: str = "siglip-224px+mx-bridge" - base_vlm: Union[str, Path] = "siglip-224px+7b" - - freeze_vision_backbone: bool = False - freeze_llm_backbone: bool = False - unfreeze_last_llm_layer: bool = False - - # Data Mixture Parameters - data_mix: str = "bridge" - shuffle_buffer_size: int = 256_000 - - # Optimization Parameters - epochs: int = 1000 - max_steps: Optional[int] = None - - expected_world_size: int = 8 - global_batch_size: int = 256 - per_device_batch_size: int = 32 - - learning_rate: float = 2e-5 - weight_decay: float = 0.0 - max_grad_norm: float = 1.0 - lr_scheduler_type: str = "constant" - warmup_ratio: float = 0.0 - - train_strategy: str = "fsdp-full-shard" - - -# = [8 GPU] SigLIP 224px Frozen Vision Backbone + Bridge = -@dataclass -class Exp_FreezeVIT_SigLIP_224px_Bridge(Exp_SigLIP_224px_Bridge): - vla_id: str = "siglip-224px-icy+mx-bridge" - base_vlm: Union[str, Path] = "siglip-224px+7b" - freeze_vision_backbone: bool = True - - -# = [8 GPU] Fast Iteration =>> DINO-SigLIP 224px + Bridge = -@dataclass -class Exp_DinoSigLIP_224px_Bridge(Exp_SigLIP_224px_Bridge): - vla_id: str = "prism-dinosiglip-224px+mx-bridge" - base_vlm: Union[str, Path] = "prism-dinosiglip-224px+7b" - - data_mix: str = "bridge" - - -# = [64 GPU] SigLIP 224px + OXE Magic Soup = -@dataclass -class Exp_SigLIP_224px_OXE_Magic_Soup(Exp_SigLIP_224px_Bridge): - vla_id: str = "siglip-224px+mx-oxe-magic-soup" - base_vlm: Union[str, Path] = "siglip-224px+7b" - - data_mix: str = "oxe_magic_soup" - - expected_world_size: int = 64 - global_batch_size: int = 2048 - per_device_batch_size: int = 32 - - -# = [64 GPU] DINO-SigLIP 224px + OXE Magic Soup++ = -@dataclass -class Exp_DinoSigLIP_224px_OXE_Magic_Soup_Plus(Exp_SigLIP_224px_Bridge): - vla_id: str = "prism-dinosiglip-224px+mx-oxe-magic-soup-plus" - base_vlm: Union[str, Path] = "prism-dinosiglip-224px+7b" - - # Note =>> We adopt two stages, training on a mixture including DROID for 70% of training, before resampling! - # data_mix: str = "oxe_magic_soup_plus" - data_mix: str = "oxe_magic_soup_plus_minus" - - expected_world_size: int = 64 - global_batch_size: int = 2048 - per_device_batch_size: int = 32 - - -# === OpenVLA Fine-tuning Configurations === - - -# = [8 GPU] SigLIP 224px + T-DROID = -@dataclass -class Exp_SigLIP_224px_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge): - vla_id: str = "siglip-224px+mx-tdroid_carrot_in_bowl" - base_vlm: Union[str, Path] = "siglip-224px+7b" - - data_mix: str = "tdroid_carrot_in_bowl" - - -@dataclass -class Exp_SigLIP_224px_TDROID_PourCornInPot(Exp_SigLIP_224px_Bridge): - vla_id: str = "siglip-224px+mx-tdroid_pour_corn_in_pot" - base_vlm: Union[str, Path] = "siglip-224px+7b" - - data_mix: str = "tdroid_pour_corn_in_pot" - - -# = [8 GPU] SigLIP 224px + T-DROID -- Partial Finetuning = -@dataclass -class Exp_SigLIP_224px_Icy_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge): - vla_id: str = "siglip-224px-icy+mx-tdroid_carrot_in_bowl" - base_vlm: Union[str, Path] = "siglip-224px+7b" - freeze_vision_backbone: bool = True - freeze_llm_backbone: bool = False - - data_mix: str = "tdroid_carrot_in_bowl" - - -@dataclass -class Exp_SigLIP_224px_LastLayer_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge): - vla_id: str = "siglip-224px-last_layer+mx-tdroid_carrot_in_bowl" - base_vlm: Union[str, Path] = "siglip-224px+7b" - freeze_vision_backbone: bool = True - freeze_llm_backbone: bool = True - unfreeze_last_llm_layer: bool = True - - data_mix: str = "tdroid_carrot_in_bowl" - - -@dataclass -class Exp_SigLIP_224px_Sandwich_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge): - vla_id: str = "siglip-224px-sandwich+mx-tdroid_carrot_in_bowl" - base_vlm: Union[str, Path] = "siglip-224px+7b" - freeze_vision_backbone: bool = False - freeze_llm_backbone: bool = True - unfreeze_last_llm_layer: bool = True - - data_mix: str = "tdroid_carrot_in_bowl" - - -# === [8 GPU] SigLIP 224px + FrankaWipe === -@dataclass -class Exp_SigLIP_224px_Droid_Wipe(Exp_SigLIP_224px_Bridge): - vla_id: str = "siglip-224px+mx-droid_wipe" - base_vlm: Union[str, Path] = "siglip-224px+7b" - - data_mix: str = "droid_wipe" - - -# === Define a VLA Registry Enum for Reference & Validation === -@unique -class VLARegistry(Enum): - # Sanity Check Configurations =>> BridgeV2 - SIGLIP_224PX_MX_BRIDGE = Exp_SigLIP_224px_Bridge - DINOSIGLIP_224PX_MX_BRIDGE = Exp_DinoSigLIP_224px_Bridge - - # SigLIP Frozen Backbone Experiment - FREEZE_SIGLIP_224PX_MX_BRIDGE = Exp_FreezeVIT_SigLIP_224px_Bridge - - # [OpenVLA v0.1 7B] SigLIP 224px + OXE Magic Soup - SIGLIP_224PX_MX_OXE_MAGIC_SOUP = Exp_SigLIP_224px_OXE_Magic_Soup - - # [OpenVLA 7B] DINO + SigLIP 224px + OXE Magic Soup++ - DINOSIGLIP_224PX_MX_OXE_MAGIC_SOUP_PLUS = Exp_DinoSigLIP_224px_OXE_Magic_Soup_Plus - - # === TDROID Fine-tuning Configs === - SIGLIP_224PX_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_TDROID_CarrotInBowl - SIGLIP_224PX_MX_TDROID_POUR_CORN_IN_POT = Exp_SigLIP_224px_TDROID_PourCornInPot - - SIGLIP_224PX_ICY_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_Icy_TDROID_CarrotInBowl - SIGLIP_224PX_LASTLAYER_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_LastLayer_TDROID_CarrotInBowl - SIGLIP_224PX_SANDWICH_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_Sandwich_TDROID_CarrotInBowl - - # === DROID Fine-tuning Configs === - SIGLIP_224PX_MX_DROID_WIPE = Exp_SigLIP_224px_Droid_Wipe - - @property - def vla_id(self) -> str: - return self.value.vla_id - - -# Register VLAs in Choice Registry -for vla_variant in VLARegistry: - VLAConfig.register_subclass(vla_variant.vla_id, vla_variant.value) diff --git a/capvector-oft/prismatic/extern/__init__.py b/capvector-oft/prismatic/extern/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/capvector-oft/prismatic/extern/hf/__init__.py b/capvector-oft/prismatic/extern/hf/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/capvector-oft/prismatic/extern/hf/configuration_prismatic.py b/capvector-oft/prismatic/extern/hf/configuration_prismatic.py deleted file mode 100644 index 5bc8859ae37f52b8cb6f1a51326d12288ddcbcb0..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/extern/hf/configuration_prismatic.py +++ /dev/null @@ -1,140 +0,0 @@ -""" -configuration_prismatic.py - -HuggingFace-style configuration definition for Prismatic VLMs, inheriting from `transformers.PretrainedConfig`. -Default configuration specifies `siglip-224px+7b`. -""" - -from typing import Any, Dict, List, Optional - -from transformers import PretrainedConfig -from transformers.models.auto import CONFIG_MAPPING - -# === Utilities for Mapping Prismatic names to HF names === -# fmt: off -VISION_BACKBONE_TO_RESOLUTION: Dict[str, List[int]] = { - "clip-vit-l": [224], "siglip-vit-so400m": [224], "dinov2-vit-l": [224], "in1k-vit-l": [224], - - "clip-vit-l-336px": [336], - "siglip-vit-so400m-384px": [384], - - "dinoclip-vit-l-336px": [336, 336], - "dinosiglip-vit-so-224px": [224, 224], - "dinosiglip-vit-so-384px": [384, 384], -} -VISION_BACKBONE_TO_TIMM_ID: Dict[str, List[str]] = { - "clip-vit-l": ["vit_large_patch14_clip_224.openai"], - "clip-vit-l-336px": ["vit_large_patch14_clip_336.openai"], - - "dinov2-vit-l": ["vit_large_patch14_reg4_dinov2.lvd142m"], - "in1k-vit-l": ["vit_large_patch16_224.augreg_in21k_ft_in1k"], - - "siglip-vit-so400m": ["vit_so400m_patch14_siglip_224"], - "siglip-vit-so400m-384px": ["vit_so400m_patch14_siglip_384"], - - "dinoclip-vit-l-336px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_large_patch14_clip_336.openai"], - "dinosiglip-vit-so-224px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_224"], - "dinosiglip-vit-so-384px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_384"], -} -TIMM_OVERRIDE_ACT_LAYER: Dict[str, List[Optional[str]]] = { - "clip-vit-l": ["quick_gelu"], "clip-vit-l-336px": ["quick_gelu"], - "dinov2-vit-l": [None], "in1k-vit-l": [None], - "siglip-vit-so400m": [None], "siglip-vit-so400m-384px": [None], - "dinoclip-vit-l-336px": [None, "quick_gelu"], - "dinosiglip-vit-so-224px": [None, None], "dinosiglip-vit-so-384px": [None, None] -} - -LLM_BACKBONE_TO_HF_PATH = { - "llama2-7b-pure": "meta-llama/Llama-2-7b-hf", "llama2-13b-pure": "meta-llama/Llama-2-13b-hf", - "llama2-7b-chat": "meta-llama/Llama-2-7b-chat-hf", "llama2-13b-chat": "meta-llama/Llama-2-13b-chat-hf", - - "vicuna-v15-7b": "lmsys/vicuna-7b-v1.5", "vicuna-v15-13b": "lmsys/vicuna-13b-v1.5", - - "mistral-v0.1-7b-pure": "mistralai/Mistral-7B-v0.1", - "mistral-v0.1-7b-instruct": "mistralai/Mistral-7B-Instruct-v0.1", - - "phi-2-3b": "microsoft/phi-2", -} -LLM_BACKBONE_TO_HF_METACLASS = { - "llama2-7b-pure": "llama", "llama2-13b-pure": "llama", "llama2-7b-chat": "llama", "llama2-13b-chat": "llama", - "vicuna-v15-7b": "llama", "vicuna-v15-13b": "llama", - - "mistral-v0.1-7b-pure": "mistral", "mistral-v0.1-7b-instruct": "mistral", - - "phi-2-3b": "phi", -} - -VALID_VISION_BACKBONES = set(VISION_BACKBONE_TO_RESOLUTION.keys()) -VALID_LLM_BACKBONES = set(LLM_BACKBONE_TO_HF_PATH) -# fmt: on - - -class PrismaticConfig(PretrainedConfig): - model_type: str = "prismatic" - is_composition: bool = False - - def __init__( - self, - vision_backbone_id: str = "siglip-vit-so400m", - llm_backbone_id: str = "vicuna-v15-7b", - arch_specifier: str = "no-align+gelu-mlp", - use_fused_vision_backbone: Optional[bool] = None, - image_resize_strategy: str = "letterbox", - text_config: Optional[Dict[str, Any]] = None, - llm_max_length: int = 2048, - pad_token_id: int = 32000, - pad_to_multiple_of: int = 64, - output_projector_states: bool = False, - **kwargs: str, - ) -> None: - if vision_backbone_id not in VALID_VISION_BACKBONES: - raise ValueError(f"Vision backbone `{vision_backbone_id}` not in {VALID_VISION_BACKBONES = }") - - if llm_backbone_id not in VALID_LLM_BACKBONES: - raise ValueError(f"LLM backbone `{llm_backbone_id}` not in {VALID_LLM_BACKBONES = }") - - # Set Prismatic Configuration Fields - self.vision_backbone_id = vision_backbone_id - self.llm_backbone_id = llm_backbone_id - self.arch_specifier = arch_specifier - self.output_projector_states = output_projector_states - - # [Contract] All vision backbone parameters are lists =>> supports fused backbones with different preprocessing - self.use_fused_vision_backbone = ( - use_fused_vision_backbone - if use_fused_vision_backbone is not None - else any(self.vision_backbone_id.startswith(v) for v in ["dinoclip", "dinosiglip"]) - ) - - self.timm_model_ids = VISION_BACKBONE_TO_TIMM_ID[self.vision_backbone_id] - self.timm_override_act_layers = TIMM_OVERRIDE_ACT_LAYER[self.vision_backbone_id] - self.image_sizes = VISION_BACKBONE_TO_RESOLUTION[self.vision_backbone_id] - self.image_resize_strategy = image_resize_strategy - - self.hf_llm_id = LLM_BACKBONE_TO_HF_PATH[self.llm_backbone_id] - self.llm_max_length = llm_max_length - self.pad_token_id, self.pad_to_multiple_of = pad_token_id, pad_to_multiple_of - - # [IMPORTANT] HF Utilities actually look for a `text_config` field... we need to use that specific naming! - self.text_config = ( - CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]](**text_config) - if text_config is not None - else CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]]() - ) - - # Dispatch **kwargs to super() =>> note that `pad_token_id` collides, so we pass it in here as well... - super().__init__(pad_token_id=pad_token_id, **kwargs) - - -class OpenVLAConfig(PrismaticConfig): - model_type: str = "openvla" - - def __init__( - self, - norm_stats: Optional[Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]]] = None, - n_action_bins: int = 256, - **kwargs: str, - ) -> None: - self.norm_stats, self.n_action_bins = norm_stats, n_action_bins - - super().__init__(**kwargs) diff --git a/capvector-oft/prismatic/extern/hf/modeling_prismatic.py b/capvector-oft/prismatic/extern/hf/modeling_prismatic.py deleted file mode 100644 index 1ddeb671bbabdaf0f0f74b053f1cb8de9e02ef93..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/extern/hf/modeling_prismatic.py +++ /dev/null @@ -1,1085 +0,0 @@ -""" -modeling_prismatic.py - -Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions. -Inherits from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained, -but exactly replicate the logic in `prismatic.models.vlms.prismatic.py`. -""" - -import logging -from dataclasses import dataclass -from functools import partial -from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union - -import numpy as np -import timm -import tokenizers -import torch -import torch.nn as nn -import transformers -from timm.models.vision_transformer import LayerScale -from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel -from transformers.modeling_outputs import ModelOutput - -from prismatic.training.train_utils import ( - get_current_action_mask, - get_next_actions_mask, -) -from prismatic.vla.constants import ( - ACTION_DIM, - ACTION_PROPRIO_NORMALIZATION_TYPE, - ACTION_TOKEN_BEGIN_IDX, - IGNORE_INDEX, - NUM_ACTIONS_CHUNK, - STOP_INDEX, - NormalizationType, -) - -from .configuration_prismatic import OpenVLAConfig, PrismaticConfig - -# Set up logger -logger = logging.getLogger(__name__) - - -# === Utility Functions for Monkey-Patching === -def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]: - def wrapper(*args: Any, **kwargs: Any) -> Any: - result = fn(*args, **kwargs) - return result[0] if isinstance(result, tuple) else result - - return wrapper - - -# HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale. -# =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109 -# =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960 -def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor: - return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor - - -def ls_apply_patch(ls_module: LayerScale): - ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone()) - ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale) - del ls_module.gamma - - -# === Prismatic Vision Backbone (nn.Module) Definitions (w/ Fused Backbone Support) === -class PrismaticVisionBackbone(nn.Module): - """ - Vision backbone for Prismatic models that handles image feature extraction. - - Supports both single backbone (e.g., SigLIP) and fused backbone (e.g., SigLIP + DINOv2) configurations. - For fused backbones, features from both models are concatenated along the feature dimension. - """ - - def __init__( - self, - use_fused_vision_backbone: bool, - image_sizes: List[int], - timm_model_ids: List[str], - timm_override_act_layers: List[Optional[str]], - ) -> None: - """ - Initialize the vision backbone. - - Args: - use_fused_vision_backbone: Whether to use two backbones and fuse their features - image_sizes: List of image sizes for each backbone - timm_model_ids: List of TIMM model IDs to use for each backbone - timm_override_act_layers: List of activation layer overrides for each backbone - """ - super().__init__() - self.use_fused_vision_backbone = use_fused_vision_backbone - self.num_images_in_input = 1 # Default value, can be overridden later - - # Validate number of (fused) vision backbones - if len(timm_model_ids) > 2: - raise ValueError("Prismatic models only support up to 2 (fused) vision backbones!") - - # Create primary featurizer - self.featurizer = self._create_featurizer( - model_id=timm_model_ids[0], img_size=image_sizes[0], act_layer=timm_override_act_layers[0] - ) - self.embed_dim = self.featurizer.embed_dim - - # Create secondary featurizer if using fused backbone - if self.use_fused_vision_backbone: - self.fused_featurizer = self._create_featurizer( - model_id=timm_model_ids[1], img_size=image_sizes[1], act_layer=timm_override_act_layers[1] - ) - self.embed_dim += self.fused_featurizer.embed_dim - - # Patch LayerScale modules for HF compatibility - self._patch_layer_scales() - - def _create_featurizer(self, model_id: str, img_size: int, act_layer: Optional[str]) -> nn.Module: - """ - Create a TIMM-based featurizer model with appropriate configurations. - - Args: - model_id: The TIMM model ID to load - img_size: Input image size for the model - act_layer: Override for the activation layer type - - Returns: - A configured featurizer model - """ - featurizer = timm.create_model( - model_id, - pretrained=False, - num_classes=0, - img_size=img_size, - act_layer=act_layer, - ) - - # Monkey-patch the forward function to extract the second-to-last layer features - num_blocks = len(featurizer.blocks) - featurizer.forward = unpack_tuple(partial(featurizer.get_intermediate_layers, n={num_blocks - 2})) - - return featurizer - - def _patch_layer_scales(self) -> None: - """ - Patch all LayerScale modules to be compatible with HF's parameter naming. - - HF Transformers overwrites parameters with names containing 'gamma', - so we need to rename and modify the forward method. - """ - # Patch primary featurizer - for module in self.featurizer.modules(): - if isinstance(module, LayerScale): - ls_apply_patch(module) - - # Patch secondary featurizer if it exists - if self.use_fused_vision_backbone: - for module in self.fused_featurizer.modules(): - if isinstance(module, LayerScale): - ls_apply_patch(module) - - def get_num_patches(self) -> int: - """ - Returns the number of vision patches output by the vision backbone. - - Returns: - Number of patches per image - """ - return self.featurizer.patch_embed.num_patches - - def get_num_images_in_input(self) -> int: - """ - Returns the number of input images for the vision backbone. - - Returns: - Number of images expected in the input - """ - return self.num_images_in_input - - def set_num_images_in_input(self, num_images_in_input: int) -> None: - """ - Sets the number of input images for the vision backbone. - - Args: - num_images_in_input: Number of images to expect in the input - """ - self.num_images_in_input = num_images_in_input - - def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: - """ - Implements the forward pass for the vision backbone. - - If `self.use_fused_vision_backbone == True`, uses both SigLIP and DINOv2 transformers to extract visual features - (otherwise uses SigLIP only). Allows multi-image inputs (but only for fused vision backbone). - - Args: - pixel_values (torch.Tensor): Pixels for input image(s), (B, C, H, W). - """ - if self.num_images_in_input == 1: - if not self.use_fused_vision_backbone: - return self.featurizer(pixel_values) - - # Split `pixel_values :: [bsz, 2 * 3, resolution, resolution]` =>> featurize =>> channel stack - img, img_fused = torch.split(pixel_values, [3, 3], dim=1) - patches, patches_fused = self.featurizer(img), self.fused_featurizer(img_fused) - - return torch.cat([patches, patches_fused], dim=2) - - else: - assert self.use_fused_vision_backbone, "Multi-image inputs require using fused backbone!" - - # Split `pixel_values` into individual images (each with 6 channels: 3 for SigLIP + 3 for DINOv2) - images = torch.split(pixel_values, [6] * self.num_images_in_input, dim=1) - - # Process each image and collect patches - all_patches = [] - for img in images: - # Split each image further into two stacks of channels (each with 3 channels) - img_regular, img_fused = torch.split(img, [3, 3], dim=1) - - # Get patches from both SigLIP and DINOv2 vision transformers - patches = self.featurizer(img_regular) - patches_fused = self.fused_featurizer(img_fused) - - # Concatenate SigLIP and DINOv2 patches along the hidden dimension - combined_patches = torch.cat([patches, patches_fused], dim=2) - all_patches.append(combined_patches) - - # Concatenate all patches along the patch dimension - return torch.cat(all_patches, dim=1) - - -# === Prismatic Projector (nn.Module) Definitions === -class PrismaticProjector(nn.Module): - def __init__(self, use_fused_vision_backbone: bool, vision_dim: int, llm_dim: int) -> None: - super().__init__() - self.use_fused_vision_backbone = use_fused_vision_backbone - self.vision_dim, self.llm_dim = vision_dim, llm_dim - - # Switch on `use_fused_vision_backbone` =>> use slightly different MLPs and projection factors! - if not self.use_fused_vision_backbone: - self.fc1 = nn.Linear(self.vision_dim, self.llm_dim, bias=True) - self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True) - self.act_fn1 = nn.GELU() - else: - initial_projection_dim = 4 * vision_dim - self.fc1 = nn.Linear(self.vision_dim, initial_projection_dim, bias=True) - self.fc2 = nn.Linear(initial_projection_dim, self.llm_dim, bias=True) - self.fc3 = nn.Linear(self.llm_dim, self.llm_dim, bias=True) - self.act_fn1 = nn.GELU() - self.act_fn2 = nn.GELU() - - def forward(self, img_patches: torch.Tensor) -> torch.Tensor: - if not self.use_fused_vision_backbone: - projected_features = self.fc1(img_patches) - projected_features = self.act_fn1(projected_features) - projected_features = self.fc2(projected_features) - else: - projected_features = self.fc1(img_patches) - projected_features = self.act_fn1(projected_features) - projected_features = self.fc2(projected_features) - projected_features = self.act_fn2(projected_features) - projected_features = self.fc3(projected_features) - - return projected_features - - -# === Main HF Class Definitions === -@dataclass -class PrismaticCausalLMOutputWithPast(ModelOutput): - """Base class for Prismatic casual (visually-conditioned) language model outputs; also exposes visual features.""" - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - - # Additions for VLMs - projector_features: Optional[torch.FloatTensor] = None - - -class PrismaticPreTrainedModel(PreTrainedModel): - config_class: PretrainedConfig = PrismaticConfig - base_model_prefix: str = "model" - supports_gradient_checkpointing: bool = True - - _no_split_modules: ClassVar[List[str]] = ["PrismaticProjector"] - _skip_keys_device_placement: str = "past_key_values" - _supports_flash_attn_2: bool = True - - def _init_weights(self, module: nn.Module) -> None: - # Important :: this HF ported version is *not* meant for training from scratch; only inference and fine-tuning! - # => As such, this init_weights code is not correct; if training VLMs from scratch, use the main codebase at - # https://github.com/TRI-ML/prismatic-vlms - std = ( - self.config.initializer_range - if hasattr(self.config, "initializer_range") - else self.config.text_config.initializer_range - ) - - if hasattr(module, "class_embedding"): - module.class_embedding.data.normal_(mean=0.0, std=std) - - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - @property - def _supports_sdpa(self) -> bool: - """Check LLM supports SDPA Attention""" - return self.language_model._supports_sdpa - - -class PrismaticForConditionalGeneration(PrismaticPreTrainedModel): - def __init__(self, config: PrismaticConfig) -> None: - super().__init__(config) - - # [Validation] Lightweight Validate on `config` Fields + Dependency Versions - if config.use_fused_vision_backbone is None: - raise ValueError("Missing config field `use_fused_vision_backbone`") - - if timm.__version__ not in {"0.9.10", "0.9.11", "0.9.12", "0.9.16"}: - raise NotImplementedError( - "TIMM Version must be >= 0.9.10 and < 1.0.0 (breaking); please raise a GitHub Issue " - "if you urgently need support for latest TIMM versions." - ) - - if (transformers.__version__ != "4.40.1") or (tokenizers.__version__ != "0.19.1"): - logger.warning( - f"Expected `transformers==4.40.1` and `tokenizers==0.19.1` but got " - f"`transformers=={transformers.__version__}` and `tokenizers=={tokenizers.__version__}`; " - f"there might be inference-time regressions due to dependency changes. If in doubt, please" - f"use the above versions." - ) - - # Instantiate PrismaticVisionBackbone (w/ Potential Fused Backbone) - self.vision_backbone = PrismaticVisionBackbone( - config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers - ) - - # Create Multimodal Projector - self.projector = PrismaticProjector( - config.use_fused_vision_backbone, - vision_dim=self.vision_backbone.embed_dim, - llm_dim=config.text_config.hidden_size, - ) - - # Instantiate LLM Backbone - self.language_model = AutoModelForCausalLM.from_config( - config.text_config, attn_implementation=config._attn_implementation - ) - self.vocab_size = config.text_config.vocab_size - self.pad_token_id = config.pad_token_id - self.llm_dim = config.text_config.hidden_size - - # HF Boilerplate =>> initializes weights via `_init_weights()` and sets gradient checkpointing - self.post_init() - - # === `PreTrainedModel` Boilerplate === - def get_input_embeddings(self) -> nn.Module: - return self.language_model.get_input_embeddings() - - def set_input_embeddings(self, value: nn.Module) -> None: - self.language_model.set_input_embeddings(value) - - def get_output_embeddings(self) -> nn.Module: - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings: nn.Module) -> None: - self.language_model.set_output_embeddings(new_embeddings) - - def get_decoder(self) -> nn.Module: - return self.language_model.get_decoder() - - def set_decoder(self, decoder: nn.Module) -> None: - self.language_model.set_decoder(decoder) - - def tie_weights(self) -> None: - self.language_model.tie_weights() # Note: `Llama-2` and `Mistral` don't tie weights (no-op) - - def resize_token_embeddings( - self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None - ) -> nn.Embedding: - updated_embeddings = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) - - # Update config/instance variables - self.config.text_config.vocab_size = updated_embeddings.num_embeddings - self.vocab_size = updated_embeddings.num_embeddings - - return updated_embeddings - - def _replace_input_embeddings(self, input_embeddings, all_actions_mask, noisy_action_features): - """ - Replace embeddings in input_embeddings at positions where all_actions_mask is True - with embeddings from noisy_action_features, using vectorized operations. - - Args: - input_embeddings: Tensor of shape (B, S, D) - all_actions_mask: Boolean tensor of shape (B, S) - noisy_action_features: Tensor of shape (B, K, D) where K is the number of True values in mask per sample - - Returns: - Modified input_embeddings tensor - """ - # Clone input to avoid modifying the original tensor - new_input_embeddings = input_embeddings.clone() - - # Create a tensor with the same shape of input_embeddings to hold the noisy action features - repositioned_noisy_action_features = torch.zeros_like(input_embeddings) - - # Create batch indices for splicing - batch_indices = torch.arange(input_embeddings.shape[0], device=input_embeddings.device) - batch_indices = batch_indices.unsqueeze(1).expand(-1, noisy_action_features.shape[1]) - - # Get indices where mask is True for each sample - masked_indices = torch.stack([torch.where(mask)[0] for mask in all_actions_mask]) - - # Move the noisy action features into their correct positions - repositioned_noisy_action_features[batch_indices, masked_indices] = noisy_action_features - - # Combine original input embeddings and noisy action embeddings using the mask - new_input_embeddings = torch.where( - all_actions_mask.unsqueeze(-1), repositioned_noisy_action_features, new_input_embeddings - ) - - return new_input_embeddings - - def _process_action_masks(self, labels): - """Helper to get action masks from labels""" - current_action_mask = get_current_action_mask(labels) - next_actions_mask = get_next_actions_mask(labels) - all_actions_mask = current_action_mask | next_actions_mask # (B, seq_len) - return all_actions_mask - - def _process_vision_features(self, pixel_values, language_embeddings=None, use_film=False): - """Process vision features with optional FiLM conditioning""" - if use_film: - # FiLM: Infuse language inputs into visual features - patch_features = self.vision_backbone(pixel_values, language_embeddings) # (bsz, 256 * num_images, D) - else: - patch_features = self.vision_backbone(pixel_values) # (bsz, 256 * num_images, D) - - # Project patch embeddings into language embedding space - return self.projector(patch_features) - - def _process_proprio_features(self, projected_patch_embeddings, proprio, proprio_projector): - """Process proprioceptive features and append to vision features""" - if proprio_projector is not None and proprio is not None: - # projected_patch_embeddings: (bsz, num_patches * num_images, llm_dim) - # proprio: (bsz, proprio_dim) or (propro_dim,) - proprio = proprio.reshape(projected_patch_embeddings.shape[0], -1) # (bsz, proprio_dim) - proprio_features = proprio_projector(proprio) # (bsz, llm_dim) - proprio_features = proprio_features.unsqueeze(dim=1) # (bsz, 1, llm_dim) - # For simplicity, just append proprio token to the end of projected vision patch tokens - return torch.cat((projected_patch_embeddings, proprio_features), dim=1) - return projected_patch_embeddings - - def _build_multimodal_attention(self, input_embeddings, projected_patch_embeddings, attention_mask): - """Build multimodal embeddings and attention mask""" - # Update attention mask - projected_patch_attention_mask = None - if attention_mask is not None: - projected_patch_attention_mask = torch.full( - (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]), - fill_value=True, - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - - # Build multimodal embeddings & attention mask; insert embeddings after token (1:) - multimodal_embeddings = torch.cat( - [input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1 - ) - - multimodal_attention_mask = None - if attention_mask is not None: - multimodal_attention_mask = torch.cat( - [attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1 - ) - - return multimodal_embeddings, multimodal_attention_mask - - def _build_multimodal_labels(self, labels, projected_patch_embeddings): - """Build multimodal labels with IGNORE_INDEX for patch embeddings""" - if labels is not None: - projected_patch_labels = torch.full( - (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]), - fill_value=IGNORE_INDEX, - dtype=labels.dtype, - device=labels.device, - ) - return torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1) - return None - - # === Core Prismatic VLM `forward()` Logic === - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_projector_features: Optional[bool] = None, - return_dict: Optional[bool] = None, - proprio=None, - proprio_projector=None, - noisy_actions=None, - noisy_action_projector=None, - diffusion_timestep_embeddings=None, - use_film: bool = False, - ) -> Union[Tuple, PrismaticCausalLMOutputWithPast]: - """Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance.""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - output_projector_features = output_projector_features if output_projector_features is not None else False - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # Respect `use_cache` only if not training (even if `gradient_checkpointing` is off) - use_cache = use_cache and not self.training - - # Instantiate Placeholder for Projector Features - projected_patch_embeddings = None - - # === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` === - if input_ids.shape[1] == 1: - assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!" - assert past_key_values is not None, "You must provide `past_key_values` during cached generation!" - assert labels is None, "Unexpected key `labels` provided during cached generation!" - - language_model_output = self.language_model( - input_ids=input_ids, - attention_mask=None, - position_ids=None, - past_key_values=past_key_values, - inputs_embeds=None, - labels=None, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - # === Handle Unimodal Forward === - elif pixel_values is None: - assert (input_ids is not None) and (inputs_embeds is None), "Missing `input_ids` in language-only forward!" - assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!" - - language_model_output = self.language_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=None, - past_key_values=None, - inputs_embeds=None, - labels=labels, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - # === Handle Multimodal Forward === - elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]): - assert past_key_values is None, "Unexpected key `past_key_values` provided during multimodal forward!" - - # Get input embeddings (from language model embeddings) - input_embeddings = self.get_input_embeddings()(input_ids) # (B, seq_len, D) - - # Extract action masks - all_actions_mask = self._process_action_masks(labels) - - # Extract the language portion of the input embeddings (i.e. remove the action tokens portion) - language_embeddings = input_embeddings[~all_actions_mask].reshape( - input_embeddings.shape[0], -1, input_embeddings.shape[2] - ) # (B, lang_seq_len, llm_dim) - - # Get visual features - projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film) - - # Add proprioceptive state if provided - projected_patch_embeddings = self._process_proprio_features( - projected_patch_embeddings, proprio, proprio_projector - ) - - # [Diffusion] Add diffusion timestep embedding if provided - if diffusion_timestep_embeddings is not None: - # For simplicity, just append diffusion timestep embedding to the end of projected vision patch tokens - projected_patch_embeddings = torch.cat( - (projected_patch_embeddings, diffusion_timestep_embeddings), dim=1 - ) - - # Process action embeddings - if noisy_actions is not None: - # Get mask corresponding to all action tokens - all_actions_mask = self._process_action_masks(labels) - - # Reshape noisy actions into individual action tokens - # noisy_actions: (B, chunk_len, action_dim) -> (B, chunk_len * action_dim, 1) - B = noisy_actions.shape[0] - noisy_actions = noisy_actions.reshape(B, -1).unsqueeze(-1) - - # Project noisy action tokens into language model embedding space - noisy_action_features = noisy_action_projector(noisy_actions) # (B, chunk_len * action_dim, llm_dim) - - # Replace embeddings of the action tokens with noisy action embeddings - input_embeddings = self._replace_input_embeddings( - input_embeddings, all_actions_mask, noisy_action_features - ) - else: - # Replace the embeddings of the action tokens with zeros - # (Later on, the positional embeddings will be added to them) - all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1) - input_embeddings = input_embeddings * ~all_actions_mask - - # Build multimodal embeddings & attention mask - multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention( - input_embeddings, projected_patch_embeddings, attention_mask - ) - - # Build labels for multimodal sequence if needed - multimodal_labels = self._build_multimodal_labels(labels, projected_patch_embeddings) - - # Dispatch to language model - language_model_output = self.language_model( - input_ids=None, - attention_mask=multimodal_attention_mask, - position_ids=None, - past_key_values=None, - inputs_embeds=multimodal_embeddings, - labels=multimodal_labels, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - # === Otherwise =>> Assume Invalid! === - elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]): - raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!") - - else: - raise ValueError( - "Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n" - f"=> `input_ids` = {input_ids is not None}\n" - f"=> `attention_mask` = {attention_mask is not None}\n" - f"=> `pixel_values` = {pixel_values is not None}\n" - f"=> `labels` = {labels is not None}\n" - f"=> `input_embeds` = {inputs_embeds is not None}\n" - f"=> `past_key_values` = {past_key_values is not None}\n" - f"=> `use_cache` = {use_cache}" - ) - - # Unpack `language_model_output` and return PrismaticCausalLMOutputWithPast (or tuple if not `return_dict`) - if not return_dict: - if output_projector_features and (projected_patch_embeddings is not None): - return *language_model_output, projected_patch_embeddings - - return language_model_output - - return PrismaticCausalLMOutputWithPast( - loss=language_model_output.loss, - logits=language_model_output.logits, - past_key_values=language_model_output.past_key_values, - hidden_states=language_model_output.hidden_states, - attentions=language_model_output.attentions, - projector_features=projected_patch_embeddings, - ) - - # === GenerationMixin Methods === - def prepare_inputs_for_generation( - self, - input_ids: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - **kwargs: str, - ) -> Dict[str, torch.Tensor]: - """Borrowed from `LlamaForCausalLM` and simplified for batch size = 1; mirrors original PrismaticVLM logic.""" - if ((input_ids is not None) and (input_ids.shape[0] > 1)) or ( - (inputs_embeds is not None) and (inputs_embeds.shape[0] > 1) - ): - raise ValueError("Generation with batch size > 1 is not currently supported!") - - # Handle `past_key_values` (cache) =>> assume `input_ids` just has unprocessed tokens - if past_key_values is not None: - input_ids = input_ids[:, -1:] - - # If `input_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"input_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - # Make sure `pixel_values` are preserved in `model_inputs` - model_inputs.update( - { - "attention_mask": attention_mask, - "pixel_values": pixel_values, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - } - ) - - return model_inputs - - # Defer to Language Model (all handle this differently, with different return types) - def _reorder_cache(self, *args, **kwargs) -> Any: - return self.language_model._reorder_cache(*args, **kwargs) - - -class OpenVLAForActionPrediction(PrismaticForConditionalGeneration): - config_class: PretrainedConfig = OpenVLAConfig - - def __init__(self, config: OpenVLAConfig) -> None: - super().__init__(config) - self.norm_stats = config.norm_stats - - # Compute action bins - self.bins = np.linspace(-1, 1, config.n_action_bins) - self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0 - - # Compute vocab size for de-tokenization -- revert added "multiple of" - self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of - - def _prepare_input_for_action_prediction(self, input_ids, attention_mask): - """Prepares input for action prediction by adding necessary tokens""" - # Add (ACTION_DIM * NUM_ACTIONS_CHUNK) placeholder tokens to input_ids to simulate action tokens - placeholder_action_token_ids = ( - torch.ones((input_ids.shape[0], ACTION_DIM * NUM_ACTIONS_CHUNK)).to(input_ids.device).to(input_ids.dtype) - ) - input_ids = torch.cat([input_ids, placeholder_action_token_ids], dim=-1) - - # Add stop token to sequence (needed in non-causal bi-directional self-attention, as it appears at train time) - stop_token_id = torch.ones((input_ids.shape[0], 1)).to(input_ids.device).to(input_ids.dtype) * STOP_INDEX - input_ids = torch.cat([input_ids, stop_token_id], dim=-1) - - # Extend the attention mask to fit the new shape of input - # Note: Only batch size == 1 supported right now - mask_extension = ( - torch.ones((attention_mask.shape[0], input_ids.shape[-1] - attention_mask.shape[-1])) - .to(attention_mask.device) - .to(attention_mask.dtype) - ) - attention_mask = torch.cat([attention_mask, mask_extension], dim=-1) - - return input_ids, attention_mask - - def _prepare_labels_for_action_prediction(self, labels, input_ids): - """Creates labels tensor for action prediction if not provided""" - # Extend labels tensor with fake action labels - ARBITRARY_ACTION_TOKEN_IDX = ACTION_TOKEN_BEGIN_IDX + 1 - labels_extension = ( - torch.ones((labels.shape[0], input_ids.shape[-1] - labels.shape[-1])).to(labels.device).to(labels.dtype) - * ARBITRARY_ACTION_TOKEN_IDX - ) - labels = torch.cat([labels, labels_extension], dim=-1) - - # Replace last label token with stop token - labels[:, -1] = STOP_INDEX - - return labels - - def _unnormalize_actions(self, normalized_actions, unnorm_key=None): - """Unnormalize actions using dataset statistics""" - action_norm_stats = self.get_action_stats(unnorm_key) - - if ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS: - mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["min"], dtype=bool)) - action_high, action_low = np.array(action_norm_stats["max"]), np.array(action_norm_stats["min"]) - elif ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS_Q99: - mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool)) - action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"]) - else: - raise ValueError("Unsupported action/proprio normalization type detected!") - - actions = np.where( - mask, - 0.5 * (normalized_actions + 1) * (action_high - action_low + 1e-8) + action_low, - normalized_actions, - ) - - return actions - - def _run_diffusion_prediction( - self, - input_embeddings, - all_actions_mask, - noise, - action_head, - projected_patch_embeddings, - labels, - attention_mask, - NUM_PATCHES, - NUM_PROMPT_TOKENS, - noisy_action_projector, - ): - """Run diffusion-based action prediction""" - # Clone embedding for reuse in each timestep - orig_projected_patch_embeddings = projected_patch_embeddings.clone() - curr_noisy_actions = noise - - # Reverse diffusion: Iteratively denoise to generate action prediction - for t in action_head.noise_scheduler.timesteps: - # Get diffusion model's noise prediction (conditioned on VLA latent embedding, current noisy action - # embedding, and diffusion timestep embedding) - timesteps = torch.Tensor([t]).to(labels.device) - diffusion_timestep_embeddings = ( - action_head.time_encoder(timesteps).to(curr_noisy_actions.dtype).to(curr_noisy_actions.device) - ) # (B, llm_dim) - diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) # (B, 1, llm_dim) - - # [Diffusion] Replace the embeddings of the action tokens with noisy actions - # (Later on, the positional embeddings will be added to them) - - # For simplicity, append diffusion timestep embedding to the end of projected vision tokens - projected_patch_embeddings = torch.cat( - (orig_projected_patch_embeddings, diffusion_timestep_embeddings), dim=1 - ) - - # Reshape and project noisy actions into language embedding space - B = curr_noisy_actions.shape[0] - orig_curr_noisy_actions_shape = curr_noisy_actions.shape - curr_noisy_actions = curr_noisy_actions.reshape(B, -1).unsqueeze(-1) - noisy_action_features = noisy_action_projector(curr_noisy_actions) - curr_noisy_actions = curr_noisy_actions.reshape(orig_curr_noisy_actions_shape) - - # Replace action token embeddings with noisy action embeddings - input_embeddings = self._replace_input_embeddings( - input_embeddings.clone(), all_actions_mask, noisy_action_features - ) - - # Build multimodal embeddings and attention mask - multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention( - input_embeddings, projected_patch_embeddings, attention_mask - ) - - # Forward pass through language model - language_model_output = self.language_model( - input_ids=None, - attention_mask=multimodal_attention_mask, - position_ids=None, - past_key_values=None, - inputs_embeds=multimodal_embeddings, - labels=None, - use_cache=None, - output_attentions=False, - output_hidden_states=True, - return_dict=True, - ) - - # Extract hidden states for action portion of response - last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D) - actions_hidden_states = last_hidden_states[ - :, - NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK, - :, - ] # (B, act_chunk_len, D) - - # Predict noise and update noisy actions: x_t -> x_{t-1} - noise_pred = action_head.predict_noise(actions_hidden_states) - curr_noisy_actions = action_head.noise_scheduler.step(noise_pred, t, curr_noisy_actions).prev_sample - - curr_noisy_actions = curr_noisy_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM) - - # Return final actions - return curr_noisy_actions.float().cpu().detach().numpy(), actions_hidden_states - - def _regression_or_discrete_prediction( - self, - input_embeddings, - all_actions_mask, - projected_patch_embeddings, - attention_mask, - labels, - NUM_PATCHES, - NUM_PROMPT_TOKENS, - action_head=None, - ): - """Run L1 regression-based continuous action prediction or discrete action tokens prediction.""" - # Zero out action token embeddings - all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1) - input_embeddings = input_embeddings * ~all_actions_mask - - # Build multimodal embeddings and attention mask - multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention( - input_embeddings, projected_patch_embeddings, attention_mask - ) - - # Forward pass through language model - language_model_output = self.language_model( - input_ids=None, - attention_mask=multimodal_attention_mask, - position_ids=None, - past_key_values=None, - inputs_embeds=multimodal_embeddings, - labels=None, - use_cache=None, - output_attentions=False, - output_hidden_states=True, - return_dict=True, - ) - - # Extract hidden states for action tokens - last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D) - actions_hidden_states = last_hidden_states[ - :, - NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK, - :, - ] # (B, act_chunk_len, D) - - # Handle different prediction methods - if action_head is not None: - # L1 regression prediction - normalized_actions = action_head.predict_action(actions_hidden_states) - normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM) - normalized_actions = normalized_actions.float().cpu().detach().numpy() - else: - # Discrete token-based prediction - predicted_action_token_ids = ( - language_model_output.logits[ - :, - NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK, - ] - .argmax(dim=2) - .cpu() - .numpy() - ) - discretized_actions = self.vocab_size - predicted_action_token_ids - discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1) - normalized_actions = self.bin_centers[discretized_actions] - normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM) - - return normalized_actions, actions_hidden_states - - def predict_action( - self, - input_ids: Optional[torch.LongTensor] = None, - unnorm_key: Optional[str] = None, - proprio=None, - proprio_projector=None, - action_head=None, - noisy_action_projector=None, - use_film: bool = False, - **kwargs: str, - ) -> np.ndarray: - """Predict actions from input sequence, with options for different prediction methods. - - Args: - input_ids: Input token ids - unnorm_key: Key for unnormalization statistics - proprio: Proprioceptive features - proprio_projector: Projector for proprioceptive features - action_head: Optional head for L1 regression or diffusion-based prediction - noisy_action_projector: Projector for noisy actions in diffusion-based prediction - use_film: Whether to use FiLM conditioning - **kwargs: Additional arguments including pixel_values and attention_mask - - Returns: - Tuple of (unnormalized_actions, action_hidden_states) - """ - # If the special empty token ('') does not already appear after the colon (':') token in the prompt - # (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time - if not torch.all(input_ids[:, -1] == 29871): - input_ids = torch.cat( - (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1 - ) - - pixel_values = kwargs["pixel_values"] - attention_mask = kwargs["attention_mask"] - - # Create fake labels tensor (needed for action mask) - labels = input_ids.clone() - labels[:] = IGNORE_INDEX - - # Get number of tokens in prompt (excluding the start token) - NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1 # Subtract action tokens and stop token - - # Prepare inputs by adding necessary tokens - input_ids, attention_mask = self._prepare_input_for_action_prediction(input_ids, attention_mask) - - # Update labels tensor for action mask computation later - labels = self._prepare_labels_for_action_prediction(labels, input_ids) - - # Get input embeddings and action masks - input_embeddings = self.get_input_embeddings()(input_ids) - all_actions_mask = self._process_action_masks(labels) - - # Extract language embeddings - language_embeddings = input_embeddings[~all_actions_mask].reshape( - input_embeddings.shape[0], -1, input_embeddings.shape[2] - ) - - # Process vision features - projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film) - - # Add proprioceptive features if provided - use_proprio = proprio_projector is not None and proprio is not None - if use_proprio: - proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype) - projected_patch_embeddings = self._process_proprio_features( - projected_patch_embeddings, proprio, proprio_projector - ) - - # Use diffusion if provided, otherwise use regression or discrete prediction - use_diffusion = noisy_action_projector is not None and hasattr(action_head, "noise_scheduler") - - # Calculate number of patches (including proprio token and/or diffusion timestep embedding if present) - NUM_PATCHES = self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input() - if use_proprio: - NUM_PATCHES += 1 - if use_diffusion: - NUM_PATCHES += 1 - - if use_diffusion: - # Sample random noise with shape equal to output action, used as the starting state for reverse diffusion - noise = torch.randn( - size=(1, NUM_ACTIONS_CHUNK, ACTION_DIM), device=input_embeddings.device, dtype=input_embeddings.dtype - ) - - # Run diffusion-based prediction - normalized_actions, actions_hidden_states = self._run_diffusion_prediction( - input_embeddings, - all_actions_mask, - noise, - action_head, - projected_patch_embeddings, - labels, - attention_mask, - NUM_PATCHES, - NUM_PROMPT_TOKENS, - noisy_action_projector, - ) - else: - # Run regression or discrete token-based prediction - normalized_actions, actions_hidden_states = self._regression_or_discrete_prediction( - input_embeddings, - all_actions_mask, - projected_patch_embeddings, - attention_mask, - labels, - NUM_PATCHES, - NUM_PROMPT_TOKENS, - action_head, - ) - - # Unnormalize predicted actions - actions = self._unnormalize_actions(normalized_actions, unnorm_key) - - return actions, actions_hidden_states - - @staticmethod - def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str: - """Validate and resolve the unnormalization key for action statistics""" - if unnorm_key is None: - assert len(norm_stats) == 1, ( - f"Your model was trained on more than one dataset, " - f"please pass a `unnorm_key` from the following options to choose the statistics " - f"used for un-normalizing actions: {norm_stats.keys()}" - ) - unnorm_key = next(iter(norm_stats.keys())) - - assert unnorm_key in norm_stats, ( - f"The `unnorm_key` you chose is not in the set of available dataset statistics, " - f"please choose from: {norm_stats.keys()}" - ) - return unnorm_key - - def get_action_dim(self, unnorm_key: Optional[str] = None) -> int: - """Get the dimensionality of the policy's action space.""" - unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) - return len(self.norm_stats[unnorm_key]["action"]["min"]) - - def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]: - """Get all the logged statistics for the given dataset.""" - unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) - return self.norm_stats[unnorm_key]["action"] diff --git a/capvector-oft/prismatic/extern/hf/processing_prismatic.py b/capvector-oft/prismatic/extern/hf/processing_prismatic.py deleted file mode 100644 index f16e2a5466100ad23628e781589e8ae292cdf0b0..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/extern/hf/processing_prismatic.py +++ /dev/null @@ -1,252 +0,0 @@ -""" -processing_prismatic.py - -HuggingFace-style preprocessor definitions for Prismatic VLMs, inheriting from `ProcessorMixin`. Default configuration -specifies `siglip-224px+7b`. -""" - -from typing import Any, ClassVar, List, Optional, Tuple, Union - -import timm.data -import torch -import torchvision.transforms.functional as TVF -from PIL import Image -from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor -from transformers import PreTrainedTokenizerBase -from transformers.image_processing_utils import BatchFeature, ImageProcessingMixin -from transformers.processing_utils import ProcessorMixin -from transformers.tokenization_utils import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy -from transformers.utils import TensorType - - -# === Image Processing === -def letterbox_pad_transform(image: Image.Image, padding_fill_value: Tuple[int, int, int]) -> Image.Image: - """Given a PIL.Image, pad to square by adding a symmetric border around the height/width.""" - (w, h), max_wh = image.size, max(image.size) - horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2) - padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad) - - return TVF.pad(image, padding, fill=padding_fill_value, padding_mode="constant") - - -class PrismaticImageProcessor(ImageProcessingMixin): - model_input_names: ClassVar[List[str]] = ["pixel_values"] - - def __init__( - self, - use_fused_vision_backbone: bool = False, - image_resize_strategy: str = "letterbox", - input_sizes: Optional[List[Tuple[int, int, int]]] = None, - interpolations: Optional[List[str]] = None, - means: Optional[List[Tuple[float, float, float]]] = None, - stds: Optional[List[Tuple[float, float, float]]] = None, - **kwargs: str, - ) -> None: - """ - Initialize a PrismaticImageProcessor as a wrapper around a torchvision transform; this transform will be - created by TIMM, and edited to follow our custom `image_resize_strategy` logic. - @param use_fused_vision_backbone: Boolean indicating single or fused (dual) vision backbone - @param image_resize_strategy: Prismatic image resize strategy in < resize-naive | resize-crop | letterbox > - @param input_size: [TIMM :: `data_cfg`] Input image size as tuple (channels, width, height) - @param interpolation: [TIMM :: `data_cfg`] Interpolation as string (default: "bicubic") - @param mean: [TIMM :: `data_cfg`] Normalization mean as float tuple (or two-tuple if `fused_backbone`) - @param std: [TIMM :: `data_cfg`] Normalization std as float tuple (or two-tuple if `fused_backbone`) - """ - self.use_fused_vision_backbone = use_fused_vision_backbone - self.image_resize_strategy = image_resize_strategy - - # Handle `None` default values - input_sizes = [(3, 224, 224)] if input_sizes is None else input_sizes - means = [(0.5, 0.5, 0.5)] if means is None else means - stds = [(0.5, 0.5, 0.5)] if stds is None else stds - - # TIMM `data_cfg` Parameters - self.input_sizes, self.interpolations, self.means, self.stds = input_sizes, interpolations, means, stds - - # Grab torchvision transforms via TIMM =>> need to parse for specific "functional" transform values! - self.tvf_resize_params, self.tvf_crop_params, self.tvf_normalize_params = [], [], [] - self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None - - for idx in range(len(input_sizes)): - transform = timm.data.create_transform( - input_size=self.input_sizes[idx], - interpolation=self.interpolations[idx], - mean=self.means[idx], - std=self.stds[idx], - crop_pct=1.0, # Set to 1.0 to ignore cropping (initial Resize sets `input_size`) - crop_mode="center", # Default crop mode -- no-op when `crop_pct == 1.0` - is_training=False, # No image augmentations when loading the transform! - ) - - # [Validation] Ensure appropriate transform structure, expected sizes - if not ( - isinstance(transform, Compose) - and (len(transform.transforms) == 4) - and isinstance(transform.transforms[0], Resize) - and isinstance(transform.transforms[1], CenterCrop) - and isinstance(transform.transforms[2], ToTensor) - and isinstance(transform.transforms[3], Normalize) - and (transform.transforms[0].size == self.input_sizes[idx][-1]) - and (transform.transforms[1].size == self.input_sizes[idx][-2:]) - ): - raise ValueError(f"Unexpected TIMM image transformation structure/sizes: `{transform}`") - - # HF Image Processors *must* be JSON-serializable; as such, cannot have torchvision. as an attribute. - # => Instead, we're going to parse the transform and call "torchvision.transforms.functional" (`tvf`) - resize_t, crop_t, norm_t = transform.transforms[0], transform.transforms[1], transform.transforms[3] - self.tvf_resize_params.append( - { - "size": resize_t.size, - "interpolation": TVF.pil_modes_mapping[resize_t.interpolation], - "max_size": None, - "antialias": True, - } - ) - self.tvf_crop_params.append({"output_size": crop_t.size}) - self.tvf_normalize_params.append( - { - "mean": norm_t.mean.float().numpy().tolist(), - "std": norm_t.std.float().numpy().tolist(), - "inplace": False, - } - ) - self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None - - # Handle Prismatic `image_resize_strategy` - if self.image_resize_strategy == "resize-naive": - self.tvf_resize_params[idx]["size"] = (resize_t.size, resize_t.size) - elif self.image_resize_strategy == "letterbox": - self.tvf_do_letterbox, self.tvf_letterbox_fill = True, tuple([int(x * 255) for x in self.means[idx]]) - elif self.image_resize_strategy == "resize-crop": - pass - else: - raise ValueError(f"Image resize strategy `{self.image_resize_strategy}` is not supported!") - - # Dispatch **kwargs to super() - super().__init__(**kwargs) - - def apply_transform(self, img: Image.Image) -> torch.Tensor: - """Apply `functional` variant of TIMM's Transform = Compose([Resize -> CenterCrop -> ToTensor -> Normalize])""" - if self.tvf_do_letterbox: - img = letterbox_pad_transform(img, self.tvf_letterbox_fill) - - # [Contract] Fused Backbones expect "channel-stacked" inputs; we'll unpack on the model side! - imgs_t = [] - for idx in range(len(self.input_sizes)): - img_idx = TVF.resize(img, **self.tvf_resize_params[idx]) - img_idx = TVF.center_crop(img_idx, **self.tvf_crop_params[idx]) - img_idx_t = TVF.to_tensor(img_idx) - img_idx_t = TVF.normalize(img_idx_t, **self.tvf_normalize_params[idx]) - imgs_t.append(img_idx_t) - - # [Contract] `imgs_t` is a list of Tensors of shape [3, input_size, input_size]; stack along dim = 0 - img_t = torch.vstack(imgs_t) - - return img_t - - def preprocess( - self, - images: Union[Image.Image, List[Image.Image]], - return_tensors: Optional[Union[str, TensorType]] = None, - **_: str, - ) -> BatchFeature: - """ - Preprocess an image (or batch of images); note that unlike the `transformers :: BaseImageProcessor` we - explicitly only handle PIL.Image.Image instances for simplicity. - @param images: A (batch of) PIL.Image.Image instance(s) to preprocess. - @param return_tensors: BatchFeature default Tensor format (e.g., "pt" for torch); if None, returns np.ndarray - @return: Instance of `transformers :: BatchFeature` with a single key "pixel_values" - """ - if not isinstance(images, list): - images = [images] - - # Apply `self.img_transform` to each image (will return list of torch.Tensors); stack into "batched" Tensor - pixel_values = torch.stack([self.apply_transform(img.convert("RGB")) for img in images]) - - # Return BatchFeature =>> note that for compatibility, constructor expects Dict[str, np.ndarray], so we convert - return BatchFeature(data={"pixel_values": pixel_values.float().numpy()}, tensor_type=return_tensors) - - def __call__(self, images: Union[Image.Image, List[Image.Image]], **kwargs) -> BatchFeature: - return self.preprocess(images, **kwargs) - - -# === PrismaticProcessor =>> Wraps both ImageProcessor and Tokenizer === -# =>> https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/processing_llava.py -class PrismaticProcessor(ProcessorMixin): - attributes: ClassVar[List[str]] = ["image_processor", "tokenizer"] - image_processor_class: str = "AutoImageProcessor" - tokenizer_class: str = "AutoTokenizer" - - def __init__( - self, - image_processor: Optional[ImageProcessingMixin] = None, - tokenizer: Optional[PreTrainedTokenizerBase] = None, - ) -> None: - super().__init__(image_processor, tokenizer) - - def __call__( - self, - text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], - images: Union[Image.Image, List[Image.Image]], - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Optional[Union[bool, str, TruncationStrategy]] = None, - max_length: Optional[int] = None, - return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, - ) -> BatchFeature: - """ - Preprocess a given (batch) of text/images for a Prismatic VLM; forwards text to the underlying LLM's tokenizer, - forwards images to PrismaticImageProcessor. - @param text: The (batch) of text to encode; must be a string or list of strings. - @param images: A (batch of) PIL.Image.Image instance(s) to preprocess. - @param padding: Sequence padding strategy (if multiple specified) in < True = "longest" | "max_length" | False > - @param truncation: Truncation strategy for the output sequences; requires `max_length` to be specified - @param max_length: Maximum length (in tokens) to truncate - @param return_tensors: Type of return tensors (usually "pt" or TensorType.PYTORCH) - @return: BatchFeature with keys for `input_ids`, `attention_mask` and `pixel_values`. - """ - pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"] - text_inputs = self.tokenizer( - text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length - ) - - # [Validate] Need same number of images and text inputs! - if pixel_values.shape[0] != text_inputs.input_ids.shape[0]: - raise ValueError("Batch is malformed; expected same number of images and text inputs!") - - return BatchFeature(data={**text_inputs, "pixel_values": pixel_values}) - - # === Tokenizer Dispatch Utilities =>> check `PreTrainedTokenizerBase` for documentation === - def batch_decode( - self, - sequences: Union[List[int], List[List[int]], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor - skip_special_tokens: bool = False, - clean_up_tokenization_spaces: Optional[bool] = None, - **kwargs: str, - ) -> List[str]: - return self.tokenizer.batch_decode( - sequences=sequences, - skip_special_tokens=skip_special_tokens, - clean_up_tokenization_spaces=clean_up_tokenization_spaces, - **kwargs, - ) - - def decode( - self, - token_ids: Union[int, List[int], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor - skip_special_tokens: bool = False, - clean_up_tokenization_spaces: Optional[bool] = None, - **kwargs: str, - ) -> str: - return self.tokenizer.decode( - token_ids=token_ids, - skip_special_tokens=skip_special_tokens, - clean_up_tokenization_spaces=clean_up_tokenization_spaces, - **kwargs, - ) - - @property - def model_input_names(self) -> List[str]: - tokenizer_input_names = self.tokenizer.model_input_names - image_processor_input_names = self.image_processor.model_input_names - - return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) diff --git a/capvector-oft/prismatic/models/__init__.py b/capvector-oft/prismatic/models/__init__.py deleted file mode 100644 index 85a3ebb94a024811e1e567cab3d80d805f9a48f5..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/models/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .load import available_model_names, available_models, get_model_description, load, load_vla -from .materialize import get_llm_backbone_and_tokenizer, get_vision_backbone_and_transform, get_vlm diff --git a/capvector-oft/prismatic/models/action_heads.py b/capvector-oft/prismatic/models/action_heads.py deleted file mode 100644 index 2ee1d442a90c2c70538c8ae9860f6c4a938260db..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/models/action_heads.py +++ /dev/null @@ -1,211 +0,0 @@ -"""Implementations of various action heads, which serve as alternatives to VLM sequential token prediction.""" - -import math - -import numpy as np -import torch -import torch.nn as nn -from diffusers.schedulers.scheduling_ddim import DDIMScheduler -from prismatic.vla.constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX, NUM_ACTIONS_CHUNK, PROPRIO_DIM, STOP_INDEX - - -class SinusoidalPositionalEncoding(nn.Module): - """ - Sine- and cosine-based positional encoding that produces embeddings of a batch of timesteps. - - For example, at train time, the input might be a batch of 32 randomly sampled diffusion timesteps -> shape (32,) - Then the output would be a batch of 32 timestep embeddings -> shape (32, D) - - Adapted from: https://github.com/real-stanford/diffusion_policy/blob/main/diffusion_policy/model/diffusion/positional_embedding.py - """ - - def __init__(self, dim): - super().__init__() - self.dim = dim # dimensionality of the positional encoding - - def forward(self, x): - # x: (batch_size,) - device = x.device - assert self.dim % 2 == 0, f"# dimensions must be even but got {self.dim}" - half_dim = self.dim // 2 - exponent = torch.arange(half_dim, device=device) * -math.log(10000) / (half_dim - 1) # shape: (D/2,) - emb = torch.exp(exponent) # shape: (D/2,) - emb = x[:, None] * emb[None, :] # shape: (batch_size, 1) * (1, D/2) -> (batch_size, D/2) - emb = torch.cat((emb.sin(), emb.cos()), dim=-1) # shape: (batch_size, D) - return emb - - -class MLPResNetBlock(nn.Module): - """One MLP ResNet block with a residual connection.""" - def __init__(self, dim): - super().__init__() - self.dim = dim - self.ffn = nn.Sequential( # feedforward network, similar to the ones in Transformers - nn.LayerNorm(dim), - nn.Linear(dim, dim), - nn.ReLU(), - ) - - def forward(self, x): - # x: (batch_size, hidden_dim) - # We follow the module ordering of "Pre-Layer Normalization" feedforward networks in Transformers as - # described here: https://arxiv.org/pdf/2002.04745.pdf - identity = x - x = self.ffn(x) - x = x + identity - return x - - -class MLPResNet(nn.Module): - """MLP with residual connection blocks.""" - def __init__(self, num_blocks, input_dim, hidden_dim, output_dim): - super().__init__() - self.layer_norm1 = nn.LayerNorm(input_dim) - self.fc1 = nn.Linear(input_dim, hidden_dim) - self.relu = nn.ReLU() - self.mlp_resnet_blocks = nn.ModuleList() - for _ in range(num_blocks): - self.mlp_resnet_blocks.append(MLPResNetBlock(dim=hidden_dim)) - self.layer_norm2 = nn.LayerNorm(hidden_dim) - self.fc2 = nn.Linear(hidden_dim, output_dim) - - def forward(self, x): - # x: (batch_size, input_dim) - x = self.layer_norm1(x) # shape: (batch_size, input_dim) - x = self.fc1(x) # shape: (batch_size, hidden_dim) - x = self.relu(x) # shape: (batch_size, hidden_dim) - for block in self.mlp_resnet_blocks: - x = block(x) # shape: (batch_size, hidden_dim) - x = self.layer_norm2(x) # shape: (batch_size, hidden_dim) - x = self.fc2(x) # shape: (batch_size, output_dim) - return x - - -class L1RegressionActionHead(nn.Module): - """Simple MLP-based action head that generates continuous actions via L1 regression.""" - def __init__( - self, - input_dim=4096, - hidden_dim=4096, - action_dim=7, - ): - super().__init__() - self.action_dim = action_dim - self.model = MLPResNet( - num_blocks=2, input_dim=input_dim*ACTION_DIM, hidden_dim=hidden_dim, output_dim=action_dim - ) - - def predict_action(self, actions_hidden_states): - # actions_hidden_states: last hidden states of Transformer corresponding to action tokens in sequence - # - shape: (batch_size, chunk_len * action_dim, hidden_dim) - # ground_truth_actions: ground-truth actions - # - shape: (batch_size, chunk_len, action_dim) - batch_size = actions_hidden_states.shape[0] - device = actions_hidden_states.device - rearranged_actions_hidden_states = actions_hidden_states.reshape(batch_size, NUM_ACTIONS_CHUNK, -1) - action = self.model(rearranged_actions_hidden_states) - return action - - -class NoisePredictionModel(nn.Module): - """ - Diffusion noise prediction model that takes an observation embedding (which fuses the - noisy action, diffusion timestep, and image-language observation embeddings) and - outputs a noise prediction. - """ - - def __init__( - self, - transformer_hidden_dim, # Transformer hidden embedding size - hidden_dim, # MLP hidden size - action_dim=7, # action dimensionality - ): - super().__init__() - self.mlp_resnet = MLPResNet( - num_blocks=2, - input_dim=transformer_hidden_dim, - hidden_dim=hidden_dim, - output_dim=action_dim, - ) - - def forward( - self, - obs, - ): - # obs: observation embeddings to condition the generation on - # - shape: (batch_size, chunk_len, rearranged_hidden_dim=action_dim*hidden_dim) - # - # output: predicted noise - # - shape: (batch_size, action_dim) - output = self.mlp_resnet(obs) - return output - - -class DiffusionActionHead(nn.Module): - """ - Simple MLP-based action head that generates continuous actions via conditional denoising diffusion process. - - Loosely inspired by: https://github.com/real-stanford/diffusion_policy/blob/main/diffusion_policy/model/diffusion/transformer_for_diffusion.py - """ - - def __init__( - self, - input_dim=4096, - hidden_dim=4096, - action_dim=7, - num_diffusion_steps_train=50, - ): - super().__init__() - self.action_dim = action_dim - self.noise_predictor = NoisePredictionModel( - transformer_hidden_dim=hidden_dim*ACTION_DIM, hidden_dim=hidden_dim, action_dim=action_dim - ) - self.num_diffusion_steps_train = num_diffusion_steps_train - self.noise_scheduler = DDIMScheduler(num_train_timesteps=num_diffusion_steps_train, beta_schedule="squaredcos_cap_v2") - self.time_encoder = SinusoidalPositionalEncoding(dim=hidden_dim) - - def sample_noisy_actions(self, ground_truth_actions): - """ - Samples noise and applies noise to ground-truth actions to produce noisy actions, which are - used as input in the noise prediction network. Returns noise, noisy actions, and the - corresponding diffusion timestep embeddings. - """ - # ground_truth_actions: ground-truth actions - # - shape: (batch_size, chunk_len, action_dim) - batch_size = ground_truth_actions.shape[0] - device = ground_truth_actions.device - # Sample random noise with shape equal to actions, used for closed-form forward diffusion. - noise = torch.randn(size=(batch_size, NUM_ACTIONS_CHUNK, ACTION_DIM), device=device, dtype=ground_truth_actions.dtype) # (B, chunk_len, action_dim) - # Sample random diffusion timesteps (one for each action in batch). - timesteps = torch.randint( - low=0, high=self.noise_scheduler.config.num_train_timesteps, size=(batch_size,), device=device - ) - # Add noise to clean actions according to the magnitude at each diffusion timestep via - # closed-form forward diffusion. - noisy_actions = self.noise_scheduler.add_noise(ground_truth_actions, noise, timesteps) # (B, chunk_len, action_dim) - - # Get diffusion timestep embeddings as well - diffusion_timestep_embeddings = self.time_encoder(timesteps).to(noisy_actions.dtype).to(noisy_actions.device) # (B, llm_dim) - diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) # (B, 1, llm_dim) - - return_dict = dict( - noise=noise, - noisy_actions=noisy_actions, - diffusion_timestep_embeddings=diffusion_timestep_embeddings, - ) - - return return_dict - - def predict_noise(self, actions_hidden_states): - """ - Given a batch of last hidden Transformer layer embeddings (which fuse the vision-language observation embeddings, - noisy action embeddings, and diffusion timestep embedding), predicts the noise applied to the actions. - """ - # actions_hidden_states: last hidden states of Transformer corresponding to action tokens in sequence - # - shape: (batch_size, chunk_len * action_dim, hidden_dim) - batch_size = actions_hidden_states.shape[0] - device = actions_hidden_states.device - rearranged_actions_hidden_states = actions_hidden_states.reshape(batch_size, NUM_ACTIONS_CHUNK, -1) # (batch_size, chunk_len, action_dim * hidden_dim) - # Get diffusion model's noise prediction. - noise_pred = self.noise_predictor(rearranged_actions_hidden_states) - return noise_pred diff --git a/capvector-oft/prismatic/models/backbones/__init__.py b/capvector-oft/prismatic/models/backbones/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/capvector-oft/prismatic/models/backbones/llm/__init__.py b/capvector-oft/prismatic/models/backbones/llm/__init__.py deleted file mode 100644 index a040f37d9c4f1e5354b74b6e24483c76fc11bf2c..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/models/backbones/llm/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .base_llm import LLMBackbone -from .llama2 import LLaMa2LLMBackbone -from .mistral import MistralLLMBackbone -from .phi import PhiLLMBackbone diff --git a/capvector-oft/prismatic/models/backbones/llm/base_llm.py b/capvector-oft/prismatic/models/backbones/llm/base_llm.py deleted file mode 100644 index fab6d971e71f8251421c91fbcb009f6c882111ae..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/models/backbones/llm/base_llm.py +++ /dev/null @@ -1,223 +0,0 @@ -""" -base_llm.py - -Abstract class definition of a large (autoregressive) language model backbone (LLM), with full annotations of class -methods, utility functions, and initialization logic. - -We also define the generic HFLLMBackbone class here, providing a default interface for loading any HF -AutoModelForCausalLM (e.g., LLamaForCausalLM). In general, we make the assumption that any given LLM backbone implements -the AutoModelForCausalLM API (though we may add Seq2Seq models in the future). - -We make this assumption to keep the LLM handling in this codebase relatively lightweight, and to inherit all the nice HF -utilities around different types of decoding/generation strategies. -""" - -import warnings -from abc import ABC, abstractmethod -from functools import partial -from typing import Callable, List, Optional, Sequence, Type - -import torch -import torch.nn as nn -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy -from transformers import AutoConfig, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase -from transformers.modeling_outputs import CausalLMOutputWithPast - -from prismatic.models.backbones.llm.prompting import PromptBuilder -from prismatic.overwatch import initialize_overwatch - -# Suppress HF Deprecation Warnings -warnings.filterwarnings("ignore", category=FutureWarning) - -# Initialize Overwatch =>> Wraps `logging.Logger` -overwatch = initialize_overwatch(__name__) - - -# === Abstract Base Class for arbitrary HF LLM Backbones === -class LLMBackbone(nn.Module, ABC): - def __init__(self, llm_backbone_id: str) -> None: - super().__init__() - self.identifier = llm_backbone_id - - # Instance attributes for an LLM Backbone - self.llm: PreTrainedModel = None - self.tokenizer: PreTrainedTokenizerBase = None - - def get_tokenizer(self) -> PreTrainedTokenizerBase: - return self.tokenizer - - @abstractmethod - def get_fsdp_wrapping_policy(self) -> Callable: ... - - @abstractmethod - def enable_gradient_checkpointing(self) -> None: ... - - @abstractmethod - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> CausalLMOutputWithPast: - """Run a forward pass through the LLM given targets (labels), returning the scalar Cross-Entropy Loss""" - raise NotImplementedError - - @abstractmethod - def embed_input_ids(self, input_ids: torch.LongTensor) -> torch.Tensor: ... - - @property - @abstractmethod - def prompt_builder_fn(self) -> Type[PromptBuilder]: ... - - @property - @abstractmethod - def transformer_layer_cls(self) -> Type[nn.Module]: ... - - @property - @abstractmethod - def half_precision_dtype(self) -> torch.dtype: ... - - @property - @abstractmethod - def last_layer_finetune_modules(self) -> Sequence[nn.Module]: ... - - @property - def embed_dim(self) -> int: - return self.llm.config.hidden_size - - @property - def pad_token_id(self) -> int: - return self.tokenizer.pad_token_id - - -# === Abstract Base Class for Arbitrary HF Causal LLMs === -class HFCausalLLMBackbone(LLMBackbone, ABC): - def __init__( - self, - llm_backbone_id: str, - llm_family: str, - llm_cls: Type[PreTrainedModel], - hf_hub_path: str, - llm_max_length: int = 2048, - hf_token: Optional[str] = None, - inference_mode: bool = False, - use_flash_attention_2: bool = False, - ) -> None: - super().__init__(llm_backbone_id) - self.llm_family = llm_family - self.llm_max_length = llm_max_length - self.inference_mode = inference_mode - - # Initialize LLM (downloading from HF Hub if necessary) --> `llm_cls` is the actual {Model}ForCausalLM class! - # => Note: We're eschewing use of the AutoModel API so that we can be more explicit about LLM-specific details - if not self.inference_mode: - overwatch.info(f"Loading [bold]{llm_family}[/] LLM from [underline]`{hf_hub_path}`[/]", ctx_level=1) - self.llm = llm_cls.from_pretrained( - hf_hub_path, - token=hf_token, - use_flash_attention_2=use_flash_attention_2 if not self.inference_mode else False, - # The following parameters are set to prevent `UserWarnings` from HF; we want greedy decoding! - do_sample=False, - temperature=1.0, - top_p=1.0, - ) - - # [Contract] `inference_mode` means we're loading from a pretrained checkpoint; no need to load base weights! - else: - overwatch.info(f"Building empty [bold]{llm_family}[/] LLM from [underline]`{hf_hub_path}`[/]", ctx_level=1) - llm_config = AutoConfig.from_pretrained(hf_hub_path, token=hf_token) - self.llm = llm_cls._from_config(llm_config) - - # Lightweight Handling (with extended explanation) for setting some LLM Parameters - # => Set `decoder.use_cache = False` --> incompatible with gradient checkpointing (+ training in general) - # - # Reference: https://discuss.huggingface.co/t/what-is-the-purpose-of-use-cache-in-decoder/958 - self.llm.config.use_cache = False if not self.inference_mode else True - - # => Turns out that when gradient checkpointing is on and the underlying LLM has no "trainable" parameters - # (requires_grad is False), backprop will fail; setting `enable_input_requires_grad()` registers a new - # forward hook that fixes this =>> also totally safe for the "full finetuning" setting! - if not self.inference_mode: - self.llm.enable_input_require_grads() - - # Load (Fast) Tokenizer - overwatch.info(f"Loading [bold]{llm_family}[/] (Fast) Tokenizer via the AutoTokenizer API", ctx_level=1) - self.tokenizer = AutoTokenizer.from_pretrained( - hf_hub_path, model_max_length=self.llm_max_length, token=hf_token, padding_side="right" - ) - - # Validation =>> Our VLM logic currently operates under the assumption that the tokenization of a new input - # starts with a token unless `add_special_tokens = False`; for these models, we empirically - # find that adding image patches *after* the BOS leads to much better performance. - # - # As a result we explicitly validate that a tokenizer conforms to the expected behavior; if you're reading this - # line, it's probably because you're adding a new LLM with a different tokenizer behavior. If so, feel free to - # override the `SPECIAL_CASES` set below, but make sure to make the appropriate changes in the `datasets.py` - # and VLM `forward()` logic! - SPECIAL_CASES = { - # Phi-2 Tokenizer doesn't add any BOS tokens by default, and sets BOS == EOS == "<|endoftext|>" - # =>> We'll prepend BOS to first input (to play nicely with image token insertion logic; verified that - # this works well with base LLM generation. - # =>> Like Llama-2 Tokenizers -- we'll add a special PAD token for training purposes. - "phi-2-3b", - } - if self.identifier in SPECIAL_CASES: - return - - # Note =>> this assert should hold for all Llama-derived tokenizers (`LlamaTokenizerFast` ==> includes Mistral! - assert (self.tokenizer("Test 123", add_special_tokens=True).input_ids[0] == self.tokenizer.bos_token_id) and ( - self.tokenizer("Test 123", add_special_tokens=False).input_ids[0] != self.tokenizer.bos_token_id - ), ( - f"Default Tokenizer of type `{type(self.tokenizer)}` does not automatically prefix inputs with BOS token!\n" - "Please read the comment in `base_llm.py` for more information!" - ) - - def get_fsdp_wrapping_policy(self) -> Callable: - """Return a `transformer_auto_wrap_policy` where we wrap each instance of `self.transformer_layer_cls`""" - transformer_block_policy = partial( - transformer_auto_wrap_policy, transformer_layer_cls={self.transformer_layer_cls} - ) - - return transformer_block_policy - - def enable_gradient_checkpointing(self) -> None: - """Dispatch to underlying LLM instance's `gradient_checkpointing_enable`; defined for all `PretrainedModel`.""" - self.llm.gradient_checkpointing_enable() - - def embed_input_ids(self, input_ids: torch.LongTensor) -> torch.Tensor: - return self.llm.get_input_embeddings()(input_ids) - - # [Contract] Should match the `forward` call of the underlying `llm` instance! - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> CausalLMOutputWithPast: - output: CausalLMOutputWithPast = self.llm( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - labels=labels, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - return output diff --git a/capvector-oft/prismatic/models/backbones/llm/llama2.py b/capvector-oft/prismatic/models/backbones/llm/llama2.py deleted file mode 100644 index 559409e2e54d7c0c4d06f2691d104ffba991b48f..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/models/backbones/llm/llama2.py +++ /dev/null @@ -1,102 +0,0 @@ -""" -llama2.py - -Class definition for all LLMs derived from LlamaForCausalLM. -""" - -from typing import Optional, Sequence, Type - -import torch -from torch import nn as nn -from transformers import LlamaForCausalLM -from transformers.models.llama.modeling_llama import LlamaDecoderLayer - -from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone -from prismatic.models.backbones.llm.prompting import ( - LLaMa2ChatPromptBuilder, - PromptBuilder, - PurePromptBuilder, - VicunaV15ChatPromptBuilder, -) - -# Registry =>> Support LLaMa-2 Models (from HF Transformers) -# fmt: off -LLAMA2_MODELS = { - # === Pure Meta LLaMa-2 (non-instruct/chat-tuned) Models === - "llama2-7b-pure": { - "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-7b-hf" - }, - - "llama2-13b-pure": { - "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-13b-hf" - }, - - # === Meta LLaMa-2 Chat Models === - "llama2-7b-chat": { - "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-7b-chat-hf" - }, - - "llama2-13b-chat": { - "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-13b-chat-hf" - }, - - # === Vicuna v1.5 Chat Models === - "vicuna-v15-7b": { - "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "lmsys/vicuna-7b-v1.5" - }, - - "vicuna-v15-13b": { - "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "lmsys/vicuna-13b-v1.5" - }, -} -# fmt: on - - -class LLaMa2LLMBackbone(HFCausalLLMBackbone): - def __init__( - self, - llm_backbone_id: str, - llm_max_length: int = 2048, - hf_token: Optional[str] = None, - inference_mode: bool = False, - use_flash_attention_2: bool = True, - ) -> None: - super().__init__( - llm_backbone_id, - llm_max_length=llm_max_length, - hf_token=hf_token, - inference_mode=inference_mode, - use_flash_attention_2=use_flash_attention_2, - **LLAMA2_MODELS[llm_backbone_id], - ) - - # [Special Case] LLaMa-2 PAD Token Handling --> for clarity, we add an extra token (and resize) - self.tokenizer.add_special_tokens({"pad_token": ""}) - self.llm.config.pad_token_id = self.tokenizer.pad_token_id - self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64) - - @property - def prompt_builder_fn(self) -> Type[PromptBuilder]: - if self.identifier.startswith("llama2-") and self.identifier.endswith("-pure"): - return PurePromptBuilder - - elif self.identifier.startswith("llama2-") and self.identifier.endswith("-chat"): - return LLaMa2ChatPromptBuilder - - elif self.identifier.startswith("vicuna"): - return VicunaV15ChatPromptBuilder - - raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`") - - @property - def transformer_layer_cls(self) -> Type[nn.Module]: - return LlamaDecoderLayer - - @property - def half_precision_dtype(self) -> torch.dtype: - """LLaMa-2 was trained in BF16; see https://huggingface.co/docs/transformers/main/model_doc/llama2.""" - return torch.bfloat16 - - @property - def last_layer_finetune_modules(self) -> Sequence[nn.Module]: - return (self.llm.model.embed_tokens, self.llm.model.layers[-1], self.llm.lm_head) diff --git a/capvector-oft/prismatic/models/backbones/llm/mistral.py b/capvector-oft/prismatic/models/backbones/llm/mistral.py deleted file mode 100644 index 0d2a41fc33a1acc5f92860625c0018736358338e..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/models/backbones/llm/mistral.py +++ /dev/null @@ -1,72 +0,0 @@ -""" -mistral.py - -Class definition for all LLMs derived from MistralForCausalLM. -""" - -from typing import Optional, Type - -import torch -from torch import nn as nn -from transformers import MistralForCausalLM -from transformers.models.mistral.modeling_mistral import MistralDecoderLayer - -from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone -from prismatic.models.backbones.llm.prompting import MistralInstructPromptBuilder, PromptBuilder, PurePromptBuilder - -# Registry =>> Support Mistral Models (from HF Transformers) -# fmt: off -MISTRAL_MODELS = { - # === Base Mistral v0.1 === - "mistral-v0.1-7b-pure": { - "llm_family": "mistral", "llm_cls": MistralForCausalLM, "hf_hub_path": "mistralai/Mistral-7B-v0.1" - }, - - # === Mistral Instruct v0.1 === - "mistral-v0.1-7b-instruct": { - "llm_family": "mistral", "llm_cls": MistralForCausalLM, "hf_hub_path": "mistralai/Mistral-7B-Instruct-v0.1" - } -} -# fmt: on - - -class MistralLLMBackbone(HFCausalLLMBackbone): - def __init__( - self, - llm_backbone_id: str, - llm_max_length: int = 2048, - hf_token: Optional[str] = None, - inference_mode: bool = False, - use_flash_attention_2: bool = True, - ) -> None: - super().__init__( - llm_backbone_id, - llm_max_length=llm_max_length, - hf_token=hf_token, - inference_mode=inference_mode, - use_flash_attention_2=use_flash_attention_2, - **MISTRAL_MODELS[llm_backbone_id], - ) - - # [Special Case] Mistral PAD Token Handling --> for clarity, we add an extra token (and resize) - self.tokenizer.add_special_tokens({"pad_token": ""}) - self.llm.config.pad_token_id = self.tokenizer.pad_token_id - self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64) - - @property - def prompt_builder_fn(self) -> Type[PromptBuilder]: - if self.identifier.endswith("-pure"): - return PurePromptBuilder - - elif self.identifier.endswith("-instruct"): - return MistralInstructPromptBuilder - - raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`") - - @property - def transformer_layer_cls(self) -> Type[nn.Module]: - return MistralDecoderLayer - - @property - def half_precision_dtype(self) -> torch.dtype: - return torch.bfloat16 diff --git a/capvector-oft/prismatic/models/backbones/llm/phi.py b/capvector-oft/prismatic/models/backbones/llm/phi.py deleted file mode 100644 index e9063b3f9ccd1b9b792400600a7fbb0a02150c1b..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/models/backbones/llm/phi.py +++ /dev/null @@ -1,64 +0,0 @@ -""" -phi.py - -Class definition for all LLMs derived from PhiForCausalLM. -""" - -from typing import Optional, Type - -import torch -from torch import nn as nn -from transformers import PhiForCausalLM -from transformers.models.phi.modeling_phi import PhiDecoderLayer - -from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone -from prismatic.models.backbones.llm.prompting import PhiPromptBuilder, PromptBuilder - -# Registry ==> Support Phi Models (from HF Transformers) -# fmt: off -PHI_MODELS = { - # === Phi-2 === - "phi-2-3b": { - "llm_family": "phi", "llm_cls": PhiForCausalLM, "hf_hub_path": "microsoft/phi-2" - } -} -# fmt: on - - -class PhiLLMBackbone(HFCausalLLMBackbone): - def __init__( - self, - llm_backbone_id: str, - llm_max_length: int = 2048, - hf_token: Optional[str] = None, - inference_mode: bool = False, - use_flash_attention_2: bool = True, - ) -> None: - super().__init__( - llm_backbone_id, - llm_max_length=llm_max_length, - hf_token=hf_token, - inference_mode=inference_mode, - use_flash_attention_2=use_flash_attention_2, - **PHI_MODELS[llm_backbone_id], - ) - - # [Special Case] Phi PAD Token Handling --> for clarity, we add an extra token (and resize) - self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) - self.llm.config.pad_token_id = self.tokenizer.pad_token_id - self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64) - - @property - def prompt_builder_fn(self) -> Type[PromptBuilder]: - if self.identifier.startswith("phi-2"): - return PhiPromptBuilder - - raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`") - - @property - def transformer_layer_cls(self) -> Type[nn.Module]: - return PhiDecoderLayer - - @property - def half_precision_dtype(self) -> torch.dtype: - return torch.bfloat16 diff --git a/capvector-oft/prismatic/models/backbones/llm/prompting/__init__.py b/capvector-oft/prismatic/models/backbones/llm/prompting/__init__.py deleted file mode 100644 index 5c73b61119d11b4d246f9e4a98d1aa70aa821621..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/models/backbones/llm/prompting/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .base_prompter import PromptBuilder, PurePromptBuilder -from .llama2_chat_prompter import LLaMa2ChatPromptBuilder -from .mistral_instruct_prompter import MistralInstructPromptBuilder -from .phi_prompter import PhiPromptBuilder -from .vicuna_v15_prompter import VicunaV15ChatPromptBuilder diff --git a/capvector-oft/prismatic/models/backbones/llm/prompting/base_prompter.py b/capvector-oft/prismatic/models/backbones/llm/prompting/base_prompter.py deleted file mode 100644 index 65c2e16d703d42c7849cd8de9a9cdbbd3361fee3..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/models/backbones/llm/prompting/base_prompter.py +++ /dev/null @@ -1,73 +0,0 @@ -""" -base_prompter.py - -Abstract class definition of a multi-turn prompt builder for ensuring consistent formatting for chat-based LLMs. -""" - -from abc import ABC, abstractmethod -from typing import Optional - - -class PromptBuilder(ABC): - def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: - self.model_family = model_family - - # Only some models define a system prompt => let subclasses handle this logic! - self.system_prompt = system_prompt - - @abstractmethod - def add_turn(self, role: str, message: str) -> str: ... - - @abstractmethod - def get_potential_prompt(self, user_msg: str) -> None: ... - - @abstractmethod - def get_prompt(self) -> str: ... - - -class PurePromptBuilder(PromptBuilder): - def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: - super().__init__(model_family, system_prompt) - - # TODO (siddk) =>> Can't always assume LlamaTokenizer --> FIX ME! - self.bos, self.eos = "", "" - - # Get role-specific "wrap" functions - self.wrap_human = lambda msg: f"In: {msg}\nOut: " - self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" - - # === `self.prompt` gets built up over multiple turns === - self.prompt, self.turn_count = "", 0 - - def add_turn(self, role: str, message: str) -> str: - assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") - message = message.replace("", "").strip() - - if (self.turn_count % 2) == 0: - human_message = self.wrap_human(message) - wrapped_message = human_message - else: - gpt_message = self.wrap_gpt(message) - wrapped_message = gpt_message - - # Update Prompt - self.prompt += wrapped_message - - # Bump Turn Counter - self.turn_count += 1 - - # Return "wrapped_message" (effective string added to context) - return wrapped_message - - def get_potential_prompt(self, message: str) -> None: - # Assumes that it's always the user's (human's) turn! - prompt_copy = str(self.prompt) - - human_message = self.wrap_human(message) - prompt_copy += human_message - - return prompt_copy.removeprefix(self.bos).rstrip() - - def get_prompt(self) -> str: - # Remove prefix (if exists) because it gets auto-inserted by tokenizer! - return self.prompt.removeprefix(self.bos).rstrip() diff --git a/capvector-oft/prismatic/models/backbones/llm/prompting/llama2_chat_prompter.py b/capvector-oft/prismatic/models/backbones/llm/prompting/llama2_chat_prompter.py deleted file mode 100644 index 3b5609aaec9f8f6b688182f96fed9a1db038ba95..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/models/backbones/llm/prompting/llama2_chat_prompter.py +++ /dev/null @@ -1,91 +0,0 @@ -""" -llama2_prompter.py - -Defines a PromptBuilder for building LLaMa-2 Chat Prompts --> not sure if this is "optimal", but this is the pattern -that's used by HF and other online tutorials. - -Reference: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 -""" - -from typing import Optional - -from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder - -# Default System Prompt for Prismatic Models -SYS_PROMPTS = { - "prismatic": ( - "You are a helpful language and vision assistant. " - "You are able to understand the visual content that the user provides, " - "and assist the user with a variety of tasks using natural language." - ), - "openvla": ( - "You are a helpful language and vision assistant. " - "You are able to understand the visual content that the user provides, " - "and assist the user with a variety of tasks using natural language." - ), -} - - -def format_system_prompt(system_prompt: str) -> str: - return f"<\n{system_prompt.strip()}\n<>\n\n" - - -class LLaMa2ChatPromptBuilder(PromptBuilder): - def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: - super().__init__(model_family, system_prompt) - self.system_prompt = format_system_prompt( - SYS_PROMPTS[self.model_family] if system_prompt is None else system_prompt - ) - - # LLaMa-2 Specific - self.bos, self.eos = "", "" - - # Get role-specific "wrap" functions - self.wrap_human = lambda msg: f"[INST] {msg} [/INST] " - self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" - - # === `self.prompt` gets built up over multiple turns === - self.prompt, self.turn_count = "", 0 - - def add_turn(self, role: str, message: str) -> str: - assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") - message = message.replace("", "").strip() - - # Special Handling for "system" prompt (turn_count == 0) - if self.turn_count == 0: - sys_message = self.wrap_human(self.system_prompt + message) - wrapped_message = sys_message - elif (self.turn_count % 2) == 0: - human_message = self.wrap_human(message) - wrapped_message = human_message - else: - gpt_message = self.wrap_gpt(message) - wrapped_message = gpt_message - - # Update Prompt - self.prompt += wrapped_message - - # Bump Turn Counter - self.turn_count += 1 - - # Return "wrapped_message" (effective string added to context) - return wrapped_message - - def get_potential_prompt(self, message: str) -> None: - # Assumes that it's always the user's (human's) turn! - prompt_copy = str(self.prompt) - - # Special Handling for "system" prompt (turn_count == 0) - if self.turn_count == 0: - sys_message = self.wrap_human(self.system_prompt + message) - prompt_copy += sys_message - - else: - human_message = self.wrap_human(message) - prompt_copy += human_message - - return prompt_copy.removeprefix(self.bos).rstrip() - - def get_prompt(self) -> str: - # Remove prefix because it gets auto-inserted by tokenizer! - return self.prompt.removeprefix(self.bos).rstrip() diff --git a/capvector-oft/prismatic/models/backbones/llm/prompting/mistral_instruct_prompter.py b/capvector-oft/prismatic/models/backbones/llm/prompting/mistral_instruct_prompter.py deleted file mode 100644 index e9a22b541ff9aabe3a69fa663dc422ed14acae31..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/models/backbones/llm/prompting/mistral_instruct_prompter.py +++ /dev/null @@ -1,60 +0,0 @@ -""" -mistral_instruct_prompter.py - -Defines a PromptBuilder for building Mistral Instruct Chat Prompts --> recommended pattern used by HF / online tutorial.s - -Reference: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format -""" - -from typing import Optional - -from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder - - -class MistralInstructPromptBuilder(PromptBuilder): - def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: - super().__init__(model_family, system_prompt) - - # Note =>> Mistral Tokenizer is an instance of `LlamaTokenizer(Fast)` - # =>> Mistral Instruct *does not* use a System Prompt - self.bos, self.eos = "", "" - - # Get role-specific "wrap" functions - self.wrap_human = lambda msg: f"[INST] {msg} [/INST] " - self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" - - # === `self.prompt` gets built up over multiple turns === - self.prompt, self.turn_count = "", 0 - - def add_turn(self, role: str, message: str) -> str: - assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") - message = message.replace("", "").strip() - - if (self.turn_count % 2) == 0: - human_message = self.wrap_human(message) - wrapped_message = human_message - else: - gpt_message = self.wrap_gpt(message) - wrapped_message = gpt_message - - # Update Prompt - self.prompt += wrapped_message - - # Bump Turn Counter - self.turn_count += 1 - - # Return "wrapped_message" (effective string added to context) - return wrapped_message - - def get_potential_prompt(self, message: str) -> None: - # Assumes that it's always the user's (human's) turn! - prompt_copy = str(self.prompt) - - human_message = self.wrap_human(message) - prompt_copy += human_message - - return prompt_copy.removeprefix(self.bos).rstrip() - - def get_prompt(self) -> str: - # Remove prefix because it gets auto-inserted by tokenizer! - return self.prompt.removeprefix(self.bos).rstrip() diff --git a/capvector-oft/prismatic/models/backbones/llm/prompting/phi_prompter.py b/capvector-oft/prismatic/models/backbones/llm/prompting/phi_prompter.py deleted file mode 100644 index 3843a33bc86a716002ef52a0e067bff5101407f2..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/models/backbones/llm/prompting/phi_prompter.py +++ /dev/null @@ -1,65 +0,0 @@ -""" -phi_prompter.py - -Defines a PromptBuilder for building Phi-2 Input/Output Prompts --> recommended pattern used by HF / Microsoft. -Also handles Phi special case BOS token additions. - -Reference: https://huggingface.co/microsoft/phi-2#qa-format -""" - -from typing import Optional - -from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder - - -class PhiPromptBuilder(PromptBuilder): - def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: - super().__init__(model_family, system_prompt) - - # Note =>> Phi Tokenizer is an instance of `CodeGenTokenizer(Fast)` - # =>> By default, does *not* append / tokens --> we handle that here (IMPORTANT)! - self.bos, self.eos = "<|endoftext|>", "<|endoftext|>" - - # Get role-specific "wrap" functions - # =>> Note that placement of / were based on experiments generating from Phi-2 in Input/Output mode - self.wrap_human = lambda msg: f"Input: {msg}\nOutput: " - self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}\n{self.eos}" - - # === `self.prompt` gets built up over multiple turns === - self.prompt, self.turn_count = "", 0 - - def add_turn(self, role: str, message: str) -> str: - assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") - message = message.replace("", "").strip() - - # Special Handling for "first" input --> prepend a token (expected by Prismatic) - if self.turn_count == 0: - bos_human_message = f"{self.bos}{self.wrap_human(message)}" - wrapped_message = bos_human_message - elif (self.turn_count % 2) == 0: - human_message = self.wrap_human(message) - wrapped_message = human_message - else: - gpt_message = self.wrap_gpt(message) - wrapped_message = gpt_message - - # Update Prompt - self.prompt += wrapped_message - - # Bump Turn Counter - self.turn_count += 1 - - # Return "wrapped_message" (effective string added to context) - return wrapped_message - - def get_potential_prompt(self, message: str) -> None: - # Assumes that it's always the user's (human's) turn! - prompt_copy = str(self.prompt) - - human_message = self.wrap_human(message) - prompt_copy += human_message - - return prompt_copy.rstrip() - - def get_prompt(self) -> str: - return self.prompt.rstrip() diff --git a/capvector-oft/prismatic/models/backbones/llm/prompting/vicuna_v15_prompter.py b/capvector-oft/prismatic/models/backbones/llm/prompting/vicuna_v15_prompter.py deleted file mode 100644 index 5ea246a16533f580332716361755f98f5aabe01d..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/models/backbones/llm/prompting/vicuna_v15_prompter.py +++ /dev/null @@ -1,82 +0,0 @@ -""" -vicuna_v15_prompter.py - -Defines a PromptBuilder for building Vicuna-v1.5 Chat Prompts. - -Reference: https://huggingface.co/lmsys/vicuna-13b-v1.5 -""" - -from typing import Optional - -from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder - -# Default System Prompt for LLaVa Models -SYS_PROMPTS = { - "prismatic": ( - "A chat between a curious user and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the user's questions." - ), - "openvla": ( - "A chat between a curious user and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the user's questions." - ), -} - - -class VicunaV15ChatPromptBuilder(PromptBuilder): - def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: - super().__init__(model_family, system_prompt) - self.system_prompt = (SYS_PROMPTS[self.model_family] if system_prompt is None else system_prompt).strip() + " " - - # LLaMa-2 Specific - self.bos, self.eos = "", "" - - # Get role-specific "wrap" functions - self.wrap_human = lambda msg: f"USER: {msg} ASSISTANT: " - self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" - - # === `self.prompt` gets built up over multiple turns === - self.prompt, self.turn_count = "", 0 - - def add_turn(self, role: str, message: str) -> str: - assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") - message = message.replace("", "").strip() - - # Special Handling for "system" prompt (turn_count == 0) - if self.turn_count == 0: - sys_message = self.system_prompt + self.wrap_human(message) - wrapped_message = sys_message - elif (self.turn_count % 2) == 0: - human_message = self.wrap_human(message) - wrapped_message = human_message - else: - gpt_message = self.wrap_gpt(message) - wrapped_message = gpt_message - - # Update Prompt - self.prompt += wrapped_message - - # Bump Turn Counter - self.turn_count += 1 - - # Return "wrapped_message" (effective string added to context) - return wrapped_message - - def get_potential_prompt(self, message: str) -> None: - # Assumes that it's always the user's (human's) turn! - prompt_copy = str(self.prompt) - - # Special Handling for "system" prompt (turn_count == 0) - if self.turn_count == 0: - sys_message = self.system_prompt + self.wrap_human(message) - prompt_copy += sys_message - - else: - human_message = self.wrap_human(message) - prompt_copy += human_message - - return prompt_copy.removeprefix(self.bos).rstrip() - - def get_prompt(self) -> str: - # Remove prefix (if exists) because it gets auto-inserted by tokenizer! - return self.prompt.removeprefix(self.bos).rstrip() diff --git a/capvector-oft/prismatic/models/backbones/vision/__init__.py b/capvector-oft/prismatic/models/backbones/vision/__init__.py deleted file mode 100644 index 3c6da9a186cb68050eb11688b20177fc0ee4359c..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/models/backbones/vision/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .base_vision import ImageTransform, VisionBackbone -from .clip_vit import CLIPViTBackbone -from .dinoclip_vit import DinoCLIPViTBackbone -from .dinosiglip_vit import DinoSigLIPViTBackbone -from .dinov2_vit import DinoV2ViTBackbone -from .in1k_vit import IN1KViTBackbone -from .siglip_vit import SigLIPViTBackbone diff --git a/capvector-oft/prismatic/models/backbones/vision/base_vision.py b/capvector-oft/prismatic/models/backbones/vision/base_vision.py deleted file mode 100644 index 8268c4dd53c2caa7bc98efbf1818644d46515cd3..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/models/backbones/vision/base_vision.py +++ /dev/null @@ -1,207 +0,0 @@ -""" -base_vision.py - -Abstract class definition of a Vision Backbone (Visual Featurizer), with full annotations of class methods, utility -functions, and initialization logic. - -We also define the generic TimmViTBackbone class here, providing a default interface for loading any TIMM Vision -Transformer model for feature extraction. -""" - -from abc import ABC, abstractmethod -from dataclasses import dataclass -from functools import partial -from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union - -import timm -import torch -import torch.nn as nn -import torchvision.transforms.functional as TVF -from PIL.Image import Image -from timm.models.vision_transformer import Block, VisionTransformer -from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy -from torchvision.transforms import Compose, Resize - - -# === Utility Functions for Monkey-Patching === -def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]: - def wrapper(*args: Any, **kwargs: Any) -> Any: - result = fn(*args, **kwargs) - return result[0] if isinstance(result, tuple) else result - - return wrapper - - -# === Interface for an Image Transform === -class ImageTransform(Protocol): - def __call__(self, img: Image, **kwargs: str) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: ... - - -# === Custom Torchvision Image Transforms === -@dataclass -class LetterboxPad: - padding_fill_value: Tuple[int, int, int] - - def __call__(self, image: Image) -> Image: - """Given a PIL.Image, pad to square by adding a symmetric border around the height/width.""" - (w, h), max_wh = image.size, max(image.size) - horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2) - padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad) - return TVF.pad(image, padding, fill=self.padding_fill_value, padding_mode="constant") - - -# === Abstract Base Class for arbitrary Vision Backbones === -class VisionBackbone(nn.Module, ABC): - def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: - super().__init__() - self.identifier: str = vision_backbone_id - self.image_resize_strategy: str = image_resize_strategy - self.default_image_size: int = default_image_size - - # Instance attributes for a Vision Backbone - self.featurizer: nn.Module = None - self.image_transform: ImageTransform = None - - def get_image_transform(self) -> ImageTransform: - return self.image_transform - - @abstractmethod - def get_fsdp_wrapping_policy(self) -> Callable: ... - - @abstractmethod - def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: - """Run a forward pass through the featurizer given a set of processed images, returning patch/grid features.""" - raise NotImplementedError - - @property - @abstractmethod - def default_image_resolution(self) -> Tuple[int, int, int]: ... - - @property - @abstractmethod - def embed_dim(self) -> int: ... - - @property - @abstractmethod - def num_patches(self) -> int: ... - - @property - @abstractmethod - def half_precision_dtype(self) -> torch.dtype: ... - - -# === Abstract Base Class for Arbitrary TIMM Vision Transformer Backbones === -class TimmViTBackbone(VisionBackbone, ABC): - def __init__( - self, - vision_backbone_id: str, - timm_path_or_url: str, - image_resize_strategy: str, - default_image_size: int = 224, - override_act_layer: Optional[str] = None, - ) -> None: - super().__init__(vision_backbone_id, image_resize_strategy, default_image_size=default_image_size) - self.timm_path_or_url = timm_path_or_url - self.override_act_layer = override_act_layer - self.dtype = torch.bfloat16 - - # Initialize Featurizer (ViT) by downloading from HF / TIMM Hub if necessary - if self.override_act_layer is None: - self.featurizer: VisionTransformer = timm.create_model( - self.timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size - ) - else: - self.featurizer: VisionTransformer = timm.create_model( - self.timm_path_or_url, - pretrained=True, - num_classes=0, - img_size=self.default_image_size, - act_layer=self.override_act_layer, - ) - self.featurizer.eval() - - # Monkey-Patch the `forward()` function of the featurizer to ensure FSDP-compatibility - # => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches! - # => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385 - self.featurizer.forward = unpack_tuple( - partial(self.featurizer.get_intermediate_layers, n={len(self.featurizer.blocks) - 2}) - ) - - # Validation =>> for now, this class *only* supports TIMM Vision Transformers (but can be extended!) - assert isinstance(self.featurizer, VisionTransformer), ( - "Featurizer is not a TIMM VisionTransformer; if you would like to support a new visual representation, " - "file an issue or implement the requisite logic (see `prismatic/models/backbones/vision/base_vision.py`)!" - ) - - # Get Config =>> Note :: Override default image size to ensure correct image transform - self.data_cfg = timm.data.resolve_model_data_config(self.featurizer) - self.data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size) - - # Initialize Default Image Transform --> Modified by `self.image_resize_strategy` - default_image_transform = timm.data.create_transform(**self.data_cfg, is_training=False) - - # Fix =>> SigLIP & IN1K default transforms resize to *larger* than `self.default_image_size` (crops image)! - if "siglip" in self.timm_path_or_url or "in1k" in self.timm_path_or_url: - assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!" - assert isinstance(default_image_transform.transforms[0], Resize) - default_image_transform = Compose( - [ - Resize(self.default_image_size, interpolation=default_image_transform.transforms[0].interpolation), - *default_image_transform.transforms[1:], - ] - ) - - # Switch on `image_resize_strategy` - if self.image_resize_strategy == "resize-naive": - assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!" - assert isinstance(default_image_transform.transforms[0], Resize) - - target_size = (self.default_image_size, self.default_image_size) - self.image_transform = Compose( - [ - Resize(target_size, interpolation=default_image_transform.transforms[0].interpolation), - *default_image_transform.transforms[1:], - ] - ) - - elif self.image_resize_strategy == "resize-crop": - self.image_transform = default_image_transform - - elif self.image_resize_strategy == "letterbox": - assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!" - assert "mean" in self.data_cfg, "TIMM `data_cfg` missing image normalization mean!" - - # Compute Padding Fill Value (rescaled normalization mean if applicable) - fill = tuple([int(x * 255) for x in self.data_cfg["mean"]]) - - # Build New Transform - self.image_transform = Compose([LetterboxPad(fill), *default_image_transform.transforms]) - - else: - raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!") - - def get_fsdp_wrapping_policy(self) -> Callable: - """Return a simple FSDP policy that wraps each ViT block and then the _entire_ featurizer.""" - vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer}) - transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) - return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy]) - - def forward(self, pixel_values: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor: - """Runs transformed image/pixel tensor through vision backbone, returning _all_ patch features.""" - return self.featurizer(pixel_values) - - @property - def default_image_resolution(self) -> Tuple[int, int, int]: - return self.data_cfg["input_size"] - - @property - def embed_dim(self) -> int: - return self.featurizer.embed_dim - - @property - def num_patches(self) -> int: - return self.featurizer.patch_embed.num_patches - - @property - def half_precision_dtype(self) -> torch.dtype: - return self.dtype diff --git a/capvector-oft/prismatic/models/backbones/vision/clip_vit.py b/capvector-oft/prismatic/models/backbones/vision/clip_vit.py deleted file mode 100644 index 1023d0b8ee9500547c1649bfb3d82493e6c2659d..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/models/backbones/vision/clip_vit.py +++ /dev/null @@ -1,27 +0,0 @@ -""" -clip_vit.py -""" - -from prismatic.models.backbones.vision.base_vision import TimmViTBackbone - -# Registry =>> Supported CLIP Vision Backbones (from TIMM) -CLIP_VISION_BACKBONES = { - "clip-vit-b": "vit_base_patch16_clip_224.openai", - "clip-vit-l": "vit_large_patch14_clip_224.openai", - "clip-vit-l-336px": "vit_large_patch14_clip_336.openai", -} - - -# [IMPORTANT] By Default, TIMM initialized OpenAI CLIP models with the standard GELU activation from PyTorch. -# HOWEVER =>> Original OpenAI models were trained with the quick_gelu *approximation* -- while it's -# a decent approximation, the resulting features are *worse*; this was a super tricky bug -# to identify, but luckily there's an easy fix (`override_act_layer`) -class CLIPViTBackbone(TimmViTBackbone): - def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: - super().__init__( - vision_backbone_id, - CLIP_VISION_BACKBONES[vision_backbone_id], - image_resize_strategy, - default_image_size=default_image_size, - override_act_layer="quick_gelu" if CLIP_VISION_BACKBONES[vision_backbone_id].endswith(".openai") else None, - ) diff --git a/capvector-oft/prismatic/models/backbones/vision/dinoclip_vit.py b/capvector-oft/prismatic/models/backbones/vision/dinoclip_vit.py deleted file mode 100644 index 318598a3197ef72afb2ab10c7c48285a7e6d4284..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/models/backbones/vision/dinoclip_vit.py +++ /dev/null @@ -1,147 +0,0 @@ -""" -dinoclip_vit.py - -Vision backbone that returns concatenated features from both DINOv2 and CLIP. -""" - -from dataclasses import dataclass -from functools import partial -from typing import Callable, Dict, Tuple - -import timm -import torch -from PIL import Image -from timm.models.vision_transformer import Block, VisionTransformer -from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy -from torchvision.transforms import Compose, Resize - -from prismatic.models.backbones.vision.base_vision import ImageTransform, LetterboxPad, VisionBackbone, unpack_tuple - -# Registry =>> Supported DinoCLIP Pairs (as TIMM identifiers) -DINOCLIP_VISION_BACKBONES = { - "dinoclip-vit-l-336px": { - "dino": "vit_large_patch14_reg4_dinov2.lvd142m", - "clip": "vit_large_patch14_clip_336.openai", - }, -} - - -@dataclass -class DinoCLIPImageTransform: - dino_image_transform: ImageTransform - clip_image_transform: ImageTransform - is_prismatic: bool = True - - def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]: - return {"dino": self.dino_image_transform(img, **kwargs), "clip": self.clip_image_transform(img, **kwargs)} - - -class DinoCLIPViTBackbone(VisionBackbone): - def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: - super().__init__(vision_backbone_id, image_resize_strategy, default_image_size=default_image_size) - self.dino_timm_path_or_url = DINOCLIP_VISION_BACKBONES[vision_backbone_id]["dino"] - self.clip_timm_path_or_url = DINOCLIP_VISION_BACKBONES[vision_backbone_id]["clip"] - - # Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary - self.dino_featurizer: VisionTransformer = timm.create_model( - self.dino_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size - ) - self.dino_featurizer.eval() - - self.clip_featurizer: VisionTransformer = timm.create_model( - self.clip_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size - ) - self.clip_featurizer.eval() - - # Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility - # => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches! - # => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385 - self.dino_featurizer.forward = unpack_tuple( - partial(self.dino_featurizer.get_intermediate_layers, n={len(self.dino_featurizer.blocks) - 2}) - ) - self.clip_featurizer.forward = unpack_tuple( - partial(self.clip_featurizer.get_intermediate_layers, n={len(self.clip_featurizer.blocks) - 2}) - ) - - # Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models - self.dino_data_cfg = timm.data.resolve_model_data_config(self.dino_featurizer) - self.dino_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size) - - self.clip_data_cfg = timm.data.resolve_model_data_config(self.clip_featurizer) - self.clip_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size) - - # Initialize *both* Transforms - default_dino_transform = timm.data.create_transform(**self.dino_data_cfg, is_training=False) - default_clip_transform = timm.data.create_transform(**self.clip_data_cfg, is_training=False) - if self.image_resize_strategy == "resize-naive": - assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_image_transform`!" - assert isinstance(default_clip_transform, Compose), "Unexpected `default_clip_image_transform`!" - assert isinstance(default_dino_transform.transforms[0], Resize) - assert isinstance(default_clip_transform.transforms[0], Resize) - - target_size = (self.default_image_size, self.default_image_size) - dino_transform = Compose( - [ - Resize(target_size, interpolation=default_dino_transform.transforms[0].interpolation), - *default_dino_transform.transforms[1:], - ] - ) - clip_transform = Compose( - [ - Resize(target_size, interpolation=default_clip_transform.transforms[0].interpolation), - *default_clip_transform.transforms[1:], - ] - ) - - self.image_transform = DinoCLIPImageTransform(dino_transform, clip_transform) - - elif self.image_resize_strategy == "resize-crop": - self.image_transform = DinoCLIPImageTransform(default_dino_transform, default_clip_transform) - - elif self.image_resize_strategy == "letterbox": - assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_transform`!" - assert isinstance(default_clip_transform, Compose), "Unexpected `default_clip_transform`!" - assert "mean" in self.dino_data_cfg and "mean" in self.clip_data_cfg, "DinoCLIP `data_cfg` missing `mean`!" - - # Compute Padding Fill Value(s) (rescaled normalization mean if applicable) - dino_fill = tuple([int(x * 255) for x in self.dino_data_cfg["mean"]]) - clip_fill = tuple([int(x * 255) for x in self.clip_data_cfg["mean"]]) - - # Build New Transform - self.image_transform = DinoCLIPImageTransform( - Compose([LetterboxPad(dino_fill), *default_dino_transform.transforms]), - Compose([LetterboxPad(clip_fill), *default_clip_transform.transforms]), - ) - - else: - raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!") - - def get_fsdp_wrapping_policy(self) -> Callable: - """Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers.""" - vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer}) - transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) - return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy]) - - def forward(self, pixel_values: Dict[str, torch.Tensor]) -> torch.Tensor: - """Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches.""" - dino_patches = self.dino_featurizer(pixel_values["dino"]) - clip_patches = self.clip_featurizer(pixel_values["clip"]) - - return torch.cat([dino_patches, clip_patches], dim=2) - - @property - def default_image_resolution(self) -> Tuple[int, int, int]: - return self.dino_data_cfg["input_size"] - - @property - def embed_dim(self) -> int: - return self.dino_featurizer.embed_dim + self.clip_featurizer.embed_dim - - @property - def num_patches(self) -> int: - assert self.dino_featurizer.patch_embed.num_patches == self.clip_featurizer.patch_embed.num_patches - return self.dino_featurizer.patch_embed.num_patches - - @property - def half_precision_dtype(self) -> torch.dtype: - return torch.bfloat16 diff --git a/capvector-oft/prismatic/models/backbones/vision/dinosiglip_vit.py b/capvector-oft/prismatic/models/backbones/vision/dinosiglip_vit.py deleted file mode 100644 index c8762dadf0ff756bd9b12642917c21ac58c5d5eb..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/models/backbones/vision/dinosiglip_vit.py +++ /dev/null @@ -1,164 +0,0 @@ -""" -dinosiglip_vit.py - -Vision backbone that returns concatenated features from both DINOv2 and SigLIP. -""" - -from dataclasses import dataclass -from functools import partial -from typing import Callable, Dict, Tuple - -import timm -import torch -from PIL import Image -from timm.models.vision_transformer import Block, VisionTransformer -from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy -from torchvision.transforms import Compose, Resize - -from prismatic.models.backbones.vision.base_vision import ImageTransform, LetterboxPad, VisionBackbone, unpack_tuple - -# Registry =>> Supported DinoSigLIP Pairs (as TIMM identifiers) -DINOSigLIP_VISION_BACKBONES = { - "dinosiglip-vit-so-224px": { - "dino": "vit_large_patch14_reg4_dinov2.lvd142m", - "siglip": "vit_so400m_patch14_siglip_224", - }, - "dinosiglip-vit-so-384px": { - "dino": "vit_large_patch14_reg4_dinov2.lvd142m", - "siglip": "vit_so400m_patch14_siglip_384", - }, -} - - -@dataclass -class DinoSigLIPImageTransform: - dino_image_transform: ImageTransform - siglip_image_transform: ImageTransform - is_prismatic: bool = True - - def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]: - return {"dino": self.dino_image_transform(img, **kwargs), "siglip": self.siglip_image_transform(img, **kwargs)} - - -class DinoSigLIPViTBackbone(VisionBackbone): - def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: - super().__init__(vision_backbone_id, image_resize_strategy, default_image_size=default_image_size) - self.dino_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[vision_backbone_id]["dino"] - self.siglip_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[vision_backbone_id]["siglip"] - - # Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary - self.dino_featurizer: VisionTransformer = timm.create_model( - self.dino_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size - ) - self.dino_featurizer.eval() - - self.siglip_featurizer: VisionTransformer = timm.create_model( - self.siglip_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size - ) - self.siglip_featurizer.eval() - - # Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility - # => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches! - # => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385 - self.dino_featurizer.forward = unpack_tuple( - partial(self.dino_featurizer.get_intermediate_layers, n={len(self.dino_featurizer.blocks) - 2}) - ) - self.siglip_featurizer.forward = unpack_tuple( - partial(self.siglip_featurizer.get_intermediate_layers, n={len(self.siglip_featurizer.blocks) - 2}) - ) - - # Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models - self.dino_data_cfg = timm.data.resolve_model_data_config(self.dino_featurizer) - self.dino_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size) - - self.siglip_data_cfg = timm.data.resolve_model_data_config(self.siglip_featurizer) - self.siglip_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size) - - # Initialize *both* Transforms - default_dino_transform = timm.data.create_transform(**self.dino_data_cfg, is_training=False) - default_siglip_transform = timm.data.create_transform(**self.siglip_data_cfg, is_training=False) - - # Fix =>> SigLIP default transform resizes to *larger* than `self.default_image_size` (crops image)!! - assert isinstance(default_siglip_transform, Compose), "Unexpected `default_image_transform`!" - assert isinstance(default_siglip_transform.transforms[0], Resize) - default_siglip_transform = Compose( - [ - Resize(self.default_image_size, interpolation=default_siglip_transform.transforms[0].interpolation), - *default_siglip_transform.transforms[1:], - ] - ) - - if self.image_resize_strategy == "resize-naive": - assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_image_transform`!" - assert isinstance(default_siglip_transform, Compose), "Unexpected `default_siglip_image_transform`!" - assert isinstance(default_dino_transform.transforms[0], Resize) - assert isinstance(default_siglip_transform.transforms[0], Resize) - - target_size = (self.default_image_size, self.default_image_size) - dino_transform = Compose( - [ - Resize(target_size, interpolation=default_dino_transform.transforms[0].interpolation), - *default_dino_transform.transforms[1:], - ] - ) - siglip_transform = Compose( - [ - Resize(target_size, interpolation=default_siglip_transform.transforms[0].interpolation), - *default_siglip_transform.transforms[1:], - ] - ) - - self.image_transform = DinoSigLIPImageTransform(dino_transform, siglip_transform) - - elif self.image_resize_strategy == "resize-crop": - self.image_transform = DinoSigLIPImageTransform(default_dino_transform, default_siglip_transform) - - elif self.image_resize_strategy == "letterbox": - assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_transform`!" - assert isinstance(default_siglip_transform, Compose), "Unexpected `default_siglip_transform`!" - assert ( - "mean" in self.dino_data_cfg and "mean" in self.siglip_data_cfg - ), "DinoSigLIP `data_cfg` missing `mean`!" - - # Compute Padding Fill Value(s) (rescaled normalization mean if applicable) - dino_fill = tuple([int(x * 255) for x in self.dino_data_cfg["mean"]]) - siglip_fill = tuple([int(x * 255) for x in self.siglip_data_cfg["mean"]]) - - # Build New Transform - self.image_transform = DinoSigLIPImageTransform( - Compose([LetterboxPad(dino_fill), *default_dino_transform.transforms]), - Compose([LetterboxPad(siglip_fill), *default_siglip_transform.transforms]), - ) - - else: - raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!") - - def get_fsdp_wrapping_policy(self) -> Callable: - """Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers.""" - vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer}) - transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) - return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy]) - - def forward(self, pixel_values: Dict[str, torch.Tensor]) -> torch.Tensor: - """Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches.""" - dino_patches = self.dino_featurizer(pixel_values["dino"]) - siglip_patches = self.siglip_featurizer(pixel_values["siglip"]) - - return torch.cat([dino_patches, siglip_patches], dim=2) - - @property - def default_image_resolution(self) -> Tuple[int, int, int]: - return self.dino_data_cfg["input_size"] - - @property - def embed_dim(self) -> int: - return self.dino_featurizer.embed_dim + self.siglip_featurizer.embed_dim - - @property - def num_patches(self) -> int: - assert self.dino_featurizer.patch_embed.num_patches == self.siglip_featurizer.patch_embed.num_patches - return self.dino_featurizer.patch_embed.num_patches - - @property - def half_precision_dtype(self) -> torch.dtype: - return torch.bfloat16 diff --git a/capvector-oft/prismatic/models/backbones/vision/dinov2_vit.py b/capvector-oft/prismatic/models/backbones/vision/dinov2_vit.py deleted file mode 100644 index d36acee29fffd3a202d24c1ee1d7c55cd53fa8ea..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/models/backbones/vision/dinov2_vit.py +++ /dev/null @@ -1,19 +0,0 @@ -""" -dinov2_vit.py -""" - -from prismatic.models.backbones.vision.base_vision import TimmViTBackbone - -# Registry =>> Supported DINOv2 Vision Backbones (from TIMM) =>> Note:: Using DINOv2 w/ Registers! -# => Reference: https://arxiv.org/abs/2309.16588 -DINOv2_VISION_BACKBONES = {"dinov2-vit-l": "vit_large_patch14_reg4_dinov2.lvd142m"} - - -class DinoV2ViTBackbone(TimmViTBackbone): - def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: - super().__init__( - vision_backbone_id, - DINOv2_VISION_BACKBONES[vision_backbone_id], - image_resize_strategy, - default_image_size=default_image_size, - ) diff --git a/capvector-oft/prismatic/models/backbones/vision/in1k_vit.py b/capvector-oft/prismatic/models/backbones/vision/in1k_vit.py deleted file mode 100644 index ba8fb0ee919851e5b9698b998b35513371873f6a..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/models/backbones/vision/in1k_vit.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -in1k_vit.py - -Vision Transformers trained / finetuned on ImageNet (ImageNet-21K =>> ImageNet-1K) -""" - -from prismatic.models.backbones.vision.base_vision import TimmViTBackbone - -# Registry =>> Supported Vision Backbones (from TIMM) -IN1K_VISION_BACKBONES = { - "in1k-vit-l": "vit_large_patch16_224.augreg_in21k_ft_in1k", -} - - -class IN1KViTBackbone(TimmViTBackbone): - def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: - super().__init__( - vision_backbone_id, - IN1K_VISION_BACKBONES[vision_backbone_id], - image_resize_strategy, - default_image_size=default_image_size, - ) diff --git a/capvector-oft/prismatic/models/backbones/vision/siglip_vit.py b/capvector-oft/prismatic/models/backbones/vision/siglip_vit.py deleted file mode 100644 index 618ff087134003b1ee7d237d500cfe90a42ef798..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/models/backbones/vision/siglip_vit.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -siglip_vit.py -""" - -from prismatic.models.backbones.vision.base_vision import TimmViTBackbone - -# Registry =>> Supported SigLIP Vision Backbones (from TIMM) =>> Note:: Using SigLIP w/ Patch = 14 (but SO400M Arch) -SIGLIP_VISION_BACKBONES = { - "siglip-vit-b16-224px": "vit_base_patch16_siglip_224", - "siglip-vit-b16-256px": "vit_base_patch16_siglip_256", - "siglip-vit-b16-384px": "vit_base_patch16_siglip_384", - "siglip-vit-so400m": "vit_so400m_patch14_siglip_224", - "siglip-vit-so400m-384px": "vit_so400m_patch14_siglip_384", -} - - -class SigLIPViTBackbone(TimmViTBackbone): - def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: - super().__init__( - vision_backbone_id, - SIGLIP_VISION_BACKBONES[vision_backbone_id], - image_resize_strategy, - default_image_size=default_image_size, - ) diff --git a/capvector-oft/prismatic/models/film_vit_wrapper.py b/capvector-oft/prismatic/models/film_vit_wrapper.py deleted file mode 100644 index 696d8ed1ada59d0ac62e16827168017d8cf3365d..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/models/film_vit_wrapper.py +++ /dev/null @@ -1,276 +0,0 @@ -"""Implementation of additional modules for the VLA's vision transformer.""" - -from functools import partial -from typing import Any, Callable, Sequence, Tuple, Union - -import torch -import torch.nn as nn -from timm.models.vision_transformer import VisionTransformer - - -class FiLMedVisionTransformerBlock(nn.Module): - """ - Wrapper for ViT blocks that adds components to implement FiLM language conditioning. - - Modulates visual feature embeddings via - x = (1 + gamma) * x + beta, - where x is visual feature and gamma and beta are learned projections of the average language embedding. - gamma and beta have D dimensions each, where D is the number of hidden dimensions in the ViT's features. - - NOTE #1 (Moo Jin): - In convolutional neural architectures, the "feature" in FiLM is an entire feature map, i.e., each channel in a - convolutional layer (so gamma and beta have C dimensions, where C is the number of channels). Therefore, FiLM's - scaling and shifting is applied across all spatial locations for conv nets -- i.e., it is spatially agnostic. - - For vision transformer architectures, you may consider individual patch embeddings as individual "features" at first - instinct, but this would make FiLM scaling and shifting spatially local. In order to make the modulation spatially - global like in convolutional architectures, we should apply the scaling and shifting to each dimension of each patch - embedding. I.e., gamma and beta should have D dimensions, where D is the number of dimensions in a visual embedding. - - NOTE #2 (Moo Jin): - x = (1 + gamma) * x + beta is used in the original FiLM paper as opposed to x = gamma * x + beta (see section 7.2 in - https://arxiv.org/pdf/1709.07871.pdf). Since gamma and beta are close to zero upon initialization, this leads to an - identity transformation at the beginning of training, which minimizes perturbation to the pretrained representation. - """ - - def __init__( - self, - block, - vision_dim: int, - llm_dim: int, - ): - """ - Initializes FiLM ViT block wrapper. - - Args: - block (timm.models.vision_transformer.Block): Vision transformer block. - vision_dim (int): Number of hidden dimensions in visual embeddings. - llm_dim (int): Number of hidden dimensions in language embeddings. - """ - super().__init__() - self.block = block - # Initialize gamma and beta projectors - self.scale = nn.Linear(llm_dim, vision_dim) - self.shift = nn.Linear(llm_dim, vision_dim) - - def forward(self, x, average_language_embedding): - """ - Overrides the vision transformer block forward pass to use FiLM. - - Args: - x (torch.Tensor): Visual input embeddings, (batch_size, vision_seq_len, vision_dim). - average_language_embedding (torch.Tensor): Average language embedding for task, (batch_size, llm_dim). - """ - # Project average language embedding to visual embedding space to get gamma and beta - gamma = self.scale(average_language_embedding) # (batch_size, vision_dim) - beta = self.shift(average_language_embedding) # (batch_size, vision_dim) - - # Pass visual inputs through attention portion of original block - x = x + self.block.drop_path1(self.block.ls1(self.block.attn(self.block.norm1(x)))) - - # Modulate intermediate visual representations via FiLM - x = x * (1 + gamma.view(gamma.shape[0], 1, gamma.shape[1])) + beta.view(beta.shape[0], 1, beta.shape[1]) - - # Pass visual inputs through feedforward portion of original block - x = x + self.block.drop_path2(self.block.ls2(self.block.mlp(self.block.norm2(x)))) - - return x - - -class NullVisionTransformerBlockWrapper(nn.Module): - """ - Null wrapper for ViT blocks that doesn't do anything; just calls the original block's forward function. - Useful if you want to use a block wrapper every X blocks instead of every block (e.g., to reduce the number of new - parameters introduced by a new wrapper). - """ - - def __init__( - self, - block, - ): - super().__init__() - self.block = block - - def forward(self, x, average_language_embedding): - return self.block(x) - - -def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]: - """Utility function for monkey-patching functions.""" - - def wrapper(*args: Any, **kwargs: Any) -> Any: - result = fn(*args, **kwargs) - return result[0] if isinstance(result, tuple) else result - - return wrapper - - -class FiLMedVisionTransformer(VisionTransformer): - """ - Wrapper for timm.models.vision_transformer.VisionTransformer that overrides functions to enable infusing language - embeddings into visual embeddings via FiLM. - """ - - def _intermediate_layers( - self, - x: torch.Tensor, - language_embeddings: torch.Tensor, - n: Union[int, Sequence] = 1, - ): - """ - Copy of timm.models.vision_transformer.VisionTransformer._intermediate_layers() with modifications - to take in language embeddings as additional input. - """ - outputs, num_blocks = [], len(self.blocks) - take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n) - - # forward pass - x = self.patch_embed(x) - x = self._pos_embed(x) - x = self.patch_drop(x) - x = self.norm_pre(x) - for i, blk in enumerate(self.blocks): - x = blk(x, language_embeddings) # Modified to receive language_embeddings - if i in take_indices: - outputs.append(x) - - return outputs - - def get_intermediate_layers( - self, - x: torch.Tensor, - language_embeddings: torch.Tensor, - n: Union[int, Sequence] = 1, - reshape: bool = False, - return_prefix_tokens: bool = False, - norm: bool = False, - ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: - """ - Copy of timm.models.vision_transformer.VisionTransformer.get_intermediate_layers() with modifications - to allow language embeddings as additional input. - """ - # take last n blocks if n is an int, if in is a sequence, select by matching indices - outputs = self._intermediate_layers(x, language_embeddings, n) - if norm: - outputs = [self.norm(out) for out in outputs] - prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs] - outputs = [out[:, self.num_prefix_tokens :] for out in outputs] - - if reshape: - grid_size = self.patch_embed.grid_size - outputs = [ - out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous() - for out in outputs - ] - - if return_prefix_tokens: - return tuple(zip(outputs, prefix_tokens)) - return tuple(outputs) - - -class FiLMedPrismaticVisionBackbone(nn.Module): - """ - Wrapper for OpenVLA's vision backbone that implements feature-wise linear modulation (FiLM). - - Wraps the Vision Transformers in the vision backbone to enable language conditioning through FiLM. - Supports processing 1-3 images using dual vision backbones (SigLIP + DINOv2). - """ - - def __init__( - self, - vision_backbone, - llm_dim: int = 4096, # 4096 for Llama-2 7B - ) -> None: - """ - Initializes FiLM wrapper. - - Args: - vision_backbone (PrismaticVisionBackbone): Base vision backbone. - llm_dim (int): Dimension of language model embeddings. - """ - super().__init__() - self.vision_backbone = vision_backbone - self.llm_dim = llm_dim - - # Wrap vision transformers - self._wrap_vit(self.vision_backbone.featurizer) # SigLIP - if self.vision_backbone.use_fused_vision_backbone: - self._wrap_vit(self.vision_backbone.fused_featurizer) # DINOv2 - - def _wrap_vit(self, vit) -> None: - """ - Creates wrapper around an individual vision transformer to allow for infusion of language inputs. - - Args: - vit (VisionTransformer): Original vision transformer. - """ - # Wrap vision transformer blocks - block_wrappers = [] - for block in vit.blocks: - block_wrappers.append( - FiLMedVisionTransformerBlock(block=block, vision_dim=vit.num_features, llm_dim=self.llm_dim) - ) - vit.blocks = nn.Sequential(*block_wrappers) - - # Wrap vision transformer with new class that overrides functions used for forward pass - vit.__class__ = FiLMedVisionTransformer - vit.forward = unpack_tuple(partial(vit.get_intermediate_layers, n={len(vit.blocks) - 2})) - - def get_num_patches(self) -> int: - """Returns the number of vision patches output by the vision backbone.""" - return self.vision_backbone.get_num_patches() - - def get_num_images_in_input(self) -> int: - """Returns the number of input images for the vision backbone.""" - return self.vision_backbone.get_num_images_in_input() - - def set_num_images_in_input(self, num_images_in_input: int) -> None: - """Sets the number of input images for the vision backbone.""" - self.vision_backbone.set_num_images_in_input(num_images_in_input) - - def forward(self, pixel_values: torch.Tensor, language_embeddings: torch.Tensor) -> torch.Tensor: - """ - Implements the forward pass for the vision backbone with FiLM to infuse language inputs into visual features. - - Identical to PrismaticVisionBackbone.forward() except that language embeddings are also used as input. - - Args: - pixel_values (torch.Tensor): Pixels for input image(s), (B, C, H, W). - language_embeddings (torch.Tensor): Language embeddings for the task description, (B, seq_len, llm_dim). - """ - # For FiLM: Average the language embeddings of the task description - average_language_embedding = language_embeddings.mean(dim=1) - - if self.get_num_images_in_input() == 1: - if not self.vision_backbone.use_fused_vision_backbone: - return self.vision_backbone.featurizer(pixel_values, average_language_embedding) - - # Split `pixel_values :: [bsz, 2 * 3, resolution, resolution]` =>> featurize =>> channel stack - img, img_fused = torch.split(pixel_values, [3, 3], dim=1) - patches = self.vision_backbone.featurizer(img, average_language_embedding) - patches_fused = self.vision_backbone.fused_featurizer(img_fused, average_language_embedding) - - return torch.cat([patches, patches_fused], dim=2) - - else: - assert self.vision_backbone.use_fused_vision_backbone, "Multi-image inputs require using fused backbone!" - - # Split `pixel_values` into individual images (each with 6 channels: 3 for SigLIP + 3 for DINOv2) - images = torch.split(pixel_values, [6] * self.get_num_images_in_input(), dim=1) - - # Process each image and collect patches - all_patches = [] - for img in images: - # Split each image further into two stacks of channels (each with 3 channels) - img_regular, img_fused = torch.split(img, [3, 3], dim=1) - - # Get patches from both SigLIP and DINOv2 vision transformers - patches = self.vision_backbone.featurizer(img_regular, average_language_embedding) - patches_fused = self.vision_backbone.fused_featurizer(img_fused, average_language_embedding) - - # Concatenate SigLIP and DINOv2 patches along the hidden dimension - combined_patches = torch.cat([patches, patches_fused], dim=2) - all_patches.append(combined_patches) - - # Concatenate all patches along the patch dimension - return torch.cat(all_patches, dim=1) diff --git a/capvector-oft/prismatic/models/load.py b/capvector-oft/prismatic/models/load.py deleted file mode 100644 index 76cc3a3ae2362806d4d179fc4e427a9ac89eeb84..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/models/load.py +++ /dev/null @@ -1,226 +0,0 @@ -""" -load.py - -Entry point for loading pretrained VLMs for inference; exposes functions for listing available models (with canonical -IDs, mappings to paper experiments, and short descriptions), as well as for loading models (from disk or HF Hub). -""" - -import json -import os -from pathlib import Path -from typing import List, Optional, Union - -from huggingface_hub import HfFileSystem, hf_hub_download - -from prismatic.conf import ModelConfig -from prismatic.models.materialize import get_llm_backbone_and_tokenizer, get_vision_backbone_and_transform -from prismatic.models.registry import GLOBAL_REGISTRY, MODEL_REGISTRY -from prismatic.models.vlas import OpenVLA -from prismatic.models.vlms import PrismaticVLM -from prismatic.overwatch import initialize_overwatch -from prismatic.vla.action_tokenizer import ActionTokenizer - -# Initialize Overwatch =>> Wraps `logging.Logger` -overwatch = initialize_overwatch(__name__) - - -# === HF Hub Repository === -HF_HUB_REPO = "TRI-ML/prismatic-vlms" -VLA_HF_HUB_REPO = "openvla/openvla-dev" - - -# === Available Models === -def available_models() -> List[str]: - return list(MODEL_REGISTRY.keys()) - - -def available_model_names() -> List[str]: - return list(GLOBAL_REGISTRY.items()) - - -def get_model_description(model_id_or_name: str) -> str: - if model_id_or_name not in GLOBAL_REGISTRY: - raise ValueError(f"Couldn't find `{model_id_or_name = }; check `prismatic.available_model_names()`") - - # Print Description & Return - print(json.dumps(description := GLOBAL_REGISTRY[model_id_or_name]["description"], indent=2)) - - return description - - -# === Load Pretrained Model === -def load( - model_id_or_path: Union[str, Path], - hf_token: Optional[str] = None, - cache_dir: Optional[Union[str, Path]] = None, - load_for_training: bool = False, -) -> PrismaticVLM: - """Loads a pretrained PrismaticVLM from either local disk or the HuggingFace Hub.""" - if os.path.isdir(model_id_or_path): - overwatch.info(f"Loading from local path `{(run_dir := Path(model_id_or_path))}`") - - # Get paths for `config.json` and pretrained checkpoint - config_json, checkpoint_pt = run_dir / "config.json", run_dir / "checkpoints" / "latest-checkpoint.pt" - assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`" - assert checkpoint_pt.exists(), f"Missing checkpoint for `{run_dir = }`" - else: - if model_id_or_path not in GLOBAL_REGISTRY: - raise ValueError(f"Couldn't find `{model_id_or_path = }; check `prismatic.available_model_names()`") - - overwatch.info(f"Downloading `{(model_id := GLOBAL_REGISTRY[model_id_or_path]['model_id'])} from HF Hub") - with overwatch.local_zero_first(): - config_json = hf_hub_download(repo_id=HF_HUB_REPO, filename=f"{model_id}/config.json", cache_dir=cache_dir) - checkpoint_pt = hf_hub_download( - repo_id=HF_HUB_REPO, filename=f"{model_id}/checkpoints/latest-checkpoint.pt", cache_dir=cache_dir - ) - - # Load Model Config from `config.json` - with open(config_json, "r") as f: - model_cfg = json.load(f)["model"] - - # = Load Individual Components necessary for Instantiating a VLM = - # =>> Print Minimal Config - overwatch.info( - f"Found Config =>> Loading & Freezing [bold blue]{model_cfg['model_id']}[/] with:\n" - f" Vision Backbone =>> [bold]{model_cfg['vision_backbone_id']}[/]\n" - f" LLM Backbone =>> [bold]{model_cfg['llm_backbone_id']}[/]\n" - f" Arch Specifier =>> [bold]{model_cfg['arch_specifier']}[/]\n" - f" Checkpoint Path =>> [underline]`{checkpoint_pt}`[/]" - ) - - # Load Vision Backbone - overwatch.info(f"Loading Vision Backbone [bold]{model_cfg['vision_backbone_id']}[/]") - vision_backbone, image_transform = get_vision_backbone_and_transform( - model_cfg["vision_backbone_id"], - model_cfg["image_resize_strategy"], - ) - - # Load LLM Backbone --> note `inference_mode = True` by default when calling `load()` - overwatch.info(f"Loading Pretrained LLM [bold]{model_cfg['llm_backbone_id']}[/] via HF Transformers") - llm_backbone, tokenizer = get_llm_backbone_and_tokenizer( - model_cfg["llm_backbone_id"], - llm_max_length=model_cfg.get("llm_max_length", 2048), - hf_token=hf_token, - inference_mode=not load_for_training, - ) - - # Load VLM using `from_pretrained` (clobbers HF syntax... eventually should reconcile) - overwatch.info(f"Loading VLM [bold blue]{model_cfg['model_id']}[/] from Checkpoint") - vlm = PrismaticVLM.from_pretrained( - checkpoint_pt, - model_cfg["model_id"], - vision_backbone, - llm_backbone, - arch_specifier=model_cfg["arch_specifier"], - freeze_weights=not load_for_training, - ) - - return vlm - - -# === Load Pretrained VLA Model === -def load_vla( - model_id_or_path: Union[str, Path], - hf_token: Optional[str] = None, - cache_dir: Optional[Union[str, Path]] = None, - load_for_training: bool = False, - step_to_load: Optional[int] = None, - model_type: str = "pretrained", -) -> OpenVLA: - """Loads a pretrained OpenVLA from either local disk or the HuggingFace Hub.""" - - # TODO (siddk, moojink) :: Unify semantics with `load()` above; right now, `load_vla()` assumes path points to - # checkpoint `.pt` file, rather than the top-level run directory! - if os.path.isfile(model_id_or_path): - overwatch.info(f"Loading from local checkpoint path `{(checkpoint_pt := Path(model_id_or_path))}`") - - # [Validate] Checkpoint Path should look like `...//checkpoints/.pt` - assert (checkpoint_pt.suffix == ".pt") and (checkpoint_pt.parent.name == "checkpoints"), "Invalid checkpoint!" - run_dir = checkpoint_pt.parents[1] - - # Get paths for `config.json`, `dataset_statistics.json` and pretrained checkpoint - config_json, dataset_statistics_json = run_dir / "config.json", run_dir / "dataset_statistics.json" - assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`" - assert dataset_statistics_json.exists(), f"Missing `dataset_statistics.json` for `{run_dir = }`" - - # Otherwise =>> try looking for a match on `model_id_or_path` on the HF Hub (`VLA_HF_HUB_REPO`) - else: - # Search HF Hub Repo via fsspec API - overwatch.info(f"Checking HF for `{(hf_path := str(Path(VLA_HF_HUB_REPO) / model_type / model_id_or_path))}`") - if not (tmpfs := HfFileSystem()).exists(hf_path): - raise ValueError(f"Couldn't find valid HF Hub Path `{hf_path = }`") - - # Identify Checkpoint to Load (via `step_to_load`) - step_to_load = f"{step_to_load:06d}" if step_to_load is not None else None - valid_ckpts = tmpfs.glob(f"{hf_path}/checkpoints/step-{step_to_load if step_to_load is not None else ''}*.pt") - if (len(valid_ckpts) == 0) or (step_to_load is not None and len(valid_ckpts) != 1): - raise ValueError(f"Couldn't find a valid checkpoint to load from HF Hub Path `{hf_path}/checkpoints/") - - # Call to `glob` will sort steps in ascending order (if `step_to_load` is None); just grab last element - target_ckpt = Path(valid_ckpts[-1]).name - - overwatch.info(f"Downloading Model `{model_id_or_path}` Config & Checkpoint `{target_ckpt}`") - with overwatch.local_zero_first(): - relpath = Path(model_type) / model_id_or_path - config_json = hf_hub_download( - repo_id=VLA_HF_HUB_REPO, filename=f"{(relpath / 'config.json')!s}", cache_dir=cache_dir - ) - dataset_statistics_json = hf_hub_download( - repo_id=VLA_HF_HUB_REPO, filename=f"{(relpath / 'dataset_statistics.json')!s}", cache_dir=cache_dir - ) - checkpoint_pt = hf_hub_download( - repo_id=VLA_HF_HUB_REPO, filename=f"{(relpath / 'checkpoints' / target_ckpt)!s}", cache_dir=cache_dir - ) - - # Load VLA Config (and corresponding base VLM `ModelConfig`) from `config.json` - with open(config_json, "r") as f: - vla_cfg = json.load(f)["vla"] - model_cfg = ModelConfig.get_choice_class(vla_cfg["base_vlm"])() - - # Load Dataset Statistics for Action Denormalization - with open(dataset_statistics_json, "r") as f: - norm_stats = json.load(f) - - # = Load Individual Components necessary for Instantiating a VLA (via base VLM components) = - # =>> Print Minimal Config - overwatch.info( - f"Found Config =>> Loading & Freezing [bold blue]{model_cfg.model_id}[/] with:\n" - f" Vision Backbone =>> [bold]{model_cfg.vision_backbone_id}[/]\n" - f" LLM Backbone =>> [bold]{model_cfg.llm_backbone_id}[/]\n" - f" Arch Specifier =>> [bold]{model_cfg.arch_specifier}[/]\n" - f" Checkpoint Path =>> [underline]`{checkpoint_pt}`[/]" - ) - - # Load Vision Backbone - overwatch.info(f"Loading Vision Backbone [bold]{model_cfg.vision_backbone_id}[/]") - vision_backbone, image_transform = get_vision_backbone_and_transform( - model_cfg.vision_backbone_id, - model_cfg.image_resize_strategy, - ) - - # Load LLM Backbone --> note `inference_mode = True` by default when calling `load()` - overwatch.info(f"Loading Pretrained LLM [bold]{model_cfg.llm_backbone_id}[/] via HF Transformers") - llm_backbone, tokenizer = get_llm_backbone_and_tokenizer( - model_cfg.llm_backbone_id, - llm_max_length=model_cfg.llm_max_length, - hf_token=hf_token, - inference_mode=not load_for_training, - ) - - # Create Action Tokenizer - action_tokenizer = ActionTokenizer(llm_backbone.get_tokenizer()) - - # Load VLM using `from_pretrained` (clobbers HF syntax... eventually should reconcile) - overwatch.info(f"Loading VLA [bold blue]{model_cfg.model_id}[/] from Checkpoint") - vla = OpenVLA.from_pretrained( - checkpoint_pt, - model_cfg.model_id, - vision_backbone, - llm_backbone, - arch_specifier=model_cfg.arch_specifier, - freeze_weights=not load_for_training, - norm_stats=norm_stats, - action_tokenizer=action_tokenizer, - ) - - return vla diff --git a/capvector-oft/prismatic/models/materialize.py b/capvector-oft/prismatic/models/materialize.py deleted file mode 100644 index 90b1fd4ba4dc15fe1a1da47db5ec970424f7c066..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/models/materialize.py +++ /dev/null @@ -1,130 +0,0 @@ -""" -materialize.py - -Factory class for initializing Vision Backbones, LLM Backbones, and VLMs from a set registry; provides and exports -individual functions for clear control flow. -""" - -from typing import Optional, Tuple - -from transformers import PreTrainedTokenizerBase - -from prismatic.models.backbones.llm import LLaMa2LLMBackbone, LLMBackbone, MistralLLMBackbone, PhiLLMBackbone -from prismatic.models.backbones.vision import ( - CLIPViTBackbone, - DinoCLIPViTBackbone, - DinoSigLIPViTBackbone, - DinoV2ViTBackbone, - ImageTransform, - IN1KViTBackbone, - SigLIPViTBackbone, - VisionBackbone, -) -from prismatic.models.vlms import PrismaticVLM - -# === Registries =>> Maps ID --> {cls(), kwargs} :: Different Registries for Vision Backbones, LLM Backbones, VLMs === -# fmt: off - -# === Vision Backbone Registry === -VISION_BACKBONES = { - # === 224px Backbones === - "clip-vit-l": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 224}}, - "siglip-vit-so400m": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 224}}, - "dinov2-vit-l": {"cls": DinoV2ViTBackbone, "kwargs": {"default_image_size": 224}}, - "in1k-vit-l": {"cls": IN1KViTBackbone, "kwargs": {"default_image_size": 224}}, - "dinosiglip-vit-so-224px": {"cls": DinoSigLIPViTBackbone, "kwargs": {"default_image_size": 224}}, - - # === Assorted CLIP Backbones === - "clip-vit-b": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 224}}, - "clip-vit-l-336px": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 336}}, - - # === Assorted SigLIP Backbones === - "siglip-vit-b16-224px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 224}}, - "siglip-vit-b16-256px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 256}}, - "siglip-vit-b16-384px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 384}}, - "siglip-vit-so400m-384px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 384}}, - - # === Fused Backbones === - "dinoclip-vit-l-336px": {"cls": DinoCLIPViTBackbone, "kwargs": {"default_image_size": 336}}, - "dinosiglip-vit-so-384px": {"cls": DinoSigLIPViTBackbone, "kwargs": {"default_image_size": 384}}, -} - - -# === Language Model Registry === -LLM_BACKBONES = { - # === LLaMa-2 Pure (Non-Chat) Backbones === - "llama2-7b-pure": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, - "llama2-13b-pure": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, - - # === LLaMa-2 Chat Backbones === - "llama2-7b-chat": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, - "llama2-13b-chat": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, - - # === Vicuna-v1.5 Backbones === - "vicuna-v15-7b": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, - "vicuna-v15-13b": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, - - # === Mistral v0.1 Backbones === - "mistral-v0.1-7b-pure": {"cls": MistralLLMBackbone, "kwargs": {}}, - "mistral-v0.1-7b-instruct": {"cls": MistralLLMBackbone, "kwargs": {}}, - - # === Phi-2 Backbone === - "phi-2-3b": {"cls": PhiLLMBackbone, "kwargs": {}}, -} - -# fmt: on - - -def get_vision_backbone_and_transform( - vision_backbone_id: str, image_resize_strategy: str -) -> Tuple[VisionBackbone, ImageTransform]: - """Instantiate a Vision Backbone, returning both the nn.Module wrapper class and default Image Transform.""" - if vision_backbone_id in VISION_BACKBONES: - vision_cfg = VISION_BACKBONES[vision_backbone_id] - vision_backbone: VisionBackbone = vision_cfg["cls"]( - vision_backbone_id, image_resize_strategy, **vision_cfg["kwargs"] - ) - image_transform = vision_backbone.get_image_transform() - return vision_backbone, image_transform - - else: - raise ValueError(f"Vision Backbone `{vision_backbone_id}` is not supported!") - - -def get_llm_backbone_and_tokenizer( - llm_backbone_id: str, - llm_max_length: int = 2048, - hf_token: Optional[str] = None, - inference_mode: bool = False, -) -> Tuple[LLMBackbone, PreTrainedTokenizerBase]: - if llm_backbone_id in LLM_BACKBONES: - llm_cfg = LLM_BACKBONES[llm_backbone_id] - llm_backbone: LLMBackbone = llm_cfg["cls"]( - llm_backbone_id, - llm_max_length=llm_max_length, - hf_token=hf_token, - inference_mode=inference_mode, - **llm_cfg["kwargs"], - ) - tokenizer = llm_backbone.get_tokenizer() - return llm_backbone, tokenizer - - else: - raise ValueError(f"LLM Backbone `{llm_backbone_id}` is not supported!") - - -def get_vlm( - model_id: str, - arch_specifier: str, - vision_backbone: VisionBackbone, - llm_backbone: LLMBackbone, - enable_mixed_precision_training: bool = True, -) -> PrismaticVLM: - """Lightweight wrapper around initializing a VLM, mostly for future-proofing (if one wants to add a new VLM).""" - return PrismaticVLM( - model_id, - vision_backbone, - llm_backbone, - enable_mixed_precision_training=enable_mixed_precision_training, - arch_specifier=arch_specifier, - ) diff --git a/capvector-oft/prismatic/models/projectors.py b/capvector-oft/prismatic/models/projectors.py deleted file mode 100644 index 80aee7f02198b6a271b122e1427dad3faafb4be0..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/models/projectors.py +++ /dev/null @@ -1,49 +0,0 @@ -"""Implementation of additional projectors for additional inputs to the VLA models.""" -import torch -import torch.nn as nn - - -class ProprioProjector(nn.Module): - """ - Projects proprio state inputs into the LLM's embedding space. - """ - def __init__(self, llm_dim: int, proprio_dim: int) -> None: - super().__init__() - self.llm_dim = llm_dim - self.proprio_dim = proprio_dim - - self.fc1 = nn.Linear(self.proprio_dim, self.llm_dim, bias=True) - self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True) - self.act_fn1 = nn.GELU() - - def forward(self, proprio: torch.Tensor = None) -> torch.Tensor: - # proprio: (bsz, proprio_dim) - projected_features = self.fc1(proprio) - projected_features = self.act_fn1(projected_features) - projected_features = self.fc2(projected_features) - return projected_features - - -class NoisyActionProjector(nn.Module): - """ - [Diffusion] Projects noisy action inputs into the LLM's embedding space. - - Note that since each action is tokenized into 7 tokens in OpenVLA (rather - than having 1 token per action), each noisy action token will have dimension 1 - instead of 7. - """ - def __init__(self, llm_dim: int) -> None: - super().__init__() - self.llm_dim = llm_dim - self.action_token_dim = 1 - - self.fc1 = nn.Linear(self.action_token_dim, self.llm_dim, bias=True) - self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True) - self.act_fn1 = nn.GELU() - - def forward(self, noisy_actions: torch.Tensor = None) -> torch.Tensor: - # noisy_actions: (bsz, num_action_tokens=chunk_len*action_dim, 1) - projected_features = self.fc1(noisy_actions) - projected_features = self.act_fn1(projected_features) - projected_features = self.fc2(projected_features) - return projected_features diff --git a/capvector-oft/prismatic/models/registry.py b/capvector-oft/prismatic/models/registry.py deleted file mode 100644 index cde181f1f50998e166fcbc376407c87777ec642b..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/models/registry.py +++ /dev/null @@ -1,691 +0,0 @@ -""" -registry.py - -Exhaustive list of pretrained VLMs (with full descriptions / links to corresponding names and sections of paper). -""" - -# === Pretrained Model Registry === -# fmt: off -MODEL_REGISTRY = { - # === LLaVa v1.5 Reproductions === - "reproduction-llava-v15+7b": { - "model_id": "reproduction-llava-v15+7b", - "names": ["LLaVa v1.5 7B (Reproduction)"], - "description": { - "name": "LLaVa v1.5 7B (Reproduction)", - "optimization_procedure": "multi-stage", - "visual_representation": "CLIP ViT-L/14 @ 336px", - "image_processing": "Letterbox", - "language_model": "Vicuña v1.5 7B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - } - }, - "reproduction-llava-v15+13b": { - "model_id": "reproduction-llava-v15+13b", - "names": ["LLaVa v1.5 13B (Reproduction)"], - "description": { - "name": "LLaVa v1.5 13B (Reproduction)", - "optimization_procedure": "multi-stage", - "visual_representation": "CLIP ViT-L/14 @ 336px", - "image_processing": "Letterbox", - "language_model": "Vicuña v1.5 13B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - } - }, - - # === Section 4.1 :: Optimization Procedure === - "one-stage+7b": { - "model_id": "one-stage+7b", - "names": [ - "One-Stage 7B", - "Single-Stage 7B", - "Frozen ViT (Single-Stage)", - "CLIP ViT-L 336px (Letterbox)", - "CLIP ViT-L 336px", - "Vicuña v1.5 7B", - "1 Epoch", - "Base", - ], - "description": { - "name": "Single-Stage 7B", - "optimization_procedure": "single-stage", - "visual_representation": "CLIP ViT-L/14 @ 336px", - "image_processing": "Letterbox", - "language_model": "Vicuña v1.5 7B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - } - }, - "one-stage+13b": { - "model_id": "one-stage+13b", - "names": [ - "One-Stage 13B", - "Single-Stage 13B", - "Vicuña v1.5 13B", - ], - "description": { - "name": "Single-Stage 13B", - "optimization_procedure": "single-stage", - "visual_representation": "CLIP ViT-L/14 @ 336px", - "image_processing": "Letterbox", - "language_model": "Vicuña v1.5 13B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - } - }, - - "full-ft-multi-stage+7b": { - "model_id": "full-ft-multi-stage+7b", - "names": ["Finetune ViT (Multi-Stage)"], - "description": { - "name": "Finetune ViT (Multi-Stage)", - "optimization_procedure": "multi-stage-full-finetune", - "visual_representation": "CLIP ViT-L/14 @ 336px", - "image_processing": "Letterbox", - "language_model": "Vicuña v1.5 7B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - } - }, - "full-ft-one-stage+7b": { - "model_id": "full-ft-one-stage+7b", - "names": ["Finetune ViT (Single-Stage)"], - "description": { - "name": "Finetune ViT (Single-Stage)", - "optimization_procedure": "single-stage-full-finetune", - "visual_representation": "CLIP ViT-L/14 @ 336px", - "image_processing": "Letterbox", - "language_model": "Vicuña v1.5 7B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - } - }, - - # === Section 4.2 :: Image Processing and Visual Representations === - "in1k-224px+7b": { - "model_id": "in1k-224px+7b", - "names": ["IN1K ViT-L 224px"], - "description": { - "name": "IN1K ViT-L 224px", - "optimization_procedure": "single-stage", - "visual_representation": "ImageNet-21K+1K ViT-L/16 @ 224px", - "image_processing": "Letterbox", - "language_model": "Vicuña v1.5 7B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - }, - }, - "dinov2-224px+7b": { - "model_id": "dinov2-224px+7b", - "names": ["DINOv2 ViT-L 224px"], - "description": { - "name": "DINOv2 ViT-L 224px", - "optimization_procedure": "single-stage", - "visual_representation": "DINOv2 ViT-L/14 @ 224px", - "image_processing": "Letterbox", - "language_model": "Vicuña v1.5 7B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - }, - }, - "clip-224px+7b": { - "model_id": "clip-224px+7b", - "names": ["CLIP ViT-L 224px"], - "description": { - "name": "CLIP ViT-L 224px", - "optimization_procedure": "single-stage", - "visual_representation": "CLIP ViT-L/14 @ 224px", - "image_processing": "Letterbox", - "language_model": "Vicuña v1.5 7B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - }, - }, - "siglip-224px+7b": { - "model_id": "siglip-224px+7b", - "names": ["SigLIP ViT-SO 224px"], - "description": { - "name": "SigLIP ViT-SO 224px", - "optimization_procedure": "single-stage", - "visual_representation": "SigLIP ViT-SO/14 @ 224px", - "image_processing": "Letterbox", - "language_model": "Vicuña v1.5 7B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - }, - }, - - "clip-336px-resize-crop+7b": { - "model_id": "clip-336px-resize-crop+7b", - "names": ["CLIP ViT-L 336px (Resize Crop)"], - "description": { - "name": "CLIP ViT-L 336px (Resize Crop)", - "optimization_procedure": "single-stage", - "visual_representation": "CLIP ViT-L/14 @ 336px", - "image_processing": "Resize Crop", - "language_model": "Vicuña v1.5 7B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - } - }, - "clip-336px-resize-naive+7b": { - "model_id": "clip-336px-resize-naive+7b", - "names": ["CLIP ViT-L 336px (Naive Resize)", "CLIP 336px (Naive Resize)"], - "description": { - "name": "CLIP ViT-L 336px (Naive Resize)", - "optimization_procedure": "single-stage", - "visual_representation": "CLIP ViT-L/14 @ 336px", - "image_processing": "Naive Resize", - "language_model": "Vicuña v1.5 7B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - } - }, - "siglip-384px-letterbox+7b": { - "model_id": "siglip-384px-letterbox+7b", - "names": ["SigLIP ViT-SO 384px (Letterbox)", "SigLIP ViT-SO 384px"], - "description": { - "name": "SigLIP ViT-SO 384px (Letterbox)", - "optimization_procedure": "single-stage", - "visual_representation": "SigLIP ViT-SO/14 @ 384px", - "image_processing": "Letterbox", - "language_model": "Vicuña v1.5 7B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - } - }, - "siglip-384px-resize-crop+7b": { - "model_id": "siglip-384px-resize-crop+7b", - "names": ["SigLIP ViT-SO 384px (Resize Crop)"], - "description": { - "name": "SigLIP ViT-SO 384px (Resize Crop)", - "optimization_procedure": "single-stage", - "visual_representation": "SigLIP ViT-SO/14 @ 384px", - "image_processing": "Resize Crop", - "language_model": "Vicuña v1.5 7B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - } - }, - "siglip-384px-resize-naive+7b": { - "model_id": "siglip-384px-resize-naive+7b", - "names": ["SigLIP ViT-SO 384px (Naive Resize)", "SigLIP 384px (Naive Resize)"], - "description": { - "name": "SigLIP ViT-SO 384px (Naive Resize)", - "optimization_procedure": "single-stage", - "visual_representation": "SigLIP ViT-SO/14 @ 384px", - "image_processing": "Naive Resize", - "language_model": "Vicuña v1.5 7B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - } - }, - - "dinoclip-336px-letterbox+7b": { - "model_id": "dinoclip-336px-letterbox+7b", - "names": ["DINOv2 + CLIP 336px (Letterbox)"], - "description": { - "name": "DINOv2 + CLIP 336px (Letterbox)", - "optimization_procedure": "single-stage", - "visual_representation": "DINOv2 ViT-L/14 + CLIP ViT-L/14 @ 336px", - "image_processing": "Letterbox", - "language_model": "Vicuña v1.5 7B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - } - }, - "dinoclip-336px-resize-naive+7b": { - "model_id": "dinoclip-336px-resize-naive+7b", - "names": ["DINOv2 + CLIP 336px (Naive Resize)"], - "description": { - "name": "DINOv2 + CLIP 336px (Naive Resize)", - "optimization_procedure": "single-stage", - "visual_representation": "DINOv2 ViT-L/14 + CLIP ViT-L/14 @ 336px", - "image_processing": "Naive Resize", - "language_model": "Vicuña v1.5 7B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - } - }, - "dinosiglip-384px-letterbox+7b": { - "model_id": "dinosiglip-384px-letterbox+7b", - "names": ["DINOv2 + SigLIP 384px (Letterbox)"], - "description": { - "name": "DINOv2 + SigLIP 384px (Letterbox)", - "optimization_procedure": "single-stage", - "visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-L/14 @ 384px", - "image_processing": "Letterbox", - "language_model": "Vicuña v1.5 7B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - } - }, - "dinosiglip-384px-resize-naive+7b": { - "model_id": "dinosiglip-384px-resize-naive+7b", - "names": ["DINOv2 + SigLIP 384px (Naive Resize)"], - "description": { - "name": "DINOv2 + SigLIP 384px (Naive Resize)", - "optimization_procedure": "single-stage", - "visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-L/14 @ 384px", - "image_processing": "Naive Resize", - "language_model": "Vicuña v1.5 7B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - } - }, - - # === Section 4.3 :: Language Models === - "llama2+7b": { - "model_id": "llama2+7b", - "names": ["Llama-2 7B"], - "description": { - "name": "Llama-2 7B", - "optimization_procedure": "single-stage", - "visual_representation": "CLIP ViT-L/14 @ 336px", - "image_processing": "Letterbox", - "language_model": "Llama-2 7B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - }, - }, - "llama2+13b": { - "model_id": "llama2+13b", - "names": ["Llama-2 13B"], - "description": { - "name": "Llama-2 13B", - "optimization_procedure": "single-stage", - "visual_representation": "CLIP ViT-L/14 @ 336px", - "image_processing": "Letterbox", - "language_model": "Llama-2 13B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - }, - }, - - "vicuna-no-cotraining+7b": { - "model_id": "vicuna-no-cotraining+7b", - "names": ["Vicuña v1.5 7B (No Co-training)"], - "description": { - "name": "Vicuña v1.5 7B (No Co-training)", - "optimization_procedure": "single-stage", - "visual_representation": "CLIP ViT-L/14 @ 336px", - "image_processing": "Letterbox", - "language_model": "Vicuña v1.5 7B", - "datasets": ["LLaVa v1.5 Multimodal-Only"], - "train_epochs": 1, - }, - }, - "llama2-no-cotraining+7b": { - "model_id": "llama2-no-cotraining+7b", - "names": ["Llama-2 7B (No Co-training)"], - "description": { - "name": "Llama-2 7B (No Co-training)", - "optimization_procedure": "single-stage", - "visual_representation": "CLIP ViT-L/14 @ 336px", - "image_processing": "Letterbox", - "language_model": "Llama-2 7B", - "datasets": ["LLaVa v1.5 Multimodal-Only"], - "train_epochs": 1, - }, - }, - - # === Section 4.4 :: Scaling Properties === - "train-1.25-epochs+7b": { - "model_id": "train-1.25-epochs+7b", - "names": ["1.25 Epochs"], - "description": { - "name": "1.25 Epochs", - "optimization_procedure": "single-stage", - "visual_representation": "CLIP ViT-L/14 @ 336px", - "image_processing": "Letterbox", - "language_model": "Vicuña v1.5 7B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1.25, - } - }, - "train-1.5-epochs+7b": { - "model_id": "train-1.5-epochs+7b", - "names": ["1.5 Epochs"], - "description": { - "name": "1.5 Epochs", - "optimization_procedure": "single-stage", - "visual_representation": "CLIP ViT-L/14 @ 336px", - "image_processing": "Letterbox", - "language_model": "Vicuña v1.5 7B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1.5, - } - }, - "train-2-epochs+7b": { - "model_id": "train-2-epochs+7b", - "names": ["2 Epochs"], - "description": { - "name": "2 Epochs", - "optimization_procedure": "single-stage", - "visual_representation": "CLIP ViT-L/14 @ 336px", - "image_processing": "Letterbox", - "language_model": "Vicuña v1.5 7B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 2, - } - }, - "train-3-epochs+7b": { - "model_id": "train-3-epochs+7b", - "names": ["3 Epochs"], - "description": { - "name": "3 Epochs", - "optimization_procedure": "single-stage", - "visual_representation": "CLIP ViT-L/14 @ 336px", - "image_processing": "Letterbox", - "language_model": "Vicuña v1.5 7B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 3, - } - }, - - "llava-lvis4v+7b": { - "model_id": "llava-lvis4v+7b", - "names": ["Base + LVIS-4V"], - "description": { - "name": "Base + LVIS-4V", - "optimization_procedure": "single-stage", - "visual_representation": "CLIP ViT-L/14 @ 336px", - "image_processing": "Letterbox", - "language_model": "Vicuña v1.5 7B", - "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V"], - "train_epochs": 1, - } - }, - "llava-lrv+7b": { - "model_id": "llava-lrv+7b", - "names": ["Base + LRV"], - "description": { - "name": "Base + LRV", - "optimization_procedure": "single-stage", - "visual_representation": "CLIP ViT-L/14 @ 336px", - "image_processing": "Letterbox", - "language_model": "Vicuña v1.5 7B", - "datasets": ["LLaVa v1.5 Instruct", "LRV-Instruct"], - "train_epochs": 1, - } - }, - "llava-lvis4v-lrv+7b": { - "model_id": "llava-lvis4v-lrv+7b", - "names": ["Base + LVIS-4V + LRV"], - "description": { - "name": "Base + LVIS-4V + LRV", - "optimization_procedure": "single-stage", - "visual_representation": "CLIP ViT-L/14 @ 336px", - "image_processing": "Letterbox", - "language_model": "Vicuña v1.5 7B", - "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"], - "train_epochs": 1, - } - }, - - # === - - # === CLIP Prism Models === - "prism-clip-controlled+7b": { - "model_id": "prism-clip-controlled+7b", - "names": ["Prism-CLIP 7B (Controlled)"], - "description": { - "name": "CLIP Prism 7B (Controlled)", - "optimization_procedure": "single-stage", - "visual_representation": "CLIP ViT-L/14 @ 336px", - "image_processing": "Naive Resize", - "language_model": "Llama-2 7B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - } - }, - "prism-clip-controlled+13b": { - "model_id": "prism-clip-controlled+13b", - "names": ["Prism-CLIP 13B (Controlled)"], - "description": { - "name": "CLIP Prism 13B (Controlled)", - "optimization_procedure": "single-stage", - "visual_representation": "CLIP ViT-L/14 @ 336px", - "image_processing": "Naive Resize", - "language_model": "Llama-2 13B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - } - }, - "prism-clip+7b": { - "model_id": "prism-clip+7b", - "names": ["Prism-CLIP 7B"], - "description": { - "name": "CLIP Prism 7B", - "optimization_procedure": "single-stage", - "visual_representation": "CLIP ViT-L/14 @ 336px", - "image_processing": "Naive Resize", - "language_model": "Llama-2 7B", - "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"], - "train_epochs": 2, - }, - }, - "prism-clip+13b": { - "model_id": "prism-clip+13b", - "names": ["Prism-CLIP 13B"], - "description": { - "name": "CLIP Prism 13B", - "optimization_procedure": "single-stage", - "visual_representation": "CLIP ViT-L/14 @ 336px", - "image_processing": "Naive Resize", - "language_model": "Llama-2 13B", - "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"], - "train_epochs": 2, - }, - }, - - # === SigLIP Prism Models == - "prism-siglip-controlled+7b": { - "model_id": "prism-siglip-controlled+7b", - "names": ["Prism-SigLIP 7B (Controlled)"], - "description": { - "name": "SigLIP Prism 7B (Controlled)", - "optimization_procedure": "single-stage", - "visual_representation": "SigLIP ViT-SO/14 @ 384px", - "image_processing": "Naive Resize", - "language_model": "Llama-2 7B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - } - }, - "prism-siglip-controlled+13b": { - "model_id": "prism-siglip-controlled+7b", - "names": ["Prism-SigLIP 13B (Controlled)"], - "description": { - "name": "SigLIP Prism 13B (Controlled)", - "optimization_procedure": "single-stage", - "visual_representation": "SigLIP ViT-SO/14 @ 384px", - "image_processing": "Naive Resize", - "language_model": "Llama-2 13B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - } - }, - "prism-siglip+7b": { - "model_id": "prism-siglip+7b", - "names": ["Prism-SigLIP 7B"], - "description": { - "name": "SigLIP Prism 7B", - "optimization_procedure": "single-stage", - "visual_representation": "SigLIP ViT-SO/14 @ 384px", - "image_processing": "Naive Resize", - "language_model": "Llama-2 7B", - "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"], - "train_epochs": 2, - } - }, - "prism-siglip+13b": { - "model_id": "prism-siglip+13b", - "names": ["Prism-SigLIP 13B"], - "description": { - "name": "SigLIP Prism 13B", - "optimization_procedure": "single-stage", - "visual_representation": "SigLIP ViT-SO/14 @ 384px", - "image_processing": "Naive Resize", - "language_model": "Llama-2 13B", - "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"], - "train_epochs": 2, - } - }, - - # === DINOSigLIP Prism Models === - "prism-dinosiglip-controlled+7b": { - "model_id": "prism-dinosiglip-controlled+7b", - "names": ["Prism-DINOSigLIP 7B (Controlled)", "Prism 7B (Controlled)"], - "description": { - "name": "DINOSigLIP Prism 7B (Controlled)", - "optimization_procedure": "single-stage", - "visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-SO/14 @ 384px", - "image_processing": "Naive Resize", - "language_model": "Llama-2 7B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - } - }, - "prism-dinosiglip-controlled+13b": { - "model_id": "prism-dinosiglip-controlled+13b", - "names": ["Prism-DINOSigLIP 13B (Controlled)", "Prism 13B (Controlled)"], - "description": { - "name": "DINOSigLIP Prism 13B (Controlled)", - "optimization_procedure": "single-stage", - "visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-SO/14 @ 384px", - "image_processing": "Naive Resize", - "language_model": "Llama-2 13B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - } - }, - "prism-dinosiglip+7b": { - "model_id": "prism-dinosiglip+7b", - "names": ["Prism-DINOSigLIP 7B"], - "description": { - "name": "DINOSigLIP Prism 7B", - "optimization_procedure": "single-stage", - "visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-SO/14 @ 384px", - "image_processing": "Naive Resize", - "language_model": "Llama-2 7B", - "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"], - "train_epochs": 2, - }, - }, - "prism-dinosiglip+13b": { - "model_id": "prism-dinosiglip+13b", - "names": ["Prism-DINOSigLIP 13B"], - "description": { - "name": "DINOSigLIP Prism 13B", - "optimization_procedure": "single-stage", - "visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-SO/14 @ 384px", - "image_processing": "Naive Resize", - "language_model": "Llama-2 13B", - "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"], - "train_epochs": 2, - }, - }, - - # === DINOSigLIP 224px Prism Models === - "prism-dinosiglip-224px-controlled+7b": { - "model_id": "prism-dinosiglip-224px-controlled+7b", - "names": ["Prism-DINOSigLIP 224px 7B (Controlled)"], - "description": { - "name": "DINOSigLIP 224px 7B (Controlled)", - "optimization_procedure": "single-stage", - "visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-SO 14 @ 224px", - "image_processing": "Naive Resize", - "language_model": "Llama-2 7B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - } - }, - "prism-dinosiglip-224px+7b": { - "model_id": "prism-dinosiglip-224px+7b", - "names": ["Prism-DINOSigLIP 224px 7B"], - "description": { - "name": "DINOSigLIP 224px 7B", - "optimization_procedure": "single-stage", - "visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-SO 14 @ 224px", - "image_processing": "Naive Resize", - "language_model": "Llama-2 7B", - "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"], - "train_epochs": 2, - } - }, - - # === Additional LLM Backbones === - "llama2-chat+7b": { - "model_id": "llama2-chat+7b", - "names": ["Llama-2 Chat 7B"], - "description": { - "name": "Llama-2 Chat 7B", - "optimization_procedure": "single-stage", - "visual_representation": "CLIP ViT-L/14 @ 336px", - "image_processing": "Letterbox", - "language_model": "Llama-2 Chat 7B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - } - }, - "llama2-chat+13b": { - "model_id": "llama2-chat+13b", - "names": ["Llama-2 Chat 13B"], - "description": { - "name": "Llama-2 Chat 13B", - "optimization_procedure": "single-stage", - "visual_representation": "CLIP ViT-L/14 @ 336px", - "image_processing": "Letterbox", - "language_model": "Llama-2 Chat 13B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - } - }, - "mistral-v0.1+7b": { - "model_id": "mistral-v0.1+7b", - "names": ["Mistral v0.1 7B"], - "description": { - "name": "Mistral v0.1 7B", - "optimization_procedure": "single-stage", - "visual_representation": "CLIP ViT-L/14 @ 336px", - "image_processing": "Letterbox", - "language_model": "Mistral v0.1 7B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - } - }, - "mistral-instruct-v0.1+7b": { - "model_id": "mistral-instruct-v0.1+7b", - "names": ["Mistral Instruct v0.1 7B"], - "description": { - "name": "Mistral Instruct v0.1 7B", - "optimization_procedure": "single-stage", - "visual_representation": "CLIP ViT-L/14 @ 336px", - "image_processing": "Letterbox", - "language_model": "Mistral Instruct v0.1 7B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - } - }, - "phi-2+3b": { - "model_id": "phi-2+3b", - "names": ["Phi-2 3B"], - "description": { - "name": "Phi-2 3B", - "optimization_procedure": "single-stage", - "visual_representation": "CLIP ViT-L/14 @ 336px", - "image_processing": "Letterbox", - "language_model": "Phi-2 3B", - "datasets": ["LLaVa v1.5 Instruct"], - "train_epochs": 1, - } - }, -} - -# Build Global Registry (Model ID, Name) -> Metadata -GLOBAL_REGISTRY = {name: v for k, v in MODEL_REGISTRY.items() for name in [k] + v["names"]} - -# fmt: on diff --git a/capvector-oft/prismatic/models/vlas/__init__.py b/capvector-oft/prismatic/models/vlas/__init__.py deleted file mode 100644 index 0f6889016694a373807c125c4459ecf9b9369811..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/models/vlas/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .openvla import OpenVLA diff --git a/capvector-oft/prismatic/models/vlas/openvla.py b/capvector-oft/prismatic/models/vlas/openvla.py deleted file mode 100644 index 4aa1e3fe8d69f12401d80e8f18708ce8cdaf47be..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/models/vlas/openvla.py +++ /dev/null @@ -1,131 +0,0 @@ -""" -openvla.py - -PyTorch Module defining OpenVLA as a lightweight wrapper around a PrismaticVLM; defines custom logic around -discretizing actions with the ActionTokenizer. -""" - -from typing import Dict, List, Optional - -import numpy as np -import torch -from PIL import Image -from transformers import LlamaTokenizerFast - -from prismatic.models.vlms.prismatic import PrismaticVLM -from prismatic.overwatch import initialize_overwatch -from prismatic.vla.action_tokenizer import ActionTokenizer - -# Initialize Overwatch =>> Wraps `logging.Logger` -overwatch = initialize_overwatch(__name__) - - -class OpenVLA(PrismaticVLM): - def __init__( - self, - *args, - norm_stats: Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]], - action_tokenizer: ActionTokenizer, - **kwargs, - ) -> None: - super().__init__(*args, **kwargs) - self.norm_stats = norm_stats - self.action_tokenizer = action_tokenizer - - @torch.inference_mode() - def predict_action( - self, image: Image, instruction: str, unnorm_key: Optional[str] = None, **kwargs: str - ) -> np.ndarray: - """ - Core function for VLA inference; maps input image and task instruction to continuous action (de-tokenizes). - - @param image: PIL Image as [height, width, 3] - @param instruction: Task instruction string - @param unnorm_key: Optional dataset name for retrieving un-normalizing statistics; if None, checks that model - was trained only on a single dataset, and retrieves those statistics. - - @return Unnormalized (continuous) action vector --> end-effector deltas. - """ - image_transform, tokenizer = self.vision_backbone.image_transform, self.llm_backbone.tokenizer - - # Build VLA Prompt - prompt_builder = self.get_prompt_builder() - prompt_builder.add_turn(role="human", message=f"What action should the robot take to {instruction.lower()}?") - prompt_text = prompt_builder.get_prompt() - - # Prepare Inputs - input_ids = tokenizer(prompt_text, truncation=True, return_tensors="pt").input_ids.to(self.device) - if isinstance(tokenizer, LlamaTokenizerFast): - # If the special empty token ('') does not already appear after the colon (':') token in the prompt - # (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time - if not torch.all(input_ids[:, -1] == 29871): - input_ids = torch.cat( - (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1 - ) - else: - raise ValueError(f"Unsupported `tokenizer` type = {type(tokenizer)}") - - # Preprocess Image - pixel_values = image_transform(image) - if isinstance(pixel_values, torch.Tensor): - pixel_values = pixel_values[None, ...].to(self.device) - elif isinstance(pixel_values, dict): - pixel_values = {k: v[None, ...].to(self.device) for k, v in pixel_values.items()} - else: - raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}") - - # Invoke super().generate --> taps into `GenerationMixin` which (redirects) to `forward()` - autocast_dtype = self.llm_backbone.half_precision_dtype - with torch.autocast("cuda", dtype=autocast_dtype, enabled=self.enable_mixed_precision_training): - # fmt: off - generated_ids = super(PrismaticVLM, self).generate( - input_ids=input_ids, # Shape: [1, seq] - pixel_values=pixel_values, # Shape: [1, 3, res, res] or Dict[str, ...] - max_new_tokens=self.get_action_dim(unnorm_key), - **kwargs - ) - # fmt: on - - # Extract predicted action tokens and translate into (normalized) continuous actions - predicted_action_token_ids = generated_ids[0, -self.get_action_dim(unnorm_key) :] - normalized_actions = self.action_tokenizer.decode_token_ids_to_actions(predicted_action_token_ids.cpu().numpy()) - - # Un-normalize Actions - action_norm_stats = self.get_action_stats(unnorm_key) - mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool)) - action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"]) - actions = np.where( - mask, - 0.5 * (normalized_actions + 1) * (action_high - action_low) + action_low, - normalized_actions, - ) - - return actions - - @staticmethod - def _check_unnorm_key(norm_stats: Dict, unnorm_key: str) -> str: - if unnorm_key is None: - assert len(norm_stats) == 1, ( - f"Your model was trained on more than one dataset, please pass a `unnorm_key` from the following " - f"options to choose the statistics used for un-normalizing actions: {norm_stats.keys()}" - ) - unnorm_key = next(iter(norm_stats.keys())) - - # Error Handling - assert ( - unnorm_key in norm_stats - ), f"The `unnorm_key` you chose is not in the set of available statistics; choose from: {norm_stats.keys()}" - - return unnorm_key - - def get_action_dim(self, unnorm_key: Optional[str] = None) -> int: - """Dimensionality of the policy's action space.""" - unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) - - return len(self.norm_stats[unnorm_key]["action"]["q01"]) - - def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict: - """Dimensionality of the policy's action space.""" - unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) - - return self.norm_stats[unnorm_key]["action"] diff --git a/capvector-oft/prismatic/models/vlms/__init__.py b/capvector-oft/prismatic/models/vlms/__init__.py deleted file mode 100644 index 07b1d606a27c25694d7f2f24ac2bbf686adfe0ab..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/models/vlms/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .prismatic import PrismaticVLM diff --git a/capvector-oft/prismatic/models/vlms/base_vlm.py b/capvector-oft/prismatic/models/vlms/base_vlm.py deleted file mode 100644 index 24e180470926cccb7f2d059f83197efb17b1a573..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/models/vlms/base_vlm.py +++ /dev/null @@ -1,108 +0,0 @@ -""" -base_vlm.py - -Abstract class definition of a Vision-Language Model (VLM), with full annotations of class methods, utility functions, -and initialization logic. This is mostly to future-proof the codebase; while all our experiments instantiate -from PrismaticVLM, theoretically, this base class should be general enough to cover almost all models (e.g., IDEFICS, -PALI, Fuyu) in the future. - -We use Abstract base classes *sparingly* -- mostly as a way to encapsulate any redundant logic or nested inheritance -(e.g., dependence on nn.Module, HF PretrainedModel, etc.). For other abstract objects (e.g., Tokenizers/Transforms), -prefer Protocol definitions instead. -""" - -from __future__ import annotations - -from abc import ABC, abstractmethod -from pathlib import Path -from typing import Callable, List, Optional - -import torch -import torch.nn as nn -from transformers import GenerationMixin, PretrainedConfig -from transformers.modeling_outputs import CausalLMOutputWithPast - -from prismatic.models.backbones.llm import LLMBackbone -from prismatic.models.backbones.llm.prompting import PromptBuilder -from prismatic.models.backbones.vision import VisionBackbone - - -# === Abstract Base Class for arbitrary Vision-Language Models === -class VLM(nn.Module, GenerationMixin, ABC): - def __init__( - self, - model_family: str, - model_id: str, - vision_backbone: VisionBackbone, - llm_backbone: LLMBackbone, - enable_mixed_precision_training: bool = True, - ) -> None: - super().__init__() - self.model_family, self.model_id = model_family, model_id - self.vision_backbone, self.llm_backbone = vision_backbone, llm_backbone - self.enable_mixed_precision_training = enable_mixed_precision_training - - # Instance Attributes for a generic VLM - self.all_module_keys, self.trainable_module_keys = None, None - - # === GenerationMixin Expected Attributes =>> *DO NOT MODIFY* === - self.generation_config = self.llm_backbone.llm.generation_config - self.main_input_name = "input_ids" - - @property - def device(self) -> torch.device: - """Borrowed from `transformers.modeling_utils.py` -- checks parameter device; assumes model on *ONE* device!""" - return next(self.parameters()).device - - @classmethod - @abstractmethod - def from_pretrained( - cls, - pretrained_checkpoint: Path, - model_family: str, - model_id: str, - vision_backbone: VisionBackbone, - llm_backbone: LLMBackbone, - **kwargs: str, - ) -> VLM: ... - - @abstractmethod - def get_prompt_builder(self, system_prompt: Optional[str] = None) -> PromptBuilder: ... - - @abstractmethod - def freeze_backbones(self, stage: str) -> None: ... - - @abstractmethod - def load_from_checkpoint(self, stage: str, run_dir: Path, pretrained_checkpoint: Optional[Path] = None) -> None: ... - - @abstractmethod - def get_fsdp_wrapping_policy(self) -> Callable: ... - - @abstractmethod - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - multimodal_indices: Optional[torch.LongTensor] = None, - ) -> CausalLMOutputWithPast: ... - - # === GenerationMixin Expected Properties & Methods (DO NOT MODIFY) === - @staticmethod - def can_generate() -> bool: - return True - - @property - def config(self) -> PretrainedConfig: - return self.llm_backbone.llm.config - - # => Beam Search Utility - def _reorder_cache(self, past_key_values, beam_idx): - return self.llm_backbone.llm._reorder_cache(past_key_values, beam_idx) diff --git a/capvector-oft/prismatic/models/vlms/prismatic.py b/capvector-oft/prismatic/models/vlms/prismatic.py deleted file mode 100644 index 07477f2a14b1da650c4da747a24e407e2633e8a6..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/models/vlms/prismatic.py +++ /dev/null @@ -1,621 +0,0 @@ -""" -prismatic.py - -PyTorch Module defining a PrismaticVLM, our general interface for defining the various different VLMs in our work. - -Notes: - - For now, we don't subclass `transformers.PretrainedModel` (or CausalLM). Instead, we assume a very limited subset - of the {Model}ForCausalLM API that enables dispatch to the underlying LLM's `generate` utilities (feeding inputs - through our custom projection shim). -""" - -from __future__ import annotations - -from functools import partial -from pathlib import Path -from typing import Callable, Dict, List, Optional, Type, Union - -import torch -from PIL import Image -from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy -from transformers.modeling_outputs import CausalLMOutputWithPast - -from prismatic.models.backbones.llm import LLMBackbone -from prismatic.models.backbones.llm.prompting import PromptBuilder -from prismatic.models.backbones.vision import VisionBackbone -from prismatic.models.vlms.base_vlm import VLM -from prismatic.overwatch import initialize_overwatch -from prismatic.util.nn_utils import FusedMLPProjector, LinearProjector, MLPProjector - -# Initialize Overwatch =>> Wraps `logging.Logger` -overwatch = initialize_overwatch(__name__) - - -# HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels) -IGNORE_INDEX = -100 - - -class PrismaticVLM(VLM): - def __init__( - self, - model_id: str, - vision_backbone: VisionBackbone, - llm_backbone: LLMBackbone, - enable_mixed_precision_training: bool = True, - arch_specifier: str = "gelu-mlp", - **kwargs, - ) -> None: - super().__init__( - "prismatic", - model_id, - vision_backbone, - llm_backbone, - enable_mixed_precision_training=enable_mixed_precision_training, - ) - - # Set Weight Initialization Seed for Projector Consistency - torch.manual_seed(vision_backbone.embed_dim) - - # Initialize Projection (Adapter) based on `arch_specifier` - self.arch_specifier = arch_specifier - if arch_specifier == "linear": - self.projector = LinearProjector(vision_backbone.embed_dim, llm_backbone.embed_dim) - elif arch_specifier.endswith("fused-gelu-mlp"): - self.projector = FusedMLPProjector(vision_backbone.embed_dim, llm_backbone.embed_dim) - elif arch_specifier.endswith("gelu-mlp"): - self.projector = MLPProjector(vision_backbone.embed_dim, llm_backbone.embed_dim) - else: - raise ValueError(f"PrismaticVLM with `{arch_specifier = }` is not supported!") - - # Trackers - self.vision_backbone_requires_grad = False - - # Set Module Keys =>> used in Checkpoint Saving / Model Loading - self.all_module_keys = ["vision_backbone", "llm_backbone", "projector"] - self.trainable_module_keys = [] - - # === Generation Utilities === - # => For computing likelihoods --> get tokens corresponding to "True", "False" and "Yes", "No" - self.string2idx = {} - for trigger_string in ["True", "False", "Yes", "No"] + [chr(ord("A") + i) for i in range(26)]: - token_idx_list = self.llm_backbone.tokenizer.encode(trigger_string, add_special_tokens=False) - assert len(token_idx_list) == 1, f'String "{trigger_string}" is tokenized as more than one token!' - self.string2idx[trigger_string] = token_idx_list[0] - - @classmethod - def from_pretrained( - cls, - pretrained_checkpoint: Path, - model_id: str, - vision_backbone: VisionBackbone, - llm_backbone: LLMBackbone, - enable_mixed_precision_training: bool = True, - arch_specifier: str = "gelu-mlp", - freeze_weights: bool = True, - **kwargs, - ) -> PrismaticVLM: - """Initialize a PrismaticVLM from a pretrained checkpoint, freezing all weights, tailored for inference.""" - vlm = cls( - model_id, - vision_backbone, - llm_backbone, - enable_mixed_precision_training=enable_mixed_precision_training, - arch_specifier=arch_specifier, - **kwargs, - ) - - # Load from Checkpoint (Custom --> should load both *projector* and *llm* weights) - model_state_dict = torch.load(pretrained_checkpoint, map_location="cpu")["model"] - assert ( - "projector" in model_state_dict and "llm_backbone" in model_state_dict - ), "PrismaticVLM `from_pretrained` expects checkpoint with keys for `projector` AND `llm_backbone`!" - - vlm.projector.load_state_dict(model_state_dict["projector"]) - vlm.llm_backbone.load_state_dict(model_state_dict["llm_backbone"]) - if "vision_backbone" in model_state_dict.keys(): - vlm.vision_backbone.load_state_dict(model_state_dict["vision_backbone"]) - - # Freeze Weights - if freeze_weights: - vlm.requires_grad_(False) - vlm.eval() - - return vlm - - def get_prompt_builder(self, system_prompt: Optional[str] = None) -> PromptBuilder: - prompt_initializer: Type[PromptBuilder] = self.llm_backbone.prompt_builder_fn - return prompt_initializer(self.model_family, system_prompt=system_prompt) - - def freeze_backbones(self, stage: str) -> None: - """ - This function sets `requires_grad_` on each of the component modules explicitly, depending on stage. - - We support two separate stages --> "align" and "finetune". - => "align" --> vision_backbone*, llm_backbone* are frozen; only the `projector` is trained. - => "finetune" --> vision_backbone* is frozen; both `projector` and `llm_backbone` are trained. - - :param stage: Pretraining stage in < "align" | "finetune" | "full-finetune" | "vla-train" | "vla-full-train" > - """ - if stage == "align": - self.vision_backbone.requires_grad_(False) - self.llm_backbone.requires_grad_(False) - self.projector.requires_grad_(True) - - # Add to `self.trainable_module_keys` - self.trainable_module_keys = ["projector"] - - # Update Trackers - self.vision_backbone_requires_grad = False - - # Explicitly Log Frozen / Trainable Components - overwatch.info(f"[Frozen] 🥶 =>> Vision Backbone `{self.vision_backbone.identifier}`", ctx_level=1) - overwatch.info(f"[Frozen] 🥶 =>> LLM Backbone `{self.llm_backbone.identifier}`", ctx_level=1) - overwatch.info(f"[TRAINABLE] 🔥 =>> Projector `{self.arch_specifier}`", ctx_level=1) - - elif stage in {"finetune", "vla-train"}: - self.vision_backbone.requires_grad_(False) - self.llm_backbone.requires_grad_(True) - self.projector.requires_grad_(True) - - # Add to `self.trainable_module_keys` - self.trainable_module_keys = ["projector", "llm_backbone"] - - # Update Trackers - self.vision_backbone_requires_grad = False - - # Explicitly Log Frozen / Unfrozen Components - overwatch.info(f"[Frozen] 🥶 =>> Vision Backbone `{self.vision_backbone.identifier}`", ctx_level=1) - overwatch.info(f"[TRAINABLE] 🔥 =>> LLM Backbone `{self.llm_backbone.identifier}`", ctx_level=1) - overwatch.info(f"[TRAINABLE] 🔥 =>> Projector `{self.arch_specifier}`", ctx_level=1) - - elif stage in {"full-finetune", "vla-full-train"}: - self.vision_backbone.dtype = torch.float32 - self.vision_backbone.requires_grad_(True) - self.llm_backbone.requires_grad_(True) - self.projector.requires_grad_(True) - - # Add to `self.trainable_module_keys` - self.trainable_module_keys = ["vision_backbone", "projector", "llm_backbone"] - - # Update Trackers - self.vision_backbone_requires_grad = True - - # Explicitly Log Frozen / Unfrozen Components - overwatch.info(f"[TRAINABLE] 🔥 =>> Vision Backbone `{self.vision_backbone.identifier}`", ctx_level=1) - overwatch.info(f"[TRAINABLE] 🔥 =>> LLM Backbone `{self.llm_backbone.identifier}`", ctx_level=1) - overwatch.info(f"[TRAINABLE] 🔥 =>> Projector `{self.arch_specifier}`", ctx_level=1) - - elif stage in {"last-layer-finetune", "vla-last-layer-train"}: - self.vision_backbone.requires_grad_(False) - self.projector.requires_grad_(False) - self.llm_backbone.requires_grad_(False) - - # Unfreeze final LLM layer - for module in self.llm_backbone.last_layer_finetune_modules: - module.requires_grad_(True) - - # Add to `self.trainable_module_keys` - self.trainable_module_keys = ["llm_backbone"] - - # Update Trackers - self.vision_backbone_requires_grad = False - - # Explicitly Log Frozen / Unfrozen Components - # fmt: off - overwatch.info(f"[Frozen] 🥶 =>> Vision Backbone `{self.vision_backbone.identifier}`", ctx_level=1) # noqa: E501 - overwatch.info(f"[Frozen, except last layer] 🥶🔥 =>> LLM Backbone `{self.llm_backbone.identifier}`", ctx_level=1) # noqa: E501 - overwatch.info(f"[Frozen] 🥶 =>> Projector `{self.arch_specifier}`", ctx_level=1) - # fmt: on - - elif stage in {"vla-sandwich-train"}: - self.vision_backbone.dtype = torch.float32 - self.vision_backbone.requires_grad_(True) - self.projector.requires_grad_(True) - self.llm_backbone.requires_grad_(False) - - # Unfreeze final LLM layer - for module in self.llm_backbone.last_layer_finetune_modules: - module.requires_grad_(True) - - # Add to `self.trainable_module_keys` - self.trainable_module_keys = ["vision_backbone", "projector", "llm_backbone"] - - # Update Trackers - self.vision_backbone_requires_grad = True - - # Explicitly Log Frozen / Unfrozen Components - # fmt: off - overwatch.info(f"[TRAINABLE] 🔥 =>> Vision Backbone `{self.vision_backbone.identifier}`", ctx_level=1) # noqa: E501 - overwatch.info(f"[Frozen, except last layer] 🥶🔥 =>> LLM Backbone `{self.llm_backbone.identifier}`", ctx_level=1) # noqa: E501 - overwatch.info(f"[TRAINABLE] 🔥 =>> Projector `{self.arch_specifier}`", ctx_level=1) - # fmt: on - - else: - raise ValueError(f"Stage `{stage}` is not supported for LLaVa! Try < align | finetune >") - - overwatch.debug("##################################################") - overwatch.debug("##### Trainable Network Parameters: #####") - overwatch.debug("##################################################") - for name, param in self.named_parameters(): - if param.requires_grad: - overwatch.debug(name) - - def load_from_checkpoint(self, stage: str, run_dir: Path, pretrained_checkpoint: Optional[Path] = None) -> None: - """Load weights from checkpoint (if required by the given stage).""" - assert stage in {"align", "finetune", "full-finetune"}, f"Stage {stage} is not supported!" - - # If we're running a `no-align` architecture, we're good! - if self.arch_specifier.startswith("no-align"): - overwatch.info( - f"PrismaticVLM with `{self.arch_specifier = }` does not require pretrained weights!", ctx_level=1 - ) - return - - # Otherwise, handle stage-specific logic! - if stage == "align": - overwatch.info("Stage `align` does not require pretrained weights =>> Starting Training", ctx_level=1) - return - - # Otherwise, load from `pretrained_checkpoint` or match on `run_dir` (s/+stage-finetune/+stage-align/g) - overwatch.info("Stage `finetune` requires `align` pretrained weights", ctx_level=1) - - # Config specifies path to a checkpoint to load - if pretrained_checkpoint is not None: - overwatch.info(f"Loading from Provided Checkpoint `{pretrained_checkpoint}`", ctx_level=1) - model_state_dict = torch.load(pretrained_checkpoint)["model"] - self.projector.load_state_dict(model_state_dict["projector"]) - - return - - # [Contract] If no `pretrained_checkpoint`, assume `align` lives in the run directory; string substitution! - model, scale, _, seed = run_dir.name.split("+") - align_dirs = [ - d - for d in run_dir.parent.iterdir() - if (d.name.startswith(f"{model}+{scale}") and d.name.endswith(f"+stage-align+{seed}")) - ] - assert len(align_dirs) == 1, "Multiple or No Valid Pretrained Directories Exist -- Double Check `runs`!" - if (pretrained_checkpoint := (align_dirs[0] / "checkpoints" / "latest-checkpoint.pt")).exists(): - overwatch.info(f"Loading from Discovered Checkpoint `{pretrained_checkpoint}`", ctx_level=1) - model_state_dict = torch.load(pretrained_checkpoint)["model"] - self.projector.load_state_dict(model_state_dict["projector"]) - else: - raise ValueError(f"Could not find valid `align` checkpoint at {pretrained_checkpoint}!") - - def get_fsdp_wrapping_policy(self) -> Callable: - """Return an FSDP _or_policy over the policies returned by each individual backbone (and our VLM policy).""" - vision_fsdp_wrapping_policy = self.vision_backbone.get_fsdp_wrapping_policy() - llm_fsdp_wrapping_policy = self.llm_backbone.get_fsdp_wrapping_policy() - - # Get Prismatic Wrapping Policy =>> just a module wrapping policy around `self.projector` - prismatic_fsdp_wrapping_policy = partial( - _module_wrap_policy, - module_classes={LinearProjector, MLPProjector, FusedMLPProjector}, - ) - - # Return union (_or_) over constituent policies - # => Note: there is *not* a fall-through policy; any module that isn't covered by the above constituents will - # automatically be folded into the root VLM FSDP instance. - return partial( - _or_policy, - policies=[ - vision_fsdp_wrapping_policy, - llm_fsdp_wrapping_policy, - prismatic_fsdp_wrapping_policy, - ], - ) - - # Note =>> We're not explicitly subclassing `PreTrainedModel` because we don't need the bloat; however, `forward()` - # *must* match the signature of a `{Model}ForCausalLM` so that we can inherit from `GenerationMixin` - - # ruff: noqa: C901 - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - multimodal_indices: Optional[torch.LongTensor] = None, - ) -> CausalLMOutputWithPast: - """Run a forward pass through the VLM, returning a CausalLMOutputWithPast instance (contains loss).""" - - # Handle Inference (leverage cache, short-circuit on just LLM forward) - if input_ids.shape[1] == 1 and past_key_values is not None: - # We're leveraging the cache, so just redirect to `self.llm_backbone` with `input_ids` and `past_key_values` - output = self.llm_backbone( - input_ids=input_ids, - attention_mask=None, - position_ids=None, - past_key_values=past_key_values, - inputs_embeds=None, - labels=None, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - return output - - elif input_ids.shape[1] == 1 or pixel_values is None: - raise RuntimeError("Invalid `forward()` call!") - - # Handle Multimodal Indices is None --> pretend like the batch is fully multimodal (always image + text)! - if multimodal_indices is None: - multimodal_indices = torch.arange(len(input_ids), dtype=torch.long, device=input_ids.device) - - # Handle Multimodal Indices is Empty (len == 0) --> simple unimodal forward - elif len(multimodal_indices) == 0: - return self.llm_backbone( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=None, - past_key_values=past_key_values, - inputs_embeds=None, - labels=labels, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - # Run Visual Feature Extraction - with torch.set_grad_enabled(self.vision_backbone_requires_grad): - if isinstance(pixel_values, dict): - patch_features = self.vision_backbone({k: pixel_values[k][multimodal_indices] for k in pixel_values}) - else: - patch_features = self.vision_backbone(pixel_values[multimodal_indices]) - - # Projection Logic :: [bsz, num_patches, llm_embed_dim] =>> num_patches = (2 *) (256 + 1) for ViT-L + CLS - projected_patch_embeddings = self.projector(patch_features) - projected_patch_attention_mask = None - if attention_mask is not None: - projected_patch_attention_mask = torch.full( - (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]), - True, - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - - # Get Input Embeddings from LLM Backbone :: [bsz, input_seq_len, llm_embed_dim] - input_embeddings = self.llm_backbone.embed_input_ids(input_ids) - - # Build Multimodal Embeddings (and build resulting attention mask) - multimodal_embeddings = torch.cat( - [ - input_embeddings[multimodal_indices, :1, :], - projected_patch_embeddings, - input_embeddings[multimodal_indices, 1:, :], - ], - dim=1, - ) - multimodal_attention_mask = None - if attention_mask is not None: - multimodal_attention_mask = torch.cat( - [ - attention_mask[multimodal_indices, :1], - projected_patch_attention_mask, - attention_mask[multimodal_indices, 1:], - ], - dim=1, - ) - - # [Contract] We assume the first token of `labels` (associated with ) is already marked as "IGNORE" - # => We'll ignore the per-token outputs for each of the patch embeddings as well! - multimodal_labels = None - if labels is not None: - projected_patch_labels = torch.full( - (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]), - IGNORE_INDEX, - dtype=labels.dtype, - device=labels.device, - ) - multimodal_labels = torch.cat( - [labels[multimodal_indices, :1], projected_patch_labels, labels[multimodal_indices, 1:]], dim=1 - ) - - # === Add Unimodal Handling === - - # Create Fused Embeddings, Attention Mask, and Labels by Merging with "unimodal" Inputs (if applicable) - unimodal_indices = torch.tensor( - [idx for idx in range(len(input_ids)) if idx not in multimodal_indices], - dtype=torch.long, - device=multimodal_indices.device, - ) - - # No "unimodal" data --> Fused == Multimodal - if len(unimodal_indices) == 0: - fused_embeddings = multimodal_embeddings - fused_attention_mask = multimodal_attention_mask - fused_labels = multimodal_labels - - else: - # Otherwise --> Merge w/ unimodal data - - # This doesn't matter --> but in the "normal" case this is the embedding of the token - # => NOTE :: Verified that `zeros/randn/empty/ embedding` all return the same result! - unimodal_embeddings_pad = torch.zeros( - (len(unimodal_indices), projected_patch_embeddings.shape[1], input_embeddings.shape[2]), - dtype=input_embeddings.dtype, - device=input_embeddings.device, - ) - unimodal_attention_pad = torch.full( - (len(unimodal_indices), projected_patch_embeddings.shape[1]), - False, - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - unimodal_labels_pad = torch.full( - (len(unimodal_indices), projected_patch_embeddings.shape[1]), - IGNORE_INDEX, - dtype=labels.dtype, - device=labels.device, - ) - - unimodal_embeddings = torch.cat([input_embeddings[unimodal_indices], unimodal_embeddings_pad], dim=1) - unimodal_attention_mask = torch.cat([attention_mask[unimodal_indices], unimodal_attention_pad], dim=1) - unimodal_labels = torch.cat([labels[unimodal_indices], unimodal_labels_pad], dim=1) - - # Create "Fused" Tensors by Stacking Multimodal & Unimodal - fused_embeddings = torch.vstack([multimodal_embeddings, unimodal_embeddings]) - fused_attention_mask = torch.vstack([multimodal_attention_mask, unimodal_attention_mask]) - fused_labels = torch.vstack([multimodal_labels, unimodal_labels]) - - # Run LLM Forward --> returns CausalLMOutputWithPast! - return self.llm_backbone( - input_ids=None, - attention_mask=fused_attention_mask, - position_ids=None, - past_key_values=past_key_values, - inputs_embeds=fused_embeddings, - labels=fused_labels, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - # === GenerationMixin Methods === - # => Note: The following methods override the functionality of `transformers.GenerationMixin`; these expect the - # contract in each of the function signatures, and also expect our `forward` function to roughly take - # the same arguments as the underlying LLM (see `LlamaModelForCausalLM` as an example) - - def prepare_inputs_for_generation( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - use_cache: Optional[bool] = None, - **kwargs: torch.Tensor, - ) -> Dict[str, torch.Tensor]: - """Borrowed from `LlamaForCausalLM` --> in general, just handles caching logic during generation.""" - if past_key_values: - input_ids = input_ids[:, -1:] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - # Make sure `pixel_values` are preserved in `model_inputs` - model_inputs.update( - { - "attention_mask": attention_mask, - "pixel_values": pixel_values, - "past_key_values": past_key_values, - "use_cache": use_cache, - } - ) - - return model_inputs - - @torch.inference_mode() - def generate_batch( - self, - pixel_values: Union[torch.Tensor, Dict[str, torch.Tensor]], - texts: List[str], - return_string_probabilities: Optional[List[str]] = None, - **kwargs: str, - ) -> Union[List[str], List[List[float]]]: - # For now, only support generation with a batch size of 1 for simplicity - tokenizer = self.llm_backbone.tokenizer - - # Prepare Inputs - batch_input_ids = [ - tokenizer(text, truncation=True, return_tensors="pt").input_ids.to(self.device) for text in texts - ] - if isinstance(pixel_values, torch.Tensor): - pixel_values = pixel_values[None, ...].to(self.device) - elif isinstance(pixel_values, dict): - pixel_values = {k: v[None, ...].to(self.device) for k, v in pixel_values.items()} - else: - raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}") - - # Create Output Lists - gen_texts, gen_probabilities = [], [] - - # Invoke super().generate --> taps into `GenerationMixin` which (redirects) to `forward()` - autocast_dtype = self.llm_backbone.half_precision_dtype - with torch.autocast("cuda", dtype=autocast_dtype, enabled=self.enable_mixed_precision_training): - for idx, input_ids in enumerate(batch_input_ids): - if isinstance(pixel_values, torch.Tensor): - pixel_values = pixel_values[idx] - elif isinstance(pixel_values, dict): - pixel_values = {k: pixel_values[k][idx] for k in pixel_values} - else: - raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}") - - # Handle `return_string_probabilities` - if return_string_probabilities is None: - full_out_ids = super().generate(input_ids=input_ids, pixel_values=pixel_values, **kwargs) - gen_ids = full_out_ids[0, input_ids.shape[1] :] - - # Decode `gen_ids` and strip any tokens - gen_texts.append(tokenizer.decode(gen_ids, skip_special_tokens=True).strip()) - - else: - full_out_dict = super().generate( - input_ids=input_ids, - pixel_values=pixel_values, - output_scores=True, - return_dict_in_generate=True, - **kwargs, - ) - - # Generation pattern should usually be [TOKEN] for True/False and Yes/No Generations - gen_ids = full_out_dict.sequences[0, input_ids.shape[1] :] - - # [Debug] Verify that the first token generated is in `self.string2idx.values()` - # assert gen_ids[0] in self.string2idx.values(), "Generated ID not in mapping!" - - # Decode `gen_ids` and strip any tokens - gen_texts.append(tokenizer.decode(gen_ids, skip_special_tokens=True).strip()) - - # Get all token probabilities --> softmax over logits - token_probs = torch.softmax(full_out_dict.scores[0][0], dim=0) - - # Get *normalized* probabilities for all values in `return_token_probabilities` - slice_idxs = torch.tensor([self.string2idx[s] for s in return_string_probabilities]) - string_probs_unnormalized = token_probs[slice_idxs] - string_probs = string_probs_unnormalized / string_probs_unnormalized.sum() - gen_probabilities.append(string_probs.cpu().numpy().tolist()) - - return gen_texts if return_string_probabilities is None else gen_probabilities - - @torch.inference_mode() - def generate(self, image: Image, prompt_text: str, **kwargs: str) -> str: - # For now, only support generation with a batch size of 1 for simplicity - image_transform, tokenizer = self.vision_backbone.image_transform, self.llm_backbone.tokenizer - - # Prepare Inputs - input_ids = tokenizer(prompt_text, truncation=True, return_tensors="pt").input_ids.to(self.device) - pixel_values = image_transform(image) - if isinstance(pixel_values, torch.Tensor): - pixel_values = pixel_values[None, ...].to(self.device) - elif isinstance(pixel_values, dict): - pixel_values = {k: v[None, ...].to(self.device) for k, v in pixel_values.items()} - else: - raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}") - - # Invoke super().generate --> taps into `GenerationMixin` which (redirects) to `forward()` - autocast_dtype = self.llm_backbone.half_precision_dtype - with torch.autocast("cuda", dtype=autocast_dtype, enabled=self.enable_mixed_precision_training): - # fmt: off - generated_ids = super().generate( - input_ids=input_ids, # Shape: [1, seq] - pixel_values=pixel_values, # Shape: [1, 3, res, res] or Dict[str, Shape[1, 3, res, res]] - **kwargs - ) - # fmt: on - - generated_text = tokenizer.decode(generated_ids[0, input_ids.shape[1] :], skip_special_tokens=True).strip() - - return generated_text diff --git a/capvector-oft/prismatic/overwatch/__init__.py b/capvector-oft/prismatic/overwatch/__init__.py deleted file mode 100644 index 157cd648c6b711bc24f59ea2b356b1c0816a1c11..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/overwatch/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .overwatch import initialize_overwatch diff --git a/capvector-oft/prismatic/overwatch/overwatch.py b/capvector-oft/prismatic/overwatch/overwatch.py deleted file mode 100644 index 2e72048bddbf9a4b622b736ba7d25b6c70ac9a04..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/overwatch/overwatch.py +++ /dev/null @@ -1,147 +0,0 @@ -""" -overwatch.py - -Utility class for creating a centralized/standardized logger (built on Rich) and accelerate handler. -""" - -import logging -import logging.config -import os -from contextlib import nullcontext -from logging import LoggerAdapter -from typing import Any, Callable, ClassVar, Dict, MutableMapping, Tuple, Union - -# Overwatch Default Format String -RICH_FORMATTER, DATEFMT = "| >> %(message)s", "%m/%d [%H:%M:%S]" - -# Set Logging Configuration -LOG_CONFIG = { - "version": 1, - "disable_existing_loggers": True, - "formatters": {"simple-console": {"format": RICH_FORMATTER, "datefmt": DATEFMT}}, - "handlers": { - "console": { - "class": "rich.logging.RichHandler", - "formatter": "simple-console", - "markup": True, - "rich_tracebacks": True, - "show_level": True, - "show_path": True, - "show_time": True, - } - }, - "root": {"level": "INFO", "handlers": ["console"]}, -} -logging.config.dictConfig(LOG_CONFIG) - - -# === Custom Contextual Logging Logic === -class ContextAdapter(LoggerAdapter): - CTX_PREFIXES: ClassVar[Dict[int, str]] = {**{0: "[*] "}, **{idx: "|=> ".rjust(4 + (idx * 4)) for idx in [1, 2, 3]}} - - def process(self, msg: str, kwargs: MutableMapping[str, Any]) -> Tuple[str, MutableMapping[str, Any]]: - ctx_level = kwargs.pop("ctx_level", 0) - return f"{self.CTX_PREFIXES[ctx_level]}{msg}", kwargs - - -class DistributedOverwatch: - def __init__(self, name: str) -> None: - """Initializer for an Overwatch object that wraps logging & `accelerate.PartialState`.""" - from accelerate import PartialState - - # Note that PartialState is always safe to initialize regardless of `accelerate launch` or `torchrun` - # =>> However, might be worth actually figuring out if we need the `accelerate` dependency at all! - self.logger, self.distributed_state = ContextAdapter(logging.getLogger(name), extra={}), PartialState() - - # Logger Delegation (for convenience; would be nice to just compose & dynamic dispatch eventually) - self.debug = self.logger.debug - self.info = self.logger.info - self.warning = self.logger.warning - self.error = self.logger.error - self.critical = self.logger.critical - - # Logging Defaults =>> only Log `INFO` on Main Process, `ERROR` on others! - self.logger.setLevel(logging.INFO if self.distributed_state.is_main_process else logging.ERROR) - - @property - def rank_zero_only(self) -> Callable[..., Any]: - return self.distributed_state.on_main_process - - @property - def local_zero_only(self) -> Callable[..., Any]: - return self.distributed_state.on_local_main_process - - @property - def rank_zero_first(self) -> Callable[..., Any]: - return self.distributed_state.main_process_first - - @property - def local_zero_first(self) -> Callable[..., Any]: - return self.distributed_state.local_main_process_first - - def is_rank_zero(self) -> bool: - return self.distributed_state.is_main_process - - def rank(self) -> int: - return self.distributed_state.process_index - - def local_rank(self) -> int: - return self.distributed_state.local_process_index - - def world_size(self) -> int: - return self.distributed_state.num_processes - - -class PureOverwatch: - def __init__(self, name: str) -> None: - """Initializer for an Overwatch object that just wraps logging.""" - self.logger = ContextAdapter(logging.getLogger(name), extra={}) - - # Logger Delegation (for convenience; would be nice to just compose & dynamic dispatch eventually) - self.debug = self.logger.debug - self.info = self.logger.info - self.warning = self.logger.warning - self.error = self.logger.error - self.critical = self.logger.critical - - # Logging Defaults =>> INFO - self.logger.setLevel(logging.INFO) - - @staticmethod - def get_identity_ctx() -> Callable[..., Any]: - def identity(fn: Callable[..., Any]) -> Callable[..., Any]: - return fn - - return identity - - @property - def rank_zero_only(self) -> Callable[..., Any]: - return self.get_identity_ctx() - - @property - def local_zero_only(self) -> Callable[..., Any]: - return self.get_identity_ctx() - - @property - def rank_zero_first(self) -> Callable[..., Any]: - return nullcontext - - @property - def local_zero_first(self) -> Callable[..., Any]: - return nullcontext - - @staticmethod - def is_rank_zero() -> bool: - return True - - @staticmethod - def rank() -> int: - return 0 - - @staticmethod - def world_size() -> int: - return 1 - - -def initialize_overwatch(name: str) -> Union[DistributedOverwatch, PureOverwatch]: - return DistributedOverwatch(name) if int(os.environ.get("WORLD_SIZE", -1)) != -1 else PureOverwatch(name) diff --git a/capvector-oft/prismatic/preprocessing/__init__.py b/capvector-oft/prismatic/preprocessing/__init__.py deleted file mode 100644 index 5b3a1dcb91afb745463f05da3fe7a547f030a4e7..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/preprocessing/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .download import convert_to_jpg, download_extract -from .materialize import get_dataset_and_collator diff --git a/capvector-oft/prismatic/preprocessing/datasets/__init__.py b/capvector-oft/prismatic/preprocessing/datasets/__init__.py deleted file mode 100644 index 04f5fe1a8dd8122308f1a5f707b48e7cc87a2311..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/preprocessing/datasets/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .datasets import AlignDataset, FinetuneDataset diff --git a/capvector-oft/prismatic/preprocessing/datasets/datasets.py b/capvector-oft/prismatic/preprocessing/datasets/datasets.py deleted file mode 100644 index be86002805411363db96840878e039b86923f109..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/preprocessing/datasets/datasets.py +++ /dev/null @@ -1,200 +0,0 @@ -""" -datasets.py - -PyTorch Dataset Definitions for Prismatic models; supports processing for both the `align` and `finetune` stages, with -utilities for formatting conversations during the `finetune` stage subject to the given LLM backbone's expected -formatting (e.g., SYS_PROMPT + USER: ... ASSISTANT: ... for Vicuña v1.5 Chat models). - -We currently only support Map-style Datasets; assumes that all files (annotations, images) are on local disk, and that -random access image reading is relatively cheap/fast. -""" - -import copy -import json -from pathlib import Path -from typing import Dict, List, Tuple, Type - -import torch -from PIL import Image -from torch.utils.data import Dataset -from transformers import CodeGenTokenizerFast, LlamaTokenizerFast, PreTrainedTokenizerBase - -from prismatic.models.backbones.llm.prompting import PromptBuilder -from prismatic.models.backbones.vision import ImageTransform - -# HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels) -IGNORE_INDEX = -100 - - -class AlignDataset(Dataset[Dict[str, torch.Tensor]]): - def __init__( - self, - chat_json: Path, - image_dir: Path, - image_transform: ImageTransform, - tokenizer: PreTrainedTokenizerBase, - ) -> None: - super().__init__() - self.chat_json, self.image_dir = chat_json, image_dir - self.image_transform, self.tokenizer = image_transform, tokenizer - self.dataset_type = "align" - - # Create Prompt Template - self.prompt_template = "{caption}" + self.tokenizer.eos_token - - # Load Chat JSON - with open(self.chat_json, "r") as f: - self.examples = json.load(f) - - def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: - """ - Following the *actual* code executed from the LLaVa codebase, during the "align" phase, we actually discard - the "prompt" from the human, and instead directly predict the caption from the image. - - As a concrete example given the "raw data" for the first example: - example = self.examples[0]["conversations"]` = { - [ - {"from": "human", "value": "Render a clear and concise summary of the photo.\n"}, - {"from": "gpt", "value": "select luxury furniture 3 - inch gel memory foam mattress topper"} - ] - } - - Return =>> self.tokenizer(" select luxury furniture 3 - inch gel memory foam mattress topper\n") - - :param idx: Index to retrieve from the dataset. - - :return: Dictionary of {"pixel_values": torch.Tensor, "input_ids": torch.Tensor, "labels": torch.Tensor} - """ - image_path, conversation = Path(self.examples[idx]["image"]), self.examples[idx]["conversations"] - assert (len(conversation) == 2) and ("" not in conversation[-1]["value"]), "Unexpected text!" - - # Format Caption --> {caption}{eos_token} - caption = self.prompt_template.format(caption=conversation[-1]["value"].strip()) - - # We treat image patches as "tokens = [p1 p2 p3, ...]"; we need to specify ordering of text/patch tokens. - # => Critically, we find that inserting *after* the BOS token leads to the strongest performance! - # - input_ids = " p1 p2 p3 ... \n" - # - labels = "IGNORE IGNORE ..." (copy `input_ids` replacing and p{1...K} with IGNORE) - # - # IMPORTANT => IF WE'RE USING HF LLM.forward(... labels=labels), SHIFTING HAPPENS _INSIDE_ MODEL! - input_ids = self.tokenizer(caption, truncation=True, return_tensors="pt").input_ids[0] - labels = copy.deepcopy(input_ids) - - # Set the token's label to IGNORE_INDEX (since we're inserting the image patches right after) - labels[0] = IGNORE_INDEX - - # Process Image --> get "pixel_values" (will either be a torch.Tensor OR a Dict[str,torch.Tensor]) - pixel_values = self.image_transform(Image.open(self.image_dir / image_path).convert("RGB")) - - return dict(pixel_values=pixel_values, input_ids=input_ids, labels=labels) - - def get_modality_lengths(self, n_image_patches: int) -> List[Tuple[bool, int]]: - """Get a list of modalities (unimodal / text-only vs. multimodal) and length of conversations per example.""" - modality_lengths = [] - for example in self.examples: - is_multimodal = "image" in example - n_words = sum([len(turn["value"].replace("", "").split()) for turn in example["conversations"]]) - modality_lengths.append((is_multimodal, (n_image_patches + n_words) if is_multimodal else n_words)) - return modality_lengths - - def __len__(self) -> int: - return len(self.examples) - - -class FinetuneDataset(Dataset[Dict[str, torch.Tensor]]): - def __init__( - self, - instruct_json: Path, - image_dir: Path, - image_transform: ImageTransform, - tokenizer: PreTrainedTokenizerBase, - prompt_builder_fn: Type[PromptBuilder], - ) -> None: - super().__init__() - self.instruct_json, self.image_dir = instruct_json, image_dir - self.image_transform, self.tokenizer = image_transform, tokenizer - self.prompt_builder_fn = prompt_builder_fn - self.dataset_type = "finetune" - - # Load Instruct JSON - with open(self.instruct_json, "r") as f: - self.examples = json.load(f) - - # === Unimodal + Multimodal Handling === - def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: - """ - Unlike the *align* stage handling, for the *finetune* stage, we actually need to handle multiple "turns" of - dialog grounded in a single image. - - To do this, we leverage the `prompt_builder_fn` which instantiates a PromptBuilder object. By calling the - methods for adding turns and getting a prompt, we ensure proper formatting and consistency for each example. - - :param idx: Index to retrieve from the dataset. - - :return: Dictionary of {"pixel_values": torch.Tensor, "input_ids": torch.Tensor, "labels": torch.Tensor} - """ - conversation = self.examples[idx]["conversations"] - - # Create Prompt Builder --> add each message sequentially - prompt_builder, input_ids, labels = self.prompt_builder_fn(model_family="prismatic"), [], [] - for turn_idx, turn in enumerate(conversation): - # Get "effective" string added to prompt --> handle whitespace for tokenizer type! - msg = prompt_builder.add_turn(turn["from"], turn["value"]) - - # Llama Tokenizer (Fast) adds extra character if a string ends in whitespace --> strip if non-empty! - if isinstance(self.tokenizer, LlamaTokenizerFast): - msg = msg.rstrip() - - # Phi-2 Tokenizer == CodeGenTokenizer (Fast) -- no special handling! - elif isinstance(self.tokenizer, CodeGenTokenizerFast): - pass - - else: - raise ValueError(f"Tokenizer of type `{type(self.tokenizer)}` is not explicitly handled!") - - # Tokenize Input IDs - turn_input_ids = self.tokenizer(msg, add_special_tokens=turn_idx == 0).input_ids - - # [CRITICAL] We do not want to take the loss for the "USER: " prompts =>> just the responses! - turn_labels = ( - [IGNORE_INDEX for _ in range(len(turn_input_ids))] if (turn_idx % 2) == 0 else list(turn_input_ids) - ) - - # Add to Trackers - input_ids.extend(turn_input_ids) - labels.extend(turn_labels) - - # Tensorize =>> Set the token's label to IGNORE_INDEX (since we're inserting the image patches after) - # - IMPORTANT => IF WE'RE USING HF LLM.forward(... labels=labels), SHIFTING HAPPENS _INSIDE_ MODEL! - input_ids, labels = torch.tensor(input_ids), torch.tensor(labels) - - # Handle Truncation (if necessary) - input_ids, labels = input_ids[: self.tokenizer.model_max_length], labels[: self.tokenizer.model_max_length] - - # === Handle "unimodal" (language-only) vs. "multimodal" === - if "image" in self.examples[idx]: - image_path = Path(self.examples[idx]["image"]) - - # Set the token's label to IGNORE_INDEX (since we're inserting the image patches right after) - labels[0] = IGNORE_INDEX - - # Process Image --> get "pixel_values" (will either be a torch.Tensor OR a Dict[str,torch.Tensor]) - pixel_values = self.image_transform(Image.open(self.image_dir / image_path).convert("RGB")) - - return dict(pixel_values=pixel_values, input_ids=input_ids, labels=labels) - - else: - # No image --> return `pixel_values` = None; Collator will do the smart batch handling for us! - return dict(pixel_values=None, input_ids=input_ids, labels=labels) - - def get_modality_lengths(self) -> List[Tuple[bool, int]]: - """Get a list of modalities (unimodal / text-only vs. multimodal) and length of conversations per example.""" - modality_lengths = [] - for example in self.examples: - is_multimodal = "image" in example - n_words = sum([len(turn["value"].split()) for turn in example["conversations"]]) - modality_lengths.append((is_multimodal, n_words)) - return modality_lengths - - def __len__(self) -> int: - return len(self.examples) diff --git a/capvector-oft/prismatic/preprocessing/download.py b/capvector-oft/prismatic/preprocessing/download.py deleted file mode 100644 index 300bc0f42e311a33927789ad48a3e143a255aa5f..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/preprocessing/download.py +++ /dev/null @@ -1,207 +0,0 @@ -""" -download.py - -Utility functions for downloading and extracting various datasets to (local) disk. -""" - -import os -import shutil -from pathlib import Path -from typing import Dict, List, TypedDict -from zipfile import ZipFile - -import requests -from PIL import Image -from rich.progress import BarColumn, DownloadColumn, MofNCompleteColumn, Progress, TextColumn, TransferSpeedColumn -from tqdm import tqdm - -from prismatic.overwatch import initialize_overwatch - -# Initialize Overwatch =>> Wraps `logging.Logger` -overwatch = initialize_overwatch(__name__) - - -# === Dataset Registry w/ Links === -# fmt: off -DatasetComponent = TypedDict( - "DatasetComponent", - {"name": str, "extract": bool, "extract_type": str, "url": str, "do_rename": bool}, - total=False -) - -DATASET_REGISTRY: Dict[str, List[DatasetComponent]] = { - # === LLaVa v1.5 Dataset(s) === - - # Note =>> This is the full suite of datasets included in the LLaVa 1.5 "finetuning" stage; all the LLaVa v1.5 - # models are finetuned on this split. We use this dataset for all experiments in our paper. - "llava-laion-cc-sbu-558k": [ - { - "name": "chat.json", # Contains the "chat" traces :: {"human" => , "gpt" => } - "extract": False, - "url": "https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/resolve/main/blip_laion_cc_sbu_558k.json", - "do_rename": True, - }, - { - "name": "images", # Contains the LLaVa Processed Images (jpgs, 224x224 resolution) - "extract": True, - "extract_type": "directory", - "url": "https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/resolve/main/images.zip", - "do_rename": False, - } - ], - - "llava-v1.5-instruct": [ - { - "name": "llava_v1_5_mix665k.json", - "extract": False, - "url": ( - "https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/resolve/main/llava_v1_5_mix665k.json" - ), - "do_rename": True, - }, - { - "name": "coco/train2017", # Visual Instruct Tuning images are all sourced from COCO Train 2017 - "extract": True, - "extract_type": "directory", - "url": "http://images.cocodataset.org/zips/train2017.zip", - "do_rename": True, - }, - { - "name": "gqa/images", - "extract": True, - "extract_type": "directory", - "url": "https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip", - "do_rename": True, - }, - { - "name": "ocr_vqa/images", - "extract": True, - "extract_type": "directory", - "url": "https://huggingface.co/datasets/qnguyen3/ocr_vqa/resolve/main/ocr_vqa.zip", - "do_rename": True, - }, - { - "name": "textvqa/train_images", - "extract": True, - "extract_type": "directory", - "url": "https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip", - "do_rename": True, - }, - { - "name": "vg/VG_100K", - "extract": True, - "extract_type": "directory", - "url": "https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip", - "do_rename": True, - }, - { - "name": "vg/VG_100K_2", - "extract": True, - "extract_type": "directory", - "url": "https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip", - "do_rename": True, - }, - ] -} -# fmt: on - - -def convert_to_jpg(image_dir: Path) -> None: - """Handling for OCR-VQA Images specifically; iterates through directory, converts all GIFs/PNGs.""" - overwatch.info(f"Converting all Images in `{image_dir}` to JPG") - - for image_fn in tqdm(list(image_dir.iterdir())): - if image_fn.suffix in {".jpg", ".jpeg"} or (jpg_fn := image_dir / f"{image_fn.stem}.jpg").exists(): - continue - - if image_fn.suffix == ".gif": - gif = Image.open(image_fn) - gif.seek(0) - gif.convert("RGB").save(jpg_fn) - elif image_fn.suffix == ".png": - Image.open(image_fn).convert("RGB").save(jpg_fn) - else: - raise ValueError(f"Unexpected image format `{image_fn.suffix}`") - - -def download_with_progress(url: str, download_dir: Path, chunk_size_bytes: int = 1024) -> Path: - """Utility function for downloading files from the internet, with a handy Rich-based progress bar.""" - overwatch.info(f"Downloading {(dest_path := download_dir / Path(url).name)} from `{url}`", ctx_level=1) - if dest_path.exists(): - return dest_path - - # Otherwise --> fire an HTTP Request, with `stream = True` - response = requests.get(url, stream=True) - - # Download w/ Transfer-Aware Progress - # => Reference: https://github.com/Textualize/rich/blob/master/examples/downloader.py - with Progress( - TextColumn("[bold]{task.description} - {task.fields[fname]}"), - BarColumn(bar_width=None), - "[progress.percentage]{task.percentage:>3.1f}%", - "•", - DownloadColumn(), - "•", - TransferSpeedColumn(), - transient=True, - ) as dl_progress: - dl_tid = dl_progress.add_task( - "Downloading", fname=dest_path.name, total=int(response.headers.get("content-length", "None")) - ) - with open(dest_path, "wb") as f: - for data in response.iter_content(chunk_size=chunk_size_bytes): - dl_progress.advance(dl_tid, f.write(data)) - - return dest_path - - -def extract_with_progress(archive_path: Path, download_dir: Path, extract_type: str, cleanup: bool = False) -> Path: - """Utility function for extracting compressed archives, with a handy Rich-based progress bar.""" - assert archive_path.suffix == ".zip", "Only `.zip` compressed archives are supported for now!" - overwatch.info(f"Extracting {archive_path.name} to `{download_dir}`", ctx_level=1) - - # Extract w/ Progress - with Progress( - TextColumn("[bold]{task.description} - {task.fields[aname]}"), - BarColumn(bar_width=None), - "[progress.percentage]{task.percentage:>3.1f}%", - "•", - MofNCompleteColumn(), - transient=True, - ) as ext_progress: - with ZipFile(archive_path) as zf: - ext_tid = ext_progress.add_task("Extracting", aname=archive_path.name, total=len(members := zf.infolist())) - extract_path = Path(zf.extract(members[0], download_dir)) - if extract_type == "file": - assert len(members) == 1, f"Archive `{archive_path}` with extract type `{extract_type} has > 1 member!" - elif extract_type == "directory": - for member in members[1:]: - zf.extract(member, download_dir) - ext_progress.advance(ext_tid) - else: - raise ValueError(f"Extract type `{extract_type}` for archive `{archive_path}` is not defined!") - - # Cleanup (if specified) - if cleanup: - archive_path.unlink() - - return extract_path - - -def download_extract(dataset_id: str, root_dir: Path) -> None: - """Download all files for a given dataset (querying registry above), extracting archives if necessary.""" - os.makedirs(download_dir := root_dir / "download" / dataset_id, exist_ok=True) - - # Download Files => Single-Threaded, with Progress Bar - dl_tasks = [d for d in DATASET_REGISTRY[dataset_id] if not (download_dir / d["name"]).exists()] - for dl_task in dl_tasks: - dl_path = download_with_progress(dl_task["url"], download_dir) - - # Extract Files (if specified) --> Note (assumes ".zip" ONLY!) - if dl_task["extract"]: - dl_path = extract_with_progress(dl_path, download_dir, dl_task["extract_type"]) - dl_path = dl_path.parent if dl_path.is_file() else dl_path - - # Rename Path --> dl_task["name"] - if dl_task["do_rename"]: - shutil.move(dl_path, download_dir / dl_task["name"]) diff --git a/capvector-oft/prismatic/preprocessing/materialize.py b/capvector-oft/prismatic/preprocessing/materialize.py deleted file mode 100644 index b6605825448e95a8dc30825db2a1e31e9b46efc3..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/preprocessing/materialize.py +++ /dev/null @@ -1,69 +0,0 @@ -""" -materialize.py - -Factory class for initializing pretraining datasets on a per-VLM basis; provides and exports individual functions for -clear control flow. -""" - -from typing import Tuple, Type - -from torch.utils.data import Dataset -from transformers import PreTrainedTokenizerBase - -from prismatic.conf import DatasetConfig -from prismatic.models.backbones.llm.prompting import PromptBuilder -from prismatic.models.backbones.vision import ImageTransform -from prismatic.preprocessing.datasets import AlignDataset, FinetuneDataset -from prismatic.util.data_utils import PaddedCollatorForLanguageModeling - -# Dataset Initializers =>> Maps Stage --> cls() -DATASET_INITIALIZER = {"align": AlignDataset, "finetune": FinetuneDataset, "full-finetune": FinetuneDataset} - - -def get_dataset_and_collator( - stage: str, - dataset_cfg: DatasetConfig, - image_transform: ImageTransform, - tokenizer: PreTrainedTokenizerBase, - prompt_builder_fn: Type[PromptBuilder], - default_image_resolution: Tuple[int, int, int], - padding_side: str = "right", -) -> Tuple[Dataset, PaddedCollatorForLanguageModeling]: - dataset_cls = DATASET_INITIALIZER[stage] - dataset_root_dir = dataset_cfg.dataset_root_dir - collator = PaddedCollatorForLanguageModeling( - tokenizer.model_max_length, tokenizer.pad_token_id, default_image_resolution, padding_side=padding_side - ) - - # Switch on `stage` - if stage == "align": - annotation_json, image_dir = dataset_cfg.align_stage_components - dataset = dataset_cls( - dataset_root_dir / annotation_json, dataset_root_dir / image_dir, image_transform, tokenizer - ) - return dataset, collator - - elif stage == "finetune": - annotation_json, image_dir = dataset_cfg.finetune_stage_components - dataset = dataset_cls( - dataset_root_dir / annotation_json, - dataset_root_dir / image_dir, - image_transform, - tokenizer, - prompt_builder_fn=prompt_builder_fn, - ) - return dataset, collator - - elif stage == "full-finetune": - annotation_json, image_dir = dataset_cfg.finetune_stage_components - dataset = dataset_cls( - dataset_root_dir / annotation_json, - dataset_root_dir / image_dir, - image_transform, - tokenizer, - prompt_builder_fn=prompt_builder_fn, - ) - return dataset, collator - - else: - raise ValueError(f"Stage `{stage}` is not supported!") diff --git a/capvector-oft/prismatic/py.typed b/capvector-oft/prismatic/py.typed deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/capvector-oft/prismatic/training/__init__.py b/capvector-oft/prismatic/training/__init__.py deleted file mode 100644 index c66f906fadd8a0fb31acdb0f7f86f6c393dc68ba..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/training/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .materialize import get_train_strategy -from .metrics import Metrics, VLAMetrics diff --git a/capvector-oft/prismatic/training/materialize.py b/capvector-oft/prismatic/training/materialize.py deleted file mode 100644 index 5fefd9fdec35b5df0d66d0669a8b625aea64e7ac..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/training/materialize.py +++ /dev/null @@ -1,66 +0,0 @@ -""" -materialize.py - -Factory class defining functions for instantiating various Training Strategies, supporting different VLMs, backbones, -and strategy configurations. -""" - -from typing import Callable, Optional - -import torch - -from prismatic.models.vlms import PrismaticVLM -from prismatic.training.strategies import FSDPStrategy, TrainingStrategy - -# Registry =>> Maps ID --> {cls(), kwargs} :: supports FSDP for now, but DDP handler is also implemented! -TRAIN_STRATEGIES = { - "fsdp-shard-grad-op": {"cls": FSDPStrategy, "kwargs": {"sharding_strategy": "shard-grad-op"}}, - "fsdp-full-shard": {"cls": FSDPStrategy, "kwargs": {"sharding_strategy": "full-shard"}}, -} - - -def get_train_strategy( - train_strategy: str, - vlm: PrismaticVLM, - device_id: int, - stage: str, - epochs: int, - max_steps: Optional[int], - global_batch_size: int, - per_device_batch_size: int, - learning_rate: float, - weight_decay: float, - max_grad_norm: float, - lr_scheduler_type: str, - warmup_ratio: float, - enable_gradient_checkpointing: bool = True, - enable_mixed_precision_training: bool = True, - reduce_in_full_precision: bool = False, - mixed_precision_dtype: torch.dtype = torch.bfloat16, - worker_init_fn: Optional[Callable[[int], None]] = None, -) -> TrainingStrategy: - if train_strategy in TRAIN_STRATEGIES: - strategy_cfg = TRAIN_STRATEGIES[train_strategy] - strategy = strategy_cfg["cls"]( - vlm=vlm, - device_id=device_id, - stage=stage, - epochs=epochs, - max_steps=max_steps, - global_batch_size=global_batch_size, - per_device_batch_size=per_device_batch_size, - learning_rate=learning_rate, - weight_decay=weight_decay, - max_grad_norm=max_grad_norm, - lr_scheduler_type=lr_scheduler_type, - warmup_ratio=warmup_ratio, - enable_gradient_checkpointing=enable_gradient_checkpointing, - enable_mixed_precision_training=enable_mixed_precision_training, - reduce_in_full_precision=reduce_in_full_precision, - mixed_precision_dtype=mixed_precision_dtype, - worker_init_fn=worker_init_fn, - **strategy_cfg["kwargs"], - ) - return strategy - else: - raise ValueError(f"Train Strategy `{train_strategy}` is not supported!") diff --git a/capvector-oft/prismatic/training/metrics.py b/capvector-oft/prismatic/training/metrics.py deleted file mode 100644 index 6fcc78172e59f9efa6ce2ef4b87d1f19a027821a..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/training/metrics.py +++ /dev/null @@ -1,348 +0,0 @@ -""" -metrics.py - -Utility classes defining a Metrics container and multiple Trackers to enable model/stage-specific logging to various -endpoints (e.g., JSONL local logs, Weights & Biases). -""" - -import time -from collections import defaultdict, deque -from pathlib import Path -from typing import Any, Dict, Optional, Protocol, Tuple, Union - -import jsonlines -import numpy as np -import torch -import wandb - -from prismatic.overwatch import initialize_overwatch - -# Initialize Overwatch =>> Wraps `logging.Logger` -overwatch = initialize_overwatch(__name__) - - -# === Define Tracker Interface === -class Tracker(Protocol): - def write_hyperparameters(self) -> None: ... - - def write(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None: ... - - def finalize(self) -> None: ... - - -# === Individual Tracker Definitions === -class JSONLinesTracker: - def __init__(self, run_id: str, run_dir: Path, hparams: Dict[str, Any]) -> None: - self.run_id, self.run_dir, self.hparams = run_id, run_dir, hparams - - @overwatch.rank_zero_only - def write_hyperparameters(self) -> None: - with jsonlines.open(self.run_dir / "run-metrics.jsonl", mode="w", sort_keys=True) as js_tracker: - js_tracker.write({"run_id": self.run_id, "hparams": self.hparams}) - - @overwatch.rank_zero_only - def write(self, _: int, metrics: Dict[str, Union[int, float]]) -> None: - with jsonlines.open(self.run_dir / f"{self.run_id}.jsonl", mode="a", sort_keys=True) as js_tracker: - js_tracker.write(metrics) - - def finalize(self) -> None: - return - - -class WeightsBiasesTracker: - def __init__( - self, - run_id: str, - run_dir: Path, - hparams: Dict[str, Any], - project: str = "prismatic", - entity: Optional[str] = None, - group: str = "align", - ) -> None: - self.run_id, self.run_dir, self.hparams = run_id, run_dir, hparams - - # Get W&B-Specific Initialization Parameters - self.project, self.entity, self.group, self.wandb_dir = project, entity, group, self.run_dir - - # Call W&B.init() - self.initialize() - - @overwatch.rank_zero_only - def initialize(self) -> None: - wandb.init( - name=self.run_id, - dir=self.wandb_dir, - config=self.hparams, - project=self.project, - entity=self.entity, - group=self.group, - ) - - @overwatch.rank_zero_only - def write_hyperparameters(self) -> None: - wandb.config = self.hparams - - @overwatch.rank_zero_only - def write(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None: - wandb.log(metrics, step=global_step) - - @staticmethod - def finalize() -> None: - if overwatch.is_rank_zero(): - wandb.finish() - - # A job gets 210 seconds to get its affairs in order - time.sleep(210) - - -# === Core Metrics Container :: Initializes Trackers => Compiles/Pushes Metrics === - - -class Metrics: - def __init__( - self, - active_trackers: Tuple[str, ...], - run_id: str, - run_dir: Path, - hparams: Dict[str, Any], - stage: str, - wandb_project: str = "prismatic", - wandb_entity: Optional[str] = None, - grad_accumulation_steps: int = 1, - window_size: int = 128, - ) -> None: - self.run_id, self.run_dir, self.hparams, self.stage = run_id, run_dir, hparams, stage - - # Initialize Trackers - self.trackers = [] - for tracker_type in active_trackers: - if tracker_type == "jsonl": - tracker = JSONLinesTracker(run_id, run_dir, hparams) - elif tracker_type == "wandb": - tracker = WeightsBiasesTracker( - run_id, run_dir, hparams, project=wandb_project, entity=wandb_entity, group=self.stage - ) - else: - raise ValueError(f"Tracker with type `{tracker_type} is not supported!") - - # Add Hyperparameters --> add to `self.trackers` - tracker.write_hyperparameters() - self.trackers.append(tracker) - - # Create Universal Metrics Buffers - self.global_step, self.start_time, self.step_start_time = 0, time.time(), time.time() - self.state = { - "loss_raw": deque(maxlen=grad_accumulation_steps), - "loss": deque(maxlen=window_size), - "step_time": deque(maxlen=window_size), - "lr": [], - } - - def log(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None: - for tracker in self.trackers: - tracker.write(global_step, metrics) - - def get_status(self, loss: Optional[torch.Tensor] = None) -> str: - lr = self.state["lr"][-1] if len(self.state["lr"]) > 0 else 0 - if loss is None: - return f"=>> [Global Step] {self.global_step:06d} =>> LR :: {lr:.6f}" - - # Otherwise, embed `loss` in status report! - return f"=>> [Global Step] {self.global_step:06d} =>> LR :: {lr:.6f} -- Loss :: {loss:.4f}" - - def commit( - self, *, global_step: Optional[int] = None, lr: Optional[float] = None, update_step_time: bool = False, **kwargs - ) -> None: - """Update all metrics in `self.state` by iterating through special positional arguments & kwargs.""" - if global_step is not None: - self.global_step = global_step - - # For all other variables --> only track on rank zero! - if not overwatch.is_rank_zero(): - return - - # Special Positional Arguments - if lr is not None: - self.state["lr"].append(lr) - - if update_step_time: - self.state["step_time"].append(time.time() - self.step_start_time) - self.step_start_time = time.time() - - # Generic Keyword Arguments - for key, value in kwargs.items(): - if key == "loss": - loss_val = value.detach() - self.state["loss_raw"].append(loss_val) - self.state["loss"].append(loss_val) - else: - self.state[key].append(value.detach()) - - @overwatch.rank_zero_only - def push(self) -> str: - # Note :: Raw Loss is an Average over Gradient Accumulation Steps --> No Smoothing! - loss_raw = torch.stack(list(self.state["loss_raw"])).mean().item() - loss = torch.stack(list(self.state["loss"])).mean().item() - step_time, lr = np.mean(list(self.state["step_time"])), self.state["lr"][-1] - status = self.get_status(loss) - - # Fire to Trackers - prefix = self.stage.capitalize() - self.log( - self.global_step, - metrics={ - f"{prefix}/Step": self.global_step, - f"{prefix}/Loss": loss, - f"{prefix}/Loss (Raw)": loss_raw, - f"{prefix}/Learning Rate": lr, - f"{prefix}/Step Time": step_time, - }, - ) - return status - - def finalize(self) -> str: - for tracker in self.trackers: - tracker.finalize() - - -class VLAMetrics: - def __init__( - self, - active_trackers: Tuple[str, ...], - run_id: str, - run_dir: Path, - hparams: Dict[str, Any], - wandb_project: str = "openvla", - wandb_entity: Optional[str] = "stanford-voltron", - grad_accumulation_steps: int = 1, - window_size: int = 1, - resume_step: Optional[int] = None, - resume_epoch: Optional[int] = None, - ) -> None: - self.run_id, self.run_dir, self.hparams = run_id, run_dir, hparams - - # Initialize Trackers - self.trackers = [] - for tracker_type in active_trackers: - if tracker_type == "jsonl": - tracker = JSONLinesTracker(run_id, run_dir, hparams) - elif tracker_type == "wandb": - tracker = WeightsBiasesTracker( - run_id, run_dir, hparams, project=wandb_project, entity=wandb_entity, group="vla-train" - ) - else: - raise ValueError(f"Tracker with type `{tracker_type} is not supported!") - - # Add Hyperparameters --> add to `self.trackers` - tracker.write_hyperparameters() - self.trackers.append(tracker) - - # Create Universal Metrics Buffers - self.global_step = 0 if resume_step is None else resume_step - self.epoch = 0 if resume_epoch is None else resume_epoch - self.start_time, self.step_start_time = time.time(), time.time() - self.state = { - "loss_raw": deque(maxlen=grad_accumulation_steps), - "loss": deque(maxlen=window_size), - "l1_loss": deque(maxlen=window_size), - "action_accuracy": deque(maxlen=window_size), - "step_time": deque(maxlen=window_size), - "lr": [], - } - - # Created metrics buffers for individual tracked datasets - self.dataset_trackers = defaultdict(lambda: VLAMetrics([], "", "", {})) - - def log(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None: - for tracker in self.trackers: - tracker.write(global_step, metrics) - - def get_status(self, loss: Optional[torch.Tensor] = None) -> str: - lr = self.state["lr"][-1] if len(self.state["lr"]) > 0 else 0 - if loss is None: - return f"=>> [Epoch {self.epoch:03d}] Global Step {self.global_step:06d} =>> LR :: {lr:.6f}" - - # Otherwise, embed `loss` in status report! - return f"=>> [Epoch {self.epoch:03d}] Global Step {self.global_step:06d} =>> LR :: {lr:.6f} - Loss :: {loss:.4f}" - - def commit( - self, - *, - global_step: Optional[int] = None, - epoch: Optional[int] = None, - lr: Optional[float] = None, - update_step_time: bool = False, - **kwargs, - ) -> None: - """Update all metrics in `self.state` by iterating through special positional arguments & kwargs.""" - if global_step is not None: - self.global_step = global_step - - if epoch is not None: - self.epoch = epoch - - # For all other variables --> only track on rank zero! - if not overwatch.is_rank_zero(): - return - - # Special Positional Arguments - if lr is not None: - self.state["lr"].append(lr) - - if update_step_time: - self.state["step_time"].append(time.time() - self.step_start_time) - self.step_start_time = time.time() - - # Generic Keyword Arguments - for key, value in kwargs.items(): - if key == "loss": - loss_val = value.detach() - self.state["loss_raw"].append(loss_val) - self.state["loss"].append(loss_val) - else: - self.state[key].append(value.detach()) - - def commit_for_dataset(self, dataset_name: str, **kwargs) -> None: - self.dataset_trackers[dataset_name].commit(**kwargs) - - @overwatch.rank_zero_only - def push(self) -> str: - # Note :: Raw Loss is an Average over Gradient Accumulation Steps --> No Smoothing! - loss_raw = torch.stack(list(self.state["loss_raw"])).mean().item() - loss = torch.stack(list(self.state["loss"])).mean().item() - l1_loss = torch.stack(list(self.state["l1_loss"])).mean().item() - action_accuracy = torch.stack(list(self.state["action_accuracy"])).mean().item() - step_time, lr = np.mean(list(self.state["step_time"])), self.state["lr"][-1] - status = self.get_status(loss) - - # Get metrics per dataset - dataset_metrics = {} - for ds, tracker in self.dataset_trackers.items(): - dataset_metrics.update( - { - f"{ds}/L1 Loss": torch.stack(list(tracker.state["l1_loss"])).mean().item(), - f"{ds}/Action Token Accuracy": torch.stack(list(tracker.state["action_accuracy"])).mean().item(), - } - ) - - # Fire to Trackers - prefix = "VLA Train" - self.log( - self.global_step, - metrics={ - f"{prefix}/Step": self.global_step, - f"{prefix}/Epoch": self.epoch, - f"{prefix}/Loss": loss, - f"{prefix}/L1 Loss": l1_loss, - f"{prefix}/Action Token Accuracy": action_accuracy, - f"{prefix}/Loss (Raw)": loss_raw, - f"{prefix}/Learning Rate": lr, - f"{prefix}/Step Time": step_time, - **dataset_metrics, - }, - ) - return status - - def finalize(self) -> str: - for tracker in self.trackers: - tracker.finalize() diff --git a/capvector-oft/prismatic/training/strategies/__init__.py b/capvector-oft/prismatic/training/strategies/__init__.py deleted file mode 100644 index 748155bebe6f45bfb012c341e526d5fa47e589eb..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/training/strategies/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .base_strategy import TrainingStrategy -from .ddp import DDPStrategy -from .fsdp import FSDPStrategy diff --git a/capvector-oft/prismatic/training/strategies/base_strategy.py b/capvector-oft/prismatic/training/strategies/base_strategy.py deleted file mode 100644 index adb8c0fde4af5ce130c2b22b6a366bf32eb5bbd1..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/training/strategies/base_strategy.py +++ /dev/null @@ -1,417 +0,0 @@ -""" -base_strategy.py - -Abstract class definition of a (distributed) training strategy, with full annotations of class methods, utility -functions, and initialization logic. - -Training Strategies (DDP, FSDP-Grad, FSDP-Full) tend to have a lot of repeated components; this class does a lot of -heavy lifting. -""" - -from abc import ABC, abstractmethod -from pathlib import Path -from typing import Callable, Optional - -import numpy as np -import torch -import torch.distributed as dist -from torch.utils.data import DataLoader, Dataset, DistributedSampler, IterableDataset -from tqdm import tqdm -from transformers.modeling_outputs import CausalLMOutputWithPast - -from prismatic.models.vlms import PrismaticVLM -from prismatic.overwatch import initialize_overwatch -from prismatic.training.metrics import Metrics, VLAMetrics -from prismatic.training.train_utils import ( - compute_actions_l1_loss, - compute_token_accuracy, - get_current_action_mask, - get_next_actions_mask, -) -from prismatic.util import check_bloat16_supported -from prismatic.util.batching_utils import SplitModalitySampler -from prismatic.util.data_utils import PaddedCollatorForActionPrediction, PaddedCollatorForLanguageModeling -from prismatic.vla.action_tokenizer import ActionTokenizer - -# HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels) -from prismatic.vla.constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, NUM_ACTIONS_CHUNK, IGNORE_INDEX -NEWLINE_INDEX = 13 # '\n' -STOP_INDEX = 2 # '' - -# Initialize Overwatch =>> Wraps `logging.Logger` -overwatch = initialize_overwatch(__name__) - - -# === Abstract Base Class for an arbitrary Training Strategy === -class TrainingStrategy(ABC): - def __init__( - self, - vlm: PrismaticVLM, - device_id: int, - stage: str, - epochs: int, - max_steps: Optional[int], - global_batch_size: int, - per_device_batch_size: int, - learning_rate: float, - weight_decay: float, - max_grad_norm: float, - lr_scheduler_type: str, - warmup_ratio: float, - enable_gradient_checkpointing: bool = True, - enable_mixed_precision_training: bool = True, - reduce_in_full_precision: bool = False, - mixed_precision_dtype: torch.dtype = torch.bfloat16, - worker_init_fn: Optional[Callable[[int], None]] = None, - **_: str, - ) -> None: - self.vlm, self.device_id, self.stage = vlm, device_id, stage - - # Get relevant VLM instance parameters before they get (potentially) wrapped - self.all_module_keys, self.trainable_module_keys = self.vlm.all_module_keys, self.vlm.trainable_module_keys - self.llm_transformer_layer_cls = self.vlm.llm_backbone.transformer_layer_cls - - # Optimization Parameters - self.epochs, self.max_steps = epochs, max_steps - self.global_batch_size, self.per_device_batch_size = global_batch_size, per_device_batch_size - - self.learning_rate, self.weight_decay, self.max_grad_norm = learning_rate, weight_decay, max_grad_norm - self.lr_scheduler_type, self.warmup_ratio = lr_scheduler_type, warmup_ratio - - # Generic Strategy Parameters - self.enable_gradient_checkpointing = enable_gradient_checkpointing - self.enable_mixed_precision_training = enable_mixed_precision_training - self.reduce_in_full_precision = reduce_in_full_precision - self.mixed_precision_dtype = mixed_precision_dtype - - # DataLoader Parameters - self.worker_init_fn = worker_init_fn - - # Optimizers & Scheduler (initialized in `run_setup`) - self.optimizer, self.lr_scheduler = None, None - - # Lightweight Validation - assert ( - self.global_batch_size % self.per_device_batch_size == 0 - ), "Per-device batch size must evenly divide global batch size!" - self.grad_accumulation_steps = self.global_batch_size // self.per_device_batch_size // overwatch.world_size() - if self.enable_mixed_precision_training: - assert self.mixed_precision_dtype == torch.bfloat16, "Only BF16 mixed precision training is supported!" - assert check_bloat16_supported(), "BFloat16 is not supported on this hardware; unset `mixed_precision`" - - @abstractmethod - def save_checkpoint( - self, - run_dir: Path, - global_step: int, - epoch: int, - train_loss: Optional[float] = None, - only_trainable: bool = True, - ) -> None: ... - - @abstractmethod - def run_setup(self, run_dir: Path, n_train_examples: int) -> None: ... - - @abstractmethod - def clip_grad_norm(self) -> None: ... - - def run_training( - self, - dataset: Dataset, - collator: PaddedCollatorForLanguageModeling, - metrics: Metrics, - stage: str = "finetune", - batch_construction_strategy: str = "split-modality", - seed: int = 7, - ) -> None: - """Run the training loop for the given `dataset` and `collator`; log losses, results to `metrics`""" - if "finetune" in stage and batch_construction_strategy == "split-modality": - # Instantiate the split-modality sampler; if you want to extend with other batch construction schemes, - # (e.g., grouping by length) =>> can easily add them here! - modality_lengths = dataset.get_modality_lengths() - sampler = SplitModalitySampler( - dataset, - modality_lengths, - global_batch_size=self.global_batch_size, - num_replicas=overwatch.world_size(), - rank=overwatch.rank(), - seed=seed, - drop_last=False, - ) - - else: - sampler = DistributedSampler( - dataset, - num_replicas=overwatch.world_size(), - rank=overwatch.rank(), - shuffle=True, - seed=seed, - drop_last=False, - ) - - # Create a DataLoader with the initialized sampler, per-device-bsz, and collator - dataloader = DataLoader( - dataset, - batch_size=self.per_device_batch_size, - sampler=sampler, - collate_fn=collator, - num_workers=2, - worker_init_fn=self.worker_init_fn, - ) - - # Max Steps vs. Epochs Computation - steps_per_epoch = len(dataloader) // self.grad_accumulation_steps - if self.max_steps is not None and steps_per_epoch < self.max_steps: - # Just set `epochs` to some large number --> we'll short-circuit based on steps anyway - self.epochs = 100 - - # === Train === - status = metrics.get_status() - with tqdm( - total=( - (self.epochs * (len(dataloader) // self.grad_accumulation_steps)) - if self.max_steps is None - else self.max_steps - ), - desc=status, - leave=False, - disable=not overwatch.is_rank_zero(), - ) as progress: - for epoch in range(self.epochs): - self.vlm.train() - sampler.set_epoch(epoch) - - # Zero-Gradients (just in case) - self.optimizer.zero_grad() - - # Note that we'll unpack batch (and let AMP/FSDP do its thing) in the VLM.forward() call - # => Basically, if we're using mixed precision (or not), autocast()/FSDP will move to device! - for train_idx, batch in enumerate(dataloader): - # [Contract] self.vlm.forward() must automatically compute `loss` and return! - with torch.autocast( - "cuda", - dtype=self.mixed_precision_dtype, - enabled=self.enable_mixed_precision_training, - ): - output: CausalLMOutputWithPast = self.vlm( - input_ids=batch["input_ids"], - attention_mask=batch["attention_mask"], - pixel_values=batch["pixel_values"], - labels=batch["labels"], - multimodal_indices=batch["multimodal_indices"], - ) - loss = output.loss - - # Commit Loss (Prior to Gradient Accumulation Normalization) - metrics.commit(loss=loss) - - # Normalize Loss to account for Gradient Accumulation --> Backward! - # [IMPORTANT] Technically speaking, doing gradient accumulation in this way is "incorrect"; this is - # because in general, each batch has a *different number of masked out tokens* (because - # we're instruct-tuning). Taking the mean over two unbalanced means != the right thing! - # - # HOWEVER -- at least at the 7B scale, the "naive" approach is just as performant as - # the "correct" implementation, without adding extra complexity. - # - # That being said =>> at the 13B scale, *no matter what we tried, ANY gradient accumulation is just - # really bad for downstream performance. Initial investigation shows that BF16 accumulation - # just really tanks in precision... and don't have a good/clean way to fix this. Would love for - # someone to PR and fix this (and I'd greatly appreciate it!!!) - normalized_loss = loss / self.grad_accumulation_steps - normalized_loss.backward() - - # Step =>> Only if Done w/ Gradient Accumulation - if (train_idx + 1) % self.grad_accumulation_steps == 0: - metrics.commit(update_step_time=True) - - # Clip Gradients --> this is custom, per-strategy because of DDP vs. FSDP locality-assumptions - self.clip_grad_norm() - - # Optimizer & LR Scheduler Step - self.optimizer.step() - self.lr_scheduler.step() - self.optimizer.zero_grad() - - # Push Metrics - metrics.commit(global_step=metrics.global_step + 1, lr=self.lr_scheduler.get_last_lr()[0]) - status = metrics.push() - - # Check for Termination & Save Final Checkpoint (in case `max_steps` is not None) - if self.max_steps is not None and metrics.global_step >= self.max_steps: - self.save_checkpoint(metrics.run_dir, metrics.global_step, epoch, loss.item()) - dist.barrier() - - return - - # Update Progress Bar - progress.update() - progress.set_description(status) - - # Save checkpoint at end each epoch (if `self.max_steps` is None) - if self.max_steps is None: - self.save_checkpoint(metrics.run_dir, metrics.global_step, epoch, loss.item()) - dist.barrier() - - # === VLA Training === - - def run_vla_training( - self, - vla_dataset: IterableDataset, - collator: PaddedCollatorForActionPrediction, - action_tokenizer: ActionTokenizer, - metrics: VLAMetrics, - save_interval: int = 2500, - save_full_model: bool = True, - ) -> None: - """Run the VLA training loop for the given `dataset` and `collator`; log losses, action metrics to `metrics`.""" - assert isinstance(vla_dataset, IterableDataset), "VLA training expects an IterableDataset!" - assert self.grad_accumulation_steps == 1, "VLA training does not support gradient accumulation!" - - # Create a DataLoader =>> Set `num_workers` to 0; RLDS loader handles parallelism! - dataloader = DataLoader( - vla_dataset, - batch_size=self.per_device_batch_size, - sampler=None, - collate_fn=collator, - num_workers=0, - worker_init_fn=self.worker_init_fn, - ) - - # === Train === - status = metrics.get_status() - with tqdm( - total=(self.epochs * len(dataloader)) if self.max_steps is None else self.max_steps, - desc=status, - leave=False, - disable=not overwatch.is_rank_zero(), - ) as progress: - self.vlm.train() - - # Zero Gradients (just in case) - self.optimizer.zero_grad() - - # [Contract] DataLoader wraps RLDS Loader (`.as_numpy_iterator() =>> implicit `.repeat()`) - # => This means looping over the DataLoader is basically "infinite" (so no outer loop over epochs). - # Slightly breaks default PyTorch semantics, which is why we adaptively compute `epoch` below. - for batch in dataloader: - # Note that we'll unpack batch (and let AMP/FSDP do its thing) in the VLM.forward() call - # => Basically, if we're using mixed precision (or not), autocast()/FSDP will move to device! - with torch.autocast( - "cuda", dtype=self.mixed_precision_dtype, enabled=self.enable_mixed_precision_training - ): - # [Contract] self.vlm.forward() must automatically compute `loss` and return! - output: CausalLMOutputWithPast = self.vlm( - input_ids=batch["input_ids"], - attention_mask=batch["attention_mask"], - pixel_values=batch["pixel_values"], - labels=batch["labels"], - ) - loss = output.loss - - # Commit Loss =>> Backward! - metrics.commit(loss=loss) - loss.backward() - - # Get predicted and ground-truth token IDs - predicted_token_ids = output.logits[:, self.vlm.vision_backbone.num_patches : -1].argmax(dim=2) - ground_truth_token_ids = batch["labels"][:, 1:].to(predicted_token_ids.device) - - ####################################################################### - # === Compute Current Action Token Accuracy & L1 Loss === - ####################################################################### - - # Get current action mask: Target the first ACTION_DIM non-ignore tokens - current_action_mask = get_current_action_mask(ground_truth_token_ids) - - # Compute Accuracy - action_accuracy = compute_token_accuracy(predicted_token_ids, ground_truth_token_ids, mask=current_action_mask) - - # Compute L1 Loss on Predicted (Continuous) Actions - action_l1_loss = compute_actions_l1_loss(action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask=current_action_mask) - - ####################################################################### - # === Compute Next Actions Token Accuracy & L1 Loss === - ####################################################################### - - # Get next actions mask: Target all tokens after the first ACTION_DIM non-ignore tokens (excluding the last token, which is the stop token) - next_actions_mask = get_next_actions_mask(ground_truth_token_ids) - - # Compute Accuracy - next_actions_accuracy = compute_token_accuracy(predicted_token_ids, ground_truth_token_ids, mask=next_actions_mask) - - # Compute L1 Loss on Predicted (Continuous) Actions - next_actions_l1_loss = compute_actions_l1_loss(action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask=next_actions_mask) - - ####################################################################### - # === Log === - ####################################################################### - - # Commit Metrics - metrics.commit( - action_accuracy=action_accuracy, - l1_loss=action_l1_loss, - next_actions_accuracy=next_actions_accuracy, - next_actions_l1_loss=next_actions_l1_loss, - update_step_time=True, - ) - - # Compute metrics per dataset --> only on rank_zero since we don't log them on other workers anyways - if overwatch.is_rank_zero(): - datasets = set(batch["dataset_names"]) - if len(datasets) > 1: - for ds in datasets: - ds_mask = torch.tensor([elem == ds for elem in batch["dataset_names"]]) - action_accuracy_ds = correct_preds[ds_mask].sum().float() / mask[ds_mask].sum().float() - pred_continuous_actions_ds = torch.tensor( - action_tokenizer.decode_token_ids_to_actions( - predicted_token_ids[ds_mask][mask[ds_mask]].cpu().numpy() - ) - ) - continuous_actions_gt_ds = torch.tensor( - action_tokenizer.decode_token_ids_to_actions( - ground_truth_token_ids[ds_mask][mask[ds_mask]].cpu().numpy() - ) - ) - action_l1_loss_ds = torch.nn.functional.l1_loss( - pred_continuous_actions_ds, continuous_actions_gt_ds - ) - metrics.commit_for_dataset( - dataset_name=ds.decode(), - action_accuracy=action_accuracy_ds, - l1_loss=action_l1_loss_ds, - next_actions_accuracy=next_actions_accuracy, - next_actions_l1_loss=next_actions_l1_loss, - ) - - # === Gradient Step === - - # Clip Gradients --> this is custom, per-strategy because of DDP vs. FSDP locality assumptions - self.clip_grad_norm() - - # Optimizer & LR Scheduler Step - self.optimizer.step() - self.lr_scheduler.step() - self.optimizer.zero_grad() - - # Compute epoch value using number of completed gradient steps - epoch = (metrics.global_step + 1) // (len(vla_dataset) // self.global_batch_size) - - # Push Metrics - metrics.commit(global_step=metrics.global_step + 1, epoch=epoch, lr=self.lr_scheduler.get_last_lr()[0]) - status = metrics.push() - - # Check for Save Interval or Max Steps & Save Checkpoint - if (terminate := (self.max_steps is not None and metrics.global_step >= self.max_steps)) or ( - (metrics.global_step % save_interval) == 0 - ): - self.save_checkpoint( - metrics.run_dir, metrics.global_step, epoch, loss.item(), only_trainable=not save_full_model - ) - dist.barrier() - - if terminate: - return - - # Update Progress Bar - progress.update() - progress.set_description(status) diff --git a/capvector-oft/prismatic/training/strategies/ddp.py b/capvector-oft/prismatic/training/strategies/ddp.py deleted file mode 100644 index 84e685d3fdfc123d2868e5de8201d927d24cceb2..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/training/strategies/ddp.py +++ /dev/null @@ -1,128 +0,0 @@ -""" -ddp.py - -Core class definition for a strategy implementing Torch native Distributed Data Parallel Training; note that on most -GPU hardware and LLM backbones >= 5-7B parameters, DDP training will OOM, which is why we opt for FSDP. -""" - -import shutil -from pathlib import Path -from typing import Optional - -import torch -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim import AdamW -from transformers.optimization import get_constant_schedule, get_cosine_schedule_with_warmup - -from prismatic.overwatch import initialize_overwatch -from prismatic.training.strategies.base_strategy import TrainingStrategy - -# Initialize Overwatch =>> Wraps `logging.Logger` -overwatch = initialize_overwatch(__name__) - - -class DDPStrategy(TrainingStrategy): - @overwatch.rank_zero_only - def save_checkpoint( - self, - run_dir: Path, - global_step: int, - epoch: int, - train_loss: Optional[float] = None, - only_trainable: bool = True, - ) -> None: - """Save a checkpoint to the `run_dir` only containing the state_dicts for trainable parameters by default.""" - assert isinstance(self.vlm, DDP), "save_checkpoint assumes VLM is already wrapped in DDP!" - - # Splinter State Dictionary by Top-Level Submodules (or subset, if `only_trainable`) - model_state_dicts = { - mkey: getattr(self.vlm.module, mkey).state_dict() - for mkey in (self.trainable_module_keys if only_trainable else self.all_module_keys) - } - optimizer_state_dict = self.optimizer.state_dict() - - # Set Checkpoint Path =>> Embed *minimal* training statistics! - checkpoint_dir = run_dir / "checkpoints" - if train_loss is None: - checkpoint_path = checkpoint_dir / f"step-{global_step:06d}-epoch-{epoch:02d}-loss=inf.pt" - else: - checkpoint_path = checkpoint_dir / f"step-{global_step:06d}-epoch-{epoch:02d}-loss={train_loss:.4f}.pt" - - # Save Checkpoint & Copy Latest to `latest-checkpoint.pt` - torch.save({"model": model_state_dicts, "optimizer": optimizer_state_dict}, checkpoint_path) - shutil.copy(checkpoint_path, checkpoint_dir / "latest-checkpoint.pt") - - def run_setup(self, run_dir: Path, n_train_examples: int) -> None: - # Gradient Checkpointing Setup - if self.enable_gradient_checkpointing: - # For Gradient Checkpointing --> we make the assumption that the "bulk" of activation memory is taken up - # by the LLM; because we also make the explicit assumption that each LLM is derived from a HF - # pretrained model, the only thing we *need* to do (technically) is call `gradient_checkpoint_enable` - # on `self.llm_backbone`. - # - # What does it actually do? --> runs the *generic* custom_forward + torch.utils.checkpoint.checkpoint logic - # => github.com/huggingface/transformers/.../models/llama/modeling_llama.py#L692-L706 - # - # Additional Reference (to better understand gradient checkpointing in PyTorch writ large) - # => github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb - overwatch.info("Enabling Gradient Checkpointing on LLM Backbone", ctx_level=1) - self.vlm.llm_backbone.gradient_checkpointing_enable() - - # Move to Device =>> Note parameters are in full precision (*mixed precision* will only autocast as appropriate) - overwatch.info("Placing Entire VLM (Vision Backbone, LLM Backbone, Projector Weights) on GPU", ctx_level=1) - self.vlm.to(self.device_id) - - # Wrap with Distributed Data Parallel - # => Note: By default, wrapping naively with DDP(self.vlm) will initialize a *separate* buffer on GPU that - # is the same size/dtype as the model parameters; this will *double* GPU memory! - # - stackoverflow.com/questions/68949954/model-takes-twice-the-memory-footprint-with-distributed-data-parallel - overwatch.info("Wrapping VLM with Distributed Data Parallel", ctx_level=1) - self.vlm = DDP(self.vlm, device_ids=[self.device_id], gradient_as_bucket_view=True) - - # Create Optimizer and LR Scheduler =>> note that most of the LR Schedulers we use require `max_steps/epochs` - # => Optimizer should only operate on parameters that are *unfrozen* / trainable! - trainable_params = [param for param in self.vlm.parameters() if param.requires_grad] - if self.max_steps is None: - num_training_steps = (n_train_examples * self.epochs) // self.global_batch_size - else: - num_training_steps = self.max_steps - - if self.lr_scheduler_type == "linear-warmup+cosine-decay": - # Set warmup steps (floor) based on `warmup_ratio` (should be 0.03 - 0.05) - num_warmup_steps = int(num_training_steps * self.warmup_ratio) - - assert self.weight_decay == 0, "DDP training does not currently support `weight_decay` > 0!" - self.optimizer = AdamW(trainable_params, lr=self.learning_rate, weight_decay=self.weight_decay) - self.lr_scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps, num_training_steps) - for param_group in self.optimizer.param_groups: - param_group["lr"] = 0.0 - - elif self.lr_scheduler_type == "constant": - num_warmup_steps = 0 - - assert self.weight_decay == 0, "DDP training does not currently support `weight_decay` > 0!" - self.optimizer = AdamW(trainable_params, lr=self.learning_rate, weight_decay=self.weight_decay) - self.lr_scheduler = get_constant_schedule(self.optimizer) - - else: - raise ValueError(f"Learning Rate Schedule with type `{self.lr_scheduler_type}` is not supported!") - - # Finalize Setup =>> Log - overwatch.info( - "DDP Strategy =>> Finalized Training Setup:\n" - f" |-> Global (Effective) Batch Size = {self.global_batch_size}\n" - f" |-> Per-Device Batch Size = {self.per_device_batch_size}\n" - f" |-> Distributed World Size = {overwatch.world_size()}\n" - f" |-> Gradient Accumulation Steps = {self.grad_accumulation_steps}\n\n" - f" |-> LLM Backbone Gradient Checkpointing = {self.enable_gradient_checkpointing}\n" - f" |-> Use Native AMP = {self.enable_mixed_precision_training} ({self.mixed_precision_dtype})\n\n" - f" |-> Default AdamW LR = {self.learning_rate}\n" - f" |-> AdamW Weight Decay = {self.weight_decay}\n" - f" |-> LR Scheduler Type = {self.lr_scheduler_type}\n" - f" |-> LR Scheduler Warmup Steps (Ratio) = {num_warmup_steps} ({self.warmup_ratio})\n" - f" |-> Dataset Size = {n_train_examples} Examples\n" - f" |-> Max Steps = {num_training_steps}\n" - ) - - def clip_grad_norm(self) -> None: - torch.nn.utils.clip_grad_norm_(self.vlm.parameters(), max_norm=self.max_grad_norm) diff --git a/capvector-oft/prismatic/training/strategies/fsdp.py b/capvector-oft/prismatic/training/strategies/fsdp.py deleted file mode 100644 index 9d59e41dab6ac5469e55375469fdc58575f24e24..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/training/strategies/fsdp.py +++ /dev/null @@ -1,270 +0,0 @@ -""" -fsdp.py - -Core class definition for a strategy implementing Torch native Fully Sharded Data Parallel Training (with support for -fine-grained control over wrapping policies and mixed precision per component). -""" - -import math -from collections import OrderedDict -from functools import partial -from pathlib import Path -from typing import Callable, Optional - -import torch -import torch.distributed as dist -import torch.nn as nn -from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - CheckpointImpl, - apply_activation_checkpointing, - checkpoint_wrapper, -) -from torch.distributed.fsdp import ( - FullStateDictConfig, - MixedPrecision, - ShardingStrategy, - StateDictType, -) -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.optim import AdamW -from transformers.optimization import get_constant_schedule, get_cosine_schedule_with_warmup - -from prismatic.models.vlms import PrismaticVLM -from prismatic.overwatch import initialize_overwatch -from prismatic.training.strategies.base_strategy import TrainingStrategy - -# Initialize Overwatch =>> Wraps `logging.Logger` -overwatch = initialize_overwatch(__name__) - - -class FSDPStrategy(TrainingStrategy): - def __init__( - self, - vlm: PrismaticVLM, - device_id: int, - stage: str, - epochs: int, - max_steps: Optional[int], - global_batch_size: int, - per_device_batch_size: int, - learning_rate: float, - weight_decay: float, - max_grad_norm: float, - lr_scheduler_type: str, - warmup_ratio: float, - enable_gradient_checkpointing: bool = True, - enable_mixed_precision_training: bool = True, - reduce_in_full_precision: bool = False, - mixed_precision_dtype: torch.dtype = torch.bfloat16, - worker_init_fn: Optional[Callable[[int], None]] = None, - sharding_strategy: str = "shard-grad-op", - state_dict_type: StateDictType = StateDictType.FULL_STATE_DICT, - ) -> None: - super().__init__( - vlm=vlm, - device_id=device_id, - stage=stage, - epochs=epochs, - max_steps=max_steps, - global_batch_size=global_batch_size, - per_device_batch_size=per_device_batch_size, - learning_rate=learning_rate, - weight_decay=weight_decay, - max_grad_norm=max_grad_norm, - lr_scheduler_type=lr_scheduler_type, - warmup_ratio=warmup_ratio, - enable_gradient_checkpointing=enable_gradient_checkpointing, - enable_mixed_precision_training=enable_mixed_precision_training, - reduce_in_full_precision=reduce_in_full_precision, - mixed_precision_dtype=mixed_precision_dtype, - worker_init_fn=worker_init_fn, - ) - - # FSDP-Specific Parameters - if sharding_strategy == "shard-grad-op": - self.fsdp_sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2 - elif sharding_strategy == "full-shard": - self.fsdp_sharding_strategy = ShardingStrategy.HYBRID_SHARD - else: - raise ValueError(f"FSDP Sharding Strategy {sharding_strategy} is not supported!") - - assert state_dict_type == StateDictType.FULL_STATE_DICT, "Sharded state saving is not yet implemented!" - self.fsdp_state_dict_type = state_dict_type - self.fsdp_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) - - def save_checkpoint( - self, - run_dir: Path, - global_step: int, - epoch: int, - train_loss: Optional[float] = None, - only_trainable: bool = True, - ) -> None: - """Save a checkpoint to the `run_dir` only containing the state_dicts for trainable parameters by default.""" - assert isinstance(self.vlm, FSDP), "FSDPStrategy.save_checkpoint assumes VLM is already wrapped in FSDP!" - - # Summon Full State Dictionary =>> Reconstitute from Shards - with FSDP.state_dict_type(self.vlm, self.fsdp_state_dict_type, self.fsdp_save_policy): - full_vlm_state_dict = self.vlm.state_dict() - model_state_dicts = { - mkey: OrderedDict() for mkey in (self.trainable_module_keys if only_trainable else self.all_module_keys) - } - - # Iterate through `full_vlm_state_dict` and split `mkey.{full_dotted_path}` -> `mkey: {full_dotted_path}` - for key, param in full_vlm_state_dict.items(): - for mkey in model_state_dicts: - if key.startswith(mprefix := f"{mkey}."): - model_state_dicts[mkey][key.removeprefix(mprefix)] = param - - # Save on rank zero *only* - if overwatch.is_rank_zero(): - checkpoint_dir = run_dir / "checkpoints" - if train_loss is None: - checkpoint_path = checkpoint_dir / f"step-{global_step:06d}-epoch-{epoch:02d}-loss=inf.pt" - else: - checkpoint_path = ( - checkpoint_dir / f"step-{global_step:06d}-epoch-{epoch:02d}-loss={train_loss:.4f}.pt" - ) - - # Save Checkpoint & Copy Latest to `latest-checkpoint.pt` - torch.save({"model": model_state_dicts}, checkpoint_path) - - # TODO (siddk) :: This breaks w/ Sagemaker default permissions (root vs. )... skip? - # shutil.copy(checkpoint_path, checkpoint_dir / "latest-checkpoint.pt") - - def run_setup(self, run_dir: Path, n_train_examples: int) -> None: - # Iteratively Assemble FSDP Wrapping Policy by fetching the wrapping policies for each backbone/constituent - vlm_fsdp_wrapping_policy = self.vlm.get_fsdp_wrapping_policy() - - # Assemble the Default FSDP Mixed Precision Policy - if self.enable_mixed_precision_training and self.mixed_precision_dtype == torch.bfloat16: - # MixedPrecision `param_dtype` specifies *compute* dtype (for forward/backward only) - # => Reference: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.MixedPrecision - reduce_buffer_dtype = torch.bfloat16 if not self.reduce_in_full_precision else torch.float32 - fsdp_precision_policy = MixedPrecision( - param_dtype=torch.bfloat16, reduce_dtype=reduce_buffer_dtype, buffer_dtype=reduce_buffer_dtype - ) - - # When running FSDP with a frozen vision backbone --> move to half precision! - if self.stage not in {"full-finetune", "vla-full-train", "vla-sandwich-train"}: - overwatch.info("Casting Vision Backbone to *Half Precision* via `.to(dtype=...)`") - self.vlm.vision_backbone.to(dtype=self.vlm.vision_backbone.half_precision_dtype) - - else: - # If we're not using mixed precision, everything is in default full precision! - fsdp_precision_policy = MixedPrecision( - param_dtype=torch.float32, reduce_dtype=torch.float32, buffer_dtype=torch.float32 - ) - - # => note that FSDP will automatically take care of device placement (similar to `autocast`) - self.vlm = FSDP( - self.vlm, - auto_wrap_policy=vlm_fsdp_wrapping_policy, - mixed_precision=fsdp_precision_policy, - sharding_strategy=self.fsdp_sharding_strategy, - device_id=torch.cuda.current_device(), - limit_all_gathers=True, - use_orig_params=True, - ) - - # Gradient Checkpoint Setup - if self.enable_gradient_checkpointing: - # For Gradient Checkpointing under FSDP --> we make the same assumption as in the DDP/other strategies; the - # bulk of activation memory is taken up by the LLM activations. However, unlike other strategies, we - # cannot rely on the HF Transformers default `gradient_checkpointing_enable()` --> FSDP breaks semantics! - # - # Instead, we need to write our own *NO-REENTRANT* wrapper, and apply it to the LLM's Transformer Layer. - non_reentrant_wrapper = partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT) - - def check_fn(submodule: nn.Module) -> bool: - return isinstance(submodule, self.llm_transformer_layer_cls) - - # Note that the terms "activation checkpointing" and "gradient checkpointing" are synonymous! - apply_activation_checkpointing(self.vlm, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn) - - # Barrier =>> Sharding takes a minute? - dist.barrier() - - # Create Optimizer and LR Scheduler =>> note that most of the LR Schedulers we use require `max_steps/epochs` - # => Optimizer should only operate on parameters that are *unfrozen* / trainable! - n_train_examples = math.ceil(n_train_examples / self.global_batch_size) * self.global_batch_size - if self.max_steps is None: - num_training_steps = (n_train_examples * self.epochs) // self.global_batch_size - else: - num_training_steps = self.max_steps - - if self.lr_scheduler_type == "linear-warmup+cosine-decay": - # Set warmup steps (floor) based on `warmup_ratio` (should be 0.03 - 0.05) - num_warmup_steps = int(num_training_steps * self.warmup_ratio) - - # Default AdamW w/ specified LR & Linear Warmup / Cosine Decay & Weight Decay - # => Create Parameter Groups --> bias terms, normalization layer parameters shouldn't be decayed! - decay, no_decay = [], [] - for name, param in self.vlm.named_parameters(): - if not param.requires_grad: - continue - - # Check on any parameters with fewer than 2 dimensions or with "bias" in the name - if param.ndim <= 1 or name.endswith(".bias"): - no_decay.append(param) - else: - decay.append(param) - - # Build Parameter Groups - groups = [{"params": decay, "weight_decay": self.weight_decay}, {"params": no_decay, "weight_decay": 0.0}] - - # Create Optimizer & LR Scheduler - self.optimizer = AdamW(groups, lr=self.learning_rate) - self.lr_scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps, num_training_steps) - for param_group in self.optimizer.param_groups: - param_group["lr"] = 0.0 - - elif self.lr_scheduler_type == "constant": - num_warmup_steps = 0 - - # Default AdamW w/ specified LR & Linear Warmup / Cosine Decay & Weight Decay - # => Create Parameter Groups --> bias terms, normalization layer parameters shouldn't be decayed! - decay, no_decay = [], [] - for name, param in self.vlm.named_parameters(): - if not param.requires_grad: - continue - - # Check on any parameters with fewer than 2 dimensions or with "bias" in the name - if param.ndim <= 1 or name.endswith(".bias"): - no_decay.append(param) - else: - decay.append(param) - - # Build Parameter Groups - groups = [{"params": decay, "weight_decay": self.weight_decay}, {"params": no_decay, "weight_decay": 0.0}] - - # Create Optimizer & LR Scheduler - self.optimizer = AdamW(groups, lr=self.learning_rate) - self.lr_scheduler = get_constant_schedule(self.optimizer) - - else: - raise ValueError(f"Learning Rate Schedule with type `{self.lr_scheduler_type}` is not supported!") - - # Finalize Setup =>> Log! - overwatch.info( - "FSDP Full-Shard Strategy =>> Finalized Training Setup:\n" - f" |-> Global (Effective) Batch Size = {self.global_batch_size}\n" - f" |-> Per-Device Batch Size = {self.per_device_batch_size}\n" - f" |-> Distributed World Size = {overwatch.world_size()}\n" - f" |-> Gradient Accumulation Steps = {self.grad_accumulation_steps}\n\n" - f" |-> LLM Backbone FSDP Gradient Checkpointing = {self.enable_gradient_checkpointing}\n" - f" |-> Use FSDP Mixed Precision = {self.enable_mixed_precision_training}\n" - f" |-> Parameter Precision = {fsdp_precision_policy.param_dtype}\n" - f" |-> Reduction Precision = {fsdp_precision_policy.reduce_dtype}\n" - f" |-> Buffer Precision = {fsdp_precision_policy.buffer_dtype}\n\n" - f" |-> Default AdamW LR = {self.learning_rate}\n" - f" |-> AdamW Weight Decay = {self.weight_decay}\n" - f" |-> LR Scheduler Type = {self.lr_scheduler_type}\n" - f" |-> LR Scheduler Warmup Steps (Ratio) = {num_warmup_steps} ({self.warmup_ratio})\n" - f" |-> Dataset Size = {n_train_examples} Examples\n" - f" |-> Max Steps = {num_training_steps}\n" - ) - - def clip_grad_norm(self) -> None: - # Note =>> FSDP uses a custom `clip_grad_norm_` function; requires *uniform grad dtype* - self.vlm.clip_grad_norm_(max_norm=self.max_grad_norm) diff --git a/capvector-oft/prismatic/training/train_utils.py b/capvector-oft/prismatic/training/train_utils.py deleted file mode 100644 index 62fa76a61bad11f41ef9c387aaa691b9e25089b3..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/training/train_utils.py +++ /dev/null @@ -1,56 +0,0 @@ -"""Utils for training/fine-tuning scripts.""" - -import torch - -from prismatic.vla.constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX - - -def get_current_action_mask(token_ids): - # Create a tensor marking positions of IGNORE_INDEX - newline_positions = token_ids != IGNORE_INDEX - - # Calculate cumulative sum to identify regions between newlines - cumsum = torch.cumsum(newline_positions, dim=1) - - # Create the mask - mask = (1 <= cumsum) & (cumsum <= ACTION_DIM) - - # Extract the action part only - action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX - mask = action_tokens_only_mask * mask - - return mask - - -def get_next_actions_mask(token_ids): - # Create a tensor marking positions of IGNORE_INDEX - newline_positions = token_ids != IGNORE_INDEX - - # Calculate cumulative sum to identify regions between newlines - cumsum = torch.cumsum(newline_positions, dim=1) - - # Create the mask - mask = cumsum > ACTION_DIM - - # Extract the action part only - action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX - mask = action_tokens_only_mask * mask - - return mask - - -def compute_token_accuracy(predicted_token_ids, ground_truth_token_ids, mask): - correct_preds = (predicted_token_ids == ground_truth_token_ids) & mask - accuracy = correct_preds.sum().float() / mask.sum().float() - return accuracy - - -def compute_actions_l1_loss(action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask): - pred_continuous_actions = torch.tensor( - action_tokenizer.decode_token_ids_to_actions(predicted_token_ids[mask].cpu().numpy()) - ) - true_continuous_actions = torch.tensor( - action_tokenizer.decode_token_ids_to_actions(ground_truth_token_ids[mask].cpu().numpy()) - ) - l1_loss = torch.nn.functional.l1_loss(pred_continuous_actions, true_continuous_actions) - return l1_loss diff --git a/capvector-oft/prismatic/util/__init__.py b/capvector-oft/prismatic/util/__init__.py deleted file mode 100644 index 71f1ff62dffae813bf8bc9281b7e61cbaef81ccf..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/util/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .torch_utils import check_bloat16_supported, set_global_seed diff --git a/capvector-oft/prismatic/util/batching_utils.py b/capvector-oft/prismatic/util/batching_utils.py deleted file mode 100644 index 558ebc4ac8536a78af73daaa6bea984d03d3f689..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/util/batching_utils.py +++ /dev/null @@ -1,212 +0,0 @@ -""" -batching_utils.py - -Core definitions of (Distributed) Samplers for VLM finetuning; provides functionality for construction and allocating -"split-modality" batches as described in the LLaVa paper; this makes sure that a given device/batch is either entirely -(vision, language) or (language-only) data, which leads to sizeable efficiency gains. -""" - -import math -from typing import Iterator, List, Optional, Tuple - -import numpy as np -import torch -import torch.distributed as dist -from torch.utils.data import Dataset, Sampler - - -# High-Fidelity Bitwise Reproduction of the LLaVa Codebase Sampler Strategy + Per-Rank Allocation Scheme (following -# the default batching behavior of HF's Trainer Class --> derived from `accelerate`). -# -# =>> Reference: https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L60 -# =>> Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L603 -class SplitModalitySampler(Sampler): - def __init__( - self, - dataset: Dataset, - modality_lengths: List[Tuple[bool, int]], - global_batch_size: int, - num_replicas: Optional[int] = None, - rank: Optional[int] = None, - seed: int = 0, - drop_last: bool = False, - ) -> None: - super().__init__() - self.num_replicas = num_replicas if num_replicas is not None else dist.get_world_size() - self.rank = rank if rank is not None else dist.get_rank() - self.seed, self.epoch = seed, 0 - - # Custom Parameters - self.dataset, self.modality_lengths, self.drop_last = dataset, modality_lengths, drop_last - self.global_batch_size = global_batch_size - - # For our purposes, `drop_last` is always False! - assert not self.drop_last, "SplitModalitySampler must set `drop_last = False`!" - self.total_size = math.ceil(len(self.dataset) / self.global_batch_size) * self.global_batch_size - self.num_samples = self.total_size // self.num_replicas - - @staticmethod - def reindex_batch(batch_idxs: List[int], idx2lengths: List[int], n_buckets: int) -> List[List[int]]: - """Re-indexes a batch in a way that is conducive to DistributedSampler + grouping by seqlen per rank.""" - assert len(batch_idxs) % n_buckets == 0, "Batch length is not divisible by `num_replicas`!" - - # Establish initial buckets, capacities, and max number of elements per bucket - n_examples_per_bucket = len(batch_idxs) // n_buckets - bucket_indices = [[] for _ in range(n_buckets)] - bucket_lengths = [0 for _ in range(n_buckets)] - - # Note that `batch_idxs` is already sorted by corresponding length (in descending order) - for idx in batch_idxs: - shortest_bucket_idx = bucket_lengths.index(min(bucket_lengths)) - bucket_indices[shortest_bucket_idx].append(idx) - - # Update `bucket_lengths` --> set length to infinity if at capacity! - bucket_lengths[shortest_bucket_idx] += idx2lengths[idx] - if len(bucket_indices[shortest_bucket_idx]) == n_examples_per_bucket: - bucket_lengths[shortest_bucket_idx] = float("inf") - - return bucket_indices - - def get_modality_and_length_grouped_indices(self, generator: torch.Generator) -> List[int]: - """ - Returns a list of indices so that each slice of `global_batch_size` consecutive indices corresponds to elements - of the same modality with each sub-sequence of `per_replica_batch_size` (the batch size each unique device sees - during distributed training) is roughly grouped by sequence length (for training efficiency). - """ - multimodal_indices, multimodal_lengths = zip( - *[(idx, length) for idx, (is_multimodal, length) in enumerate(self.modality_lengths) if is_multimodal] - ) - - # Handle Special Case --> no "unimodal" inputs - unimodal_split = [ - (idx, length) for idx, (is_multimodal, length) in enumerate(self.modality_lengths) if not is_multimodal - ] - if len(unimodal_split) == 0: - unimodal_indices, unimodal_lengths = [], [] - else: - unimodal_indices, unimodal_lengths = zip(*unimodal_split) - - # Create a permutation of indices for each of the multimodal and unimodal data - mm_shuffled_idxs = torch.randperm(len(multimodal_indices), generator=generator) - uni_shuffled_idxs = torch.randperm(len(unimodal_indices), generator=generator) - - # We're going to be running sorting/grouping relative to `self.global_batch_size` and `self.num_replicas` - g_bsz = self.global_batch_size - - # Break each of the permutations into batches of length `global_batch_size` - mm_batch_idxs = [mm_shuffled_idxs[i : i + g_bsz].tolist() for i in range(0, len(mm_shuffled_idxs), g_bsz)] - uni_batch_idxs = [uni_shuffled_idxs[i : i + g_bsz].tolist() for i in range(0, len(uni_shuffled_idxs), g_bsz)] - - # If "last" batch is not of length `g_bsz` --> PAD by stealing indices from the first batch! - if len(mm_batch_idxs[-1]) < g_bsz: - n_missing = g_bsz - len(mm_batch_idxs[-1]) - mm_batch_idxs[-1].extend(mm_batch_idxs[0][:n_missing]) - - if len(uni_batch_idxs) > 0 and len(uni_batch_idxs[-1]) < g_bsz: - n_missing = g_bsz - len(uni_batch_idxs[-1]) - uni_batch_idxs[-1].extend(uni_batch_idxs[0][:n_missing]) - - # Now we're going to sort each batch by length --> this will aid in grouping by length by rank (efficiency!) - mm_sorted_batch_idxs = [sorted(b, key=lambda i: multimodal_lengths[i], reverse=True) for b in mm_batch_idxs] - uni_sorted_batch_idxs = [sorted(b, key=lambda i: unimodal_lengths[i], reverse=True) for b in uni_batch_idxs] - - # IMPORTANT :: At this point, for each modality, we have a list of "batches" (made up of indices) where indices - # are sorted by example sequence length *within* each batch. To make this more concrete, consider the following: - # => World Size (`num_replicas`) = 2 - # => Global Batch Size (`g_bsz`) = 4 - # => `multimodal_indices` = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] - # `multimodal_lengths` = [20, 90, 21, 22, 91, 18, 89, 19, 93, 88, 92, 17] - # - # At this point in the code, `mm_sorted_batch_idxs` might then look like the following (length in parenthesis): - # => `mm_sorted_batch_idxs`: [ - # [4 (91), 3 (21), 0 (20), 5 (18)] => Batch 1 - # [6 (89), 9 (88), 7 (19), 11 (17)] => Batch 2 - # [8 (93), 10 (92), 1 (90), 2 (21)] => Batch 3 - # ] - # - # In practice: `g_bsz` is large (= 128), and for contiguous mini-batch "slices", length variance is low. - - # PROBLEM :: We want to split these "global batches" into equal-sized pieces, so that each "replica" (GPU) - # sees a "mini-batch" of roughly the same sequence lengths; this is super useful for efficient training. - - # HOWEVER :: The default "access pattern" for splitting a large batch into mini-batches by a DistributedSampler - # is akin to a "take every k" where `k` is equal to the number of replicas (GPUs) you're training on. Or, in - # Python notation --> `rank_k_indices = flatten(mm_sorted_batch_idxs)[k::num_replicas]. - # - # Naively translating this our example means each GPU (in our world of 2 total) sees the following indices - # (grouped by "mini-batch" = `g_bsz / num_replicas` = 2 for convenience): - # => `rank_0_indices`: [ [4 (91), 0 (20)] =>> [6 (89), 7 (19)] =>> [8 (93), 1 (90)] ] - # => `rank_1_indices`: [ [3 (21), 5 (18)] =>> [9 (88), 11 (17)] =>> [10 (92), 2 (21)] ] - # - # We get lucky sometimes, but for the most part, each "mini-batch" has VASTLY DIFFERENT lengths! Bad! - - # FIX :: If we "undo" the access pattern with the following code and re-arrange the way we allocate batches - # inside the __iter__ method below, we can allocate indices appropriately. Running the following code gives us - # the following indices (grouped by "mini-batch" again for convenience): - # => `rank_0_indices`: [ [4 (91), 3 (21)] =>> [6 (89), 9 (88)] =>> [8 (93), 10 (92)] ] - # => `rank_1_indices`: [ [5 (18), 0 (20)] =>> [11 (17), 7 (19)] =>> [2 (21), 1 (90)] ] - # - # Much better! As `g_bsz` and `dataset` grow, we're more often than not getting *decent* groupings! - mm_length_bucketed_idxs = [ - self.reindex_batch(batch, multimodal_lengths, self.num_replicas) for batch in mm_sorted_batch_idxs - ] - uni_length_bucketed_idxs = [ - self.reindex_batch(batch, unimodal_lengths, self.num_replicas) for batch in uni_sorted_batch_idxs - ] - - # Note :: Because of the initial `randperm` --> we're indexing both sets from 0 (we're clobbering the range) - # => Flatten indices --> index into original `{modality}_indices` then re-batch! - mm_output_idxs = [idx for batch in mm_length_bucketed_idxs for bucket in batch for idx in bucket] - mm_reindexed = [multimodal_indices[idx] for idx in mm_output_idxs] - mm_batches = [mm_reindexed[i : i + g_bsz] for i in range(0, len(mm_reindexed), g_bsz)] - - uni_output_idxs = [idx for batch in uni_length_bucketed_idxs for bucket in batch for idx in bucket] - uni_reindexed = [unimodal_indices[idx] for idx in uni_output_idxs] - uni_batches = [uni_reindexed[i : i + g_bsz] for i in range(0, len(uni_reindexed), g_bsz)] - - # Finally, randomly permute the multimodal & unimodal batches, merging into a single stream of indices - merged_batches = mm_batches + uni_batches - merge_idxs = torch.randperm(len(merged_batches), generator=generator) - all_batches = [merged_batches[idx] for idx in merge_idxs] - - # [Quality of Life] Shift "max length" batch to index 0 --> if we OOM, it happens immediately! - all_lengths = [length + ((_n_patches := 24 * 24) if is_mm else 0) for is_mm, length in self.modality_lengths] - all_batches_max_lengths = [] - for batch in all_batches: - all_batches_max_lengths.append(max([all_lengths[idx] for idx in batch])) - - # Identify Batch with "max length" --> Swap into Index 0 - longest_batch_idx = np.argmax(all_batches_max_lengths) - all_batches[0], all_batches[longest_batch_idx] = all_batches[longest_batch_idx], all_batches[0] - - # Flatten & Return all Indices - indices = [idx for batch in all_batches for idx in batch] - return indices - - def __iter__(self) -> Iterator: - """Deterministically shuffle, then split indices by modality and length.""" - g = torch.Generator() - g.manual_seed(self.seed + self.epoch) - indices = self.get_modality_and_length_grouped_indices(g) - assert len(set(indices)) == len(self.modality_lengths) == len(self.dataset), "Oops!" - assert (len(indices) % self.global_batch_size == 0) and (len(indices) % self.num_replicas) == 0, "Oops" - - # Note :: We compute per-replica batch size as a function of `global_batch` and `num_replicas` to ensure that - # gradient accumulation doesn't affect what indices are assigned a given rank. - per_replica_batch_size = self.global_batch_size // self.num_replicas - - # Tensorize & Unravel --> rather than yielding via a `take_every` --> we want to partition a global batch - # across replicas by assigning each a contiguous sub-sequence. - indices_t = torch.as_tensor(indices) - per_replica_batch_indices_t = indices_t.reshape(-1, per_replica_batch_size) - replica_indices_t = per_replica_batch_indices_t[self.rank :: self.num_replicas] - - replica_indices = replica_indices_t.flatten().tolist() - return iter(replica_indices) - - def __len__(self) -> int: - return self.num_samples - - def set_epoch(self, epoch: int) -> None: - """To be called *between* epochs, prior to DataLoader instantiation; ensures random order across epochs.""" - self.epoch = epoch diff --git a/capvector-oft/prismatic/util/data_utils.py b/capvector-oft/prismatic/util/data_utils.py deleted file mode 100644 index 1abdca892b115cc837de316681048bd766eb7b81..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/util/data_utils.py +++ /dev/null @@ -1,156 +0,0 @@ -""" -data_utils.py - -General utilities and classes for facilitating data loading and collation. -""" - -from dataclasses import dataclass -from typing import Callable, Dict, Sequence, Tuple - -import numpy as np -import torch -from torch.nn.utils.rnn import pad_sequence - -# HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels) -IGNORE_INDEX = -100 - - -def tree_map(fn: Callable, tree: dict) -> dict: - """Maps a function over a nested dictionary.""" - return {k: tree_map(fn, v) if isinstance(v, dict) else fn(v) for k, v in tree.items()} - - -def tree_map_with_key(fn: Callable, tree: dict, keys: Sequence = ()) -> dict: - """Maps a function over a nested dictionary.""" - return { - k: tree_map_with_key(fn, v, (*keys, k)) if isinstance(v, dict) else fn((*keys, k), v) for k, v in tree.items() - } - - -@dataclass -class PaddedCollatorForLanguageModeling: - model_max_length: int - pad_token_id: int - default_image_resolution: Tuple[int, int, int] - padding_side: str = "right" - pixel_values_dtype: torch.dtype = torch.float32 - - def __post_init__(self) -> None: - self.dummy_pixel_values = torch.zeros(self.default_image_resolution, dtype=self.pixel_values_dtype) - - def __call__(self, instances: Sequence[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: - input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) - pixel_values = [instance["pixel_values"] for instance in instances] - - # For now, we only support Tokenizers with `padding_side = "right"` during Training (but plan to extend!) - # => Handle padding via RNN Utils => `pad_sequence` - input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id) - labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) - - # Truncate (if necessary) - input_ids, labels = input_ids[:, : self.model_max_length], labels[:, : self.model_max_length] - - # Get `attention_mask` by checking for `pad_token_id` - attention_mask = input_ids.ne(self.pad_token_id) - - # === Handle "unimodal" (language-only) vs. "multimodal" === - - # Some examples are "language-only" --> build a Tensor of `multimodal_indices` that we can slice into easily - multimodal_indices = torch.tensor( - [idx for idx in range(len(pixel_values)) if pixel_values[idx] is not None], dtype=torch.long - ) - - # Stack all `pixel_values` --> depending on type (torch.Tensor, or Dict[str, torch.Tensor]) & presence of None - if len(multimodal_indices) == 0: - pixel_values = torch.stack([self.dummy_pixel_values for _ in range(len(input_ids))]) - elif isinstance(pv_example := pixel_values[multimodal_indices[0]], torch.Tensor): - pixel_values = torch.stack( - [ - pixel_values[idx] if idx in multimodal_indices else self.dummy_pixel_values - for idx in range(len(input_ids)) - ] - ) - elif isinstance(pv_example, dict): - pixel_values = { - k: torch.stack( - [ - pixel_values[idx][k] if idx in multimodal_indices else self.dummy_pixel_values - for idx in range(len(input_ids)) - ] - ) - for k in pv_example - } - else: - raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}") - - return dict( - pixel_values=pixel_values, - input_ids=input_ids, - attention_mask=attention_mask, - labels=labels, - multimodal_indices=multimodal_indices, - ) - - -@dataclass -class PaddedCollatorForActionPrediction: - model_max_length: int - pad_token_id: int - padding_side: str = "right" - pixel_values_dtype: torch.dtype = torch.float32 - - def __call__(self, instances: Sequence[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: - input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) - pixel_values = [instance["pixel_values"] for instance in instances] - if "dataset_name" in instances[0]: - dataset_names = [instance["dataset_name"] for instance in instances] - else: - dataset_names = None - - # For now, we only support Tokenizers with `padding_side = "right"` during training - # => Handle padding via RNN Utils => `pad_sequence` - assert self.padding_side == "right", f"Invalid Tokenizer `{self.padding_side = }`" - input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id) - labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) - - # Truncate (if necessary) - input_ids, labels = input_ids[:, : self.model_max_length], labels[:, : self.model_max_length] - - # Get `attention_mask` by checking for `pad_token_id` - attention_mask = input_ids.ne(self.pad_token_id) - - # [Contract] For VLA Training =>> No "Unimodal" Data! - assert all([pv is not None for pv in pixel_values]), "Invalid VLA Example with `pixel_values = None`!" - - # Stack all `pixel_values` --> depending on type is torch.Tensor or Dict[str, torch.Tensor] - if isinstance(pixel_values[0], torch.Tensor): - if "pixel_values_wrist" in instances[0]: - pixel_values_wrist = [instance["pixel_values_wrist"] for instance in instances] - pixel_values = torch.cat((torch.stack(pixel_values), torch.stack(pixel_values_wrist)), dim=1) - else: - pixel_values = torch.stack(pixel_values) - else: - raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}") - - # Stack all actions - actions = [torch.from_numpy(np.copy(instance["actions"])) for instance in instances] - actions = torch.stack(actions) - - # Stack proprio - if "proprio" in instances[0]: - proprio = [instance["proprio"] for instance in instances] - proprio = torch.Tensor(np.squeeze(np.stack(proprio))) - else: - proprio = None - - output = dict( - pixel_values=pixel_values, - proprio=proprio, - input_ids=input_ids, - attention_mask=attention_mask, - labels=labels, - actions=actions, - ) - if dataset_names is not None: - output["dataset_names"] = dataset_names - return output diff --git a/capvector-oft/prismatic/util/nn_utils.py b/capvector-oft/prismatic/util/nn_utils.py deleted file mode 100644 index cb62c6be7d736aef9057bb26a4c2717af830f2f4..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/util/nn_utils.py +++ /dev/null @@ -1,53 +0,0 @@ -""" -nn_utils.py - -Utility functions and PyTorch submodule definitions. -""" - -import torch -import torch.nn as nn - - -# === Definitions for Various Projection Modules, with Signature :: [..., in_dim] --> [..., out_dim] === -class LinearProjector(nn.Module): - def __init__(self, vision_dim: int, llm_dim: int) -> None: - super().__init__() - self.projector = nn.Linear(vision_dim, llm_dim, bias=True) - - def forward(self, img_patches: torch.Tensor) -> torch.Tensor: - return self.projector(img_patches) - - -class MLPProjector(nn.Module): - def __init__(self, vision_dim: int, llm_dim: int, mlp_type: str = "gelu-mlp") -> None: - super().__init__() - if mlp_type == "gelu-mlp": - self.projector = nn.Sequential( - nn.Linear(vision_dim, llm_dim, bias=True), - nn.GELU(), - nn.Linear(llm_dim, llm_dim, bias=True), - ) - else: - raise ValueError(f"Projector with `{mlp_type = }` is not supported!") - - def forward(self, img_patches: torch.Tensor) -> torch.Tensor: - return self.projector(img_patches) - - -class FusedMLPProjector(nn.Module): - def __init__(self, fused_vision_dim: int, llm_dim: int, mlp_type: str = "fused-gelu-mlp") -> None: - super().__init__() - self.initial_projection_dim = fused_vision_dim * 4 - if mlp_type == "fused-gelu-mlp": - self.projector = nn.Sequential( - nn.Linear(fused_vision_dim, self.initial_projection_dim, bias=True), - nn.GELU(), - nn.Linear(self.initial_projection_dim, llm_dim, bias=True), - nn.GELU(), - nn.Linear(llm_dim, llm_dim, bias=True), - ) - else: - raise ValueError(f"Fused Projector with `{mlp_type = }` is not supported!") - - def forward(self, fused_img_patches: torch.Tensor) -> torch.Tensor: - return self.projector(fused_img_patches) diff --git a/capvector-oft/prismatic/util/torch_utils.py b/capvector-oft/prismatic/util/torch_utils.py deleted file mode 100644 index ddef43b290bc55825a687e9ddd69d2ce532ab7d4..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/util/torch_utils.py +++ /dev/null @@ -1,95 +0,0 @@ -""" -torch_utils.py - -General utilities for randomness, mixed precision training, and miscellaneous checks in PyTorch. - -Random `set_global_seed` functionality is taken directly from PyTorch-Lighting: - > Ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/seed.py - -This is pretty important to get right if we're every randomly generating our masks (or prefix dropout) inside our -Dataset __getitem__() with multiple workers... if not handled properly, we will get repeated augmentations anytime -we inject randomness from non-PyTorch sources (e.g., numpy, random)! - > Ref: https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/ - -Terminology - -> World Size :: Total number of processes distributed over (# nodes x # devices) -- assumed homogenous! - -> Rank :: Integer index of current process in the total world size - -> Local Rank :: Local index on given node in [0, Devices per Node] -""" - -import os -import random -from typing import Callable, Optional - -import numpy as np -import torch - -# === Randomness === - - -def set_global_seed(seed: int, get_worker_init_fn: bool = False) -> Optional[Callable[[int], None]]: - """Sets seed for all randomness libraries (mostly random, numpy, torch) and produces a `worker_init_fn`""" - assert np.iinfo(np.uint32).min < seed < np.iinfo(np.uint32).max, "Seed outside the np.uint32 bounds!" - - # Set Seed as an Environment Variable - os.environ["EXPERIMENT_GLOBAL_SEED"] = str(seed) - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - - return worker_init_function if get_worker_init_fn else None - - -def worker_init_function(worker_id: int) -> None: - """ - Borrowed directly from PyTorch-Lightning; inspired by this issue comment in the PyTorch repo: - > Ref: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562 - - Intuition: You can think of the seed sequence spawn function as a "janky" torch.Generator() or jax.PRNGKey that - you can run iterative splitting on to get new (predictable) randomness. - - :param worker_id: Identifier for the given worker [0, num_workers) for the Dataloader in question. - """ - # Get current `rank` (if running distributed) and `process_seed` - global_rank, process_seed = int(os.environ["LOCAL_RANK"]), torch.initial_seed() - - # Back out the "base" (original) seed - the per-worker seed is set in PyTorch: - # > https://pytorch.org/docs/stable/data.html#data-loading-randomness - base_seed = process_seed - worker_id - - # "Magic" code --> basically creates a seed sequence that mixes different "sources" and seeds every library... - seed_seq = np.random.SeedSequence([base_seed, worker_id, global_rank]) - - # Use 128 bits (4 x 32-bit words) to represent seed --> generate_state(k) produces a `k` element array! - np.random.seed(seed_seq.generate_state(4)) - - # Spawn distinct child sequences for PyTorch (reseed) and stdlib random - torch_seed_seq, random_seed_seq = seed_seq.spawn(2) - - # Torch Manual seed takes 64 bits (so just specify a dtype of uint64 - torch.manual_seed(torch_seed_seq.generate_state(1, dtype=np.uint64)[0]) - - # Use 128 Bits for `random`, but express as integer instead of as an array - random_seed = (random_seed_seq.generate_state(2, dtype=np.uint64).astype(list) * [1 << 64, 1]).sum() - random.seed(random_seed) - - -# === BFloat16 Support === - - -def check_bloat16_supported() -> bool: - try: - import packaging.version - import torch.cuda.nccl as nccl - import torch.distributed as dist - - return ( - (torch.version.cuda is not None) - and torch.cuda.is_bf16_supported() - and (packaging.version.parse(torch.version.cuda).release >= (11, 0)) - and dist.is_nccl_available() - and (nccl.version() >= (2, 10)) - ) - - except Exception: - return False diff --git a/capvector-oft/prismatic/vla/__init__.py b/capvector-oft/prismatic/vla/__init__.py deleted file mode 100644 index bd2de0ce872181a9c7e3e9bfafd30e90381cd3e6..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/vla/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .materialize import get_vla_dataset_and_collator diff --git a/capvector-oft/prismatic/vla/action_tokenizer.py b/capvector-oft/prismatic/vla/action_tokenizer.py deleted file mode 100644 index bbb6ffa4ae191dade3d940c4363f02c92a270b36..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/vla/action_tokenizer.py +++ /dev/null @@ -1,72 +0,0 @@ -""" -action_tokenizer.py - -Extension class; wraps base LLM/VLM tokenizer with logic to discretize and tokenize continuous robot actions. -""" - -from typing import List, Union - -import numpy as np -from transformers import PreTrainedTokenizerBase - - -class ActionTokenizer: - def __init__( - self, tokenizer: PreTrainedTokenizerBase, bins: int = 256, min_action: int = -1, max_action: int = 1 - ) -> None: - """ - Discretizes continuous robot actions into N bins per dimension and maps to the least used tokens. - - NOTE =>> by default, assumes a BPE-style tokenizer akin to the LlamaTokenizer, where *the least used tokens* - appear at the end of the vocabulary! - - :param tokenizer: Base LLM/VLM tokenizer to extend. - :param bins: Number of bins for each continuous value; we'll adopt a uniform binning strategy. - :param min_action: Minimum action value (for clipping, setting lower bound on bin interval). - :param max_action: Maximum action value (for clipping, setting upper bound on bin interval). - """ - self.tokenizer, self.n_bins, self.min_action, self.max_action = tokenizer, bins, min_action, max_action - - # Create Uniform Bins + Compute Bin Centers - self.bins = np.linspace(min_action, max_action, self.n_bins) - self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0 - - # [Contract] Set "action_token_begin_idx" based on `self.tokenizer.vocab_size - (self.n_bins + 1)` - # =>> Assumes we're always overwriting the final `n_bins` tokens of the vocabulary! - self.action_token_begin_idx: int = int(self.tokenizer.vocab_size - (self.n_bins + 1)) - - def __call__(self, action: np.ndarray) -> Union[str, List[str]]: - """Clip & bin actions to *the last `n_bins` tokens* of the vocabulary (e.g., tokenizer.vocab[-256:]).""" - action = np.clip(action, a_min=float(self.min_action), a_max=float(self.max_action)) - discretized_action = np.digitize(action, self.bins) - - # Handle single element vs. batch - if len(discretized_action.shape) == 1: - return self.tokenizer.decode(list(self.tokenizer.vocab_size - discretized_action)) - else: - return self.tokenizer.batch_decode((self.tokenizer.vocab_size - discretized_action).tolist()) - - def decode_token_ids_to_actions(self, action_token_ids: np.ndarray) -> np.ndarray: - """ - Returns continuous actions for discrete action token IDs. - - NOTE =>> Because of the way the actions are discretized w.r.t. the bins (and not the bin centers), the - digitization returns bin indices between [1, # bins], inclusive, when there are actually only - (# bins - 1) bin intervals. - - Therefore, if the digitization returns the last possible index, we map this to the last bin interval. - - EXAMPLE =>> Let's say self._bins has 256 values. Then self._bin_centers has 255 values. Digitization returns - indices between [1, 256]. We subtract 1 from all indices so that they are between [0, 255]. There - is still one index (i==255) that would cause an out-of-bounds error if used to index into - self._bin_centers. Therefore, if i==255, we subtract 1 from it so that it just becomes the index of - the last bin center. We implement this simply via clipping between [0, 255 - 1]. - """ - discretized_actions = self.tokenizer.vocab_size - action_token_ids - discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1) - - return self.bin_centers[discretized_actions] - - @property - def vocab_size(self) -> int: - return self.n_bins diff --git a/capvector-oft/prismatic/vla/constants.py b/capvector-oft/prismatic/vla/constants.py deleted file mode 100644 index 81e8f9941218d81cf76cd3ce4fe07182813ed442..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/vla/constants.py +++ /dev/null @@ -1,86 +0,0 @@ -""" -Important constants for VLA training and evaluation. - -Attempts to automatically identify the correct constants to set based on the Python command used to launch -training or evaluation. If it is unclear, defaults to using the LIBERO simulation benchmark constants. -""" -import sys -from enum import Enum - -# Llama 2 token constants -IGNORE_INDEX = -100 -ACTION_TOKEN_BEGIN_IDX = 31743 -STOP_INDEX = 2 # '' - - -# Defines supported normalization schemes for action and proprioceptive state. -class NormalizationType(str, Enum): - # fmt: off - NORMAL = "normal" # Normalize to Mean = 0, Stdev = 1 - BOUNDS = "bounds" # Normalize to Interval = [-1, 1] - BOUNDS_Q99 = "bounds_q99" # Normalize [quantile_01, ..., quantile_99] --> [-1, ..., 1] - # fmt: on - - -# Define constants for each robot platform -LIBERO_CONSTANTS = { - "NUM_ACTIONS_CHUNK": 8, - "ACTION_DIM": 7, - "PROPRIO_DIM": 8, - "ACTION_PROPRIO_NORMALIZATION_TYPE": NormalizationType.BOUNDS_Q99, -} - -ALOHA_CONSTANTS = { - "NUM_ACTIONS_CHUNK": 25, - "ACTION_DIM": 14, - "PROPRIO_DIM": 14, - "ACTION_PROPRIO_NORMALIZATION_TYPE": NormalizationType.BOUNDS, -} - -BRIDGE_CONSTANTS = { - "NUM_ACTIONS_CHUNK": 5, - "ACTION_DIM": 7, - "PROPRIO_DIM": 7, - "ACTION_PROPRIO_NORMALIZATION_TYPE": NormalizationType.BOUNDS_Q99, -} - - -# Function to detect robot platform from command line arguments -def detect_robot_platform(): - cmd_args = " ".join(sys.argv).lower() - - if "libero" in cmd_args: - return "LIBERO" - elif "aloha" in cmd_args: - return "ALOHA" - elif "bridge" in cmd_args: - return "BRIDGE" - else: - # Default to LIBERO if unclear - return "LIBERO" - - -# Determine which robot platform to use -ROBOT_PLATFORM = detect_robot_platform() - -# Set the appropriate constants based on the detected platform -if ROBOT_PLATFORM == "LIBERO": - constants = LIBERO_CONSTANTS -elif ROBOT_PLATFORM == "ALOHA": - constants = ALOHA_CONSTANTS -elif ROBOT_PLATFORM == "BRIDGE": - constants = BRIDGE_CONSTANTS - -# Assign constants to global variables -NUM_ACTIONS_CHUNK = constants["NUM_ACTIONS_CHUNK"] -ACTION_DIM = constants["ACTION_DIM"] -PROPRIO_DIM = constants["PROPRIO_DIM"] -ACTION_PROPRIO_NORMALIZATION_TYPE = constants["ACTION_PROPRIO_NORMALIZATION_TYPE"] - -# Print which robot platform constants are being used (for debugging) -print(f"Using {ROBOT_PLATFORM} constants:") -print(f" NUM_ACTIONS_CHUNK = {NUM_ACTIONS_CHUNK}") -print(f" ACTION_DIM = {ACTION_DIM}") -print(f" PROPRIO_DIM = {PROPRIO_DIM}") -print(f" ACTION_PROPRIO_NORMALIZATION_TYPE = {ACTION_PROPRIO_NORMALIZATION_TYPE}") -print("If needed, manually set the correct constants in `prismatic/vla/constants.py`!") diff --git a/capvector-oft/prismatic/vla/datasets/__init__.py b/capvector-oft/prismatic/vla/datasets/__init__.py deleted file mode 100644 index 343cadd8964350abbd7fe75e6d1286b6756db795..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/vla/datasets/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .datasets import DummyDataset, EpisodicRLDSDataset, RLDSBatchTransform, RLDSDataset diff --git a/capvector-oft/prismatic/vla/datasets/datasets.py b/capvector-oft/prismatic/vla/datasets/datasets.py deleted file mode 100644 index ddc3deaac3c180a5d862884f09010358de4b149c..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/vla/datasets/datasets.py +++ /dev/null @@ -1,261 +0,0 @@ -""" -datasets.py - -Lightweight PyTorch Dataset Definition for wrapping RLDS TFDS Pipeline; just defines transform from RLDS default -format to OpenVLA, IterableDataset shim. -""" - -from dataclasses import dataclass -from pathlib import Path -from typing import Any, Dict, Tuple, Type - -import numpy as np -import torch -from PIL import Image -from torch.utils.data import Dataset, IterableDataset -from transformers import PreTrainedTokenizerBase - -from prismatic.models.backbones.llm.prompting import PromptBuilder -from prismatic.models.backbones.vision import ImageTransform -from prismatic.util.data_utils import tree_map -from prismatic.vla.action_tokenizer import ActionTokenizer -from prismatic.vla.constants import ACTION_DIM, ACTION_PROPRIO_NORMALIZATION_TYPE, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX, NUM_ACTIONS_CHUNK, PROPRIO_DIM, STOP_INDEX -from prismatic.vla.datasets.rlds import make_interleaved_dataset, make_single_dataset -from prismatic.vla.datasets.rlds.oxe import OXE_NAMED_MIXTURES, get_oxe_dataset_kwargs_and_weights - -@dataclass -class RLDSBatchTransform: - action_tokenizer: ActionTokenizer - base_tokenizer: PreTrainedTokenizerBase - image_transform: ImageTransform - prompt_builder_fn: Type[PromptBuilder] - predict_stop_token: bool = True - use_wrist_image: bool = False - use_proprio: bool = False - - def __call__(self, rlds_batch: Dict[str, Any]) -> Dict[str, Any]: - """Converts a RLDS batch to the format expected by the OpenVLA collator/models.""" - dataset_name, current_action = rlds_batch["dataset_name"], rlds_batch["action"][0] - img = Image.fromarray(rlds_batch["observation"]["image_primary"][0]) - lang = rlds_batch["task"]["language_instruction"].decode().lower() - actions = rlds_batch["action"] - - # Construct Chat-based Prompt =>> Input is default query + language instruction, output are the action tokens - prompt_builder = self.prompt_builder_fn("openvla") - - # Get future action chunk - future_actions = rlds_batch["action"][1:] - future_actions_string = ''.join(self.action_tokenizer(future_actions)) - - # Get action chunk string - current_action_string = self.action_tokenizer(current_action) - action_chunk_string = current_action_string + future_actions_string - action_chunk_len = len(action_chunk_string) - - conversation = [ - {"from": "human", "value": f"What action should the robot take to {lang}?"}, - {"from": "gpt", "value": action_chunk_string}, - ] - for turn in conversation: - prompt_builder.add_turn(turn["from"], turn["value"]) - - # Tokenize (w/ `base_tokenizer`) - input_ids = self.base_tokenizer(prompt_builder.get_prompt(), add_special_tokens=True).input_ids - labels = list(input_ids) - - # Tensorize =>> Run Image Transform to get `pixel_values` =>> Return - # =>> IMPORTANT :: IF WE'RE USING HF LLM.forward(..., labels=labels), SHIFTING HAPPENS _INSIDE_ MODEL! - input_ids, labels = torch.tensor(input_ids), torch.tensor(labels) - pixel_values = self.image_transform(img) - - # [CRITICAL] We do not want to take the loss for anything but the predicted action tokens! - labels[: -(action_chunk_len + 1)] = IGNORE_INDEX - if not self.predict_stop_token: - labels[-1] = IGNORE_INDEX - - return_dict = dict(pixel_values=pixel_values, input_ids=input_ids, labels=labels, dataset_name=dataset_name, actions=actions) - - # Add additional inputs - if self.use_wrist_image: - all_wrist_pixels = [] - for k in rlds_batch["observation"].keys(): - if "wrist" in k: - img_wrist = Image.fromarray(rlds_batch["observation"][k][0]) - pixel_values_wrist = self.image_transform(img_wrist) - all_wrist_pixels.append(pixel_values_wrist) - return_dict["pixel_values_wrist"] = torch.cat(all_wrist_pixels, dim=0) - if self.use_proprio and "proprio" in rlds_batch["observation"]: - proprio = rlds_batch["observation"]["proprio"] - return_dict["proprio"] = proprio - - return return_dict - - -class RLDSDataset(IterableDataset): - def __init__( - self, - data_root_dir: Path, - data_mix: str, - batch_transform: RLDSBatchTransform, - resize_resolution: Tuple[int, int], - shuffle_buffer_size: int = 256_000, - train: bool = True, - image_aug: bool = False, - ) -> None: - """Lightweight wrapper around RLDS TFDS Pipeline for use with PyTorch/OpenVLA Data Loaders.""" - self.data_root_dir, self.data_mix, self.batch_transform = data_root_dir, data_mix, batch_transform - - # Configure RLDS Dataset(s) - if self.data_mix in OXE_NAMED_MIXTURES: - mixture_spec = OXE_NAMED_MIXTURES[self.data_mix] - else: - # Assume that passed "mixture" name is actually a single dataset -- create single-dataset "mix" - mixture_spec = [(self.data_mix, 1.0)] - - # fmt: off - if "aloha" in self.data_mix: - load_camera_views = ("primary", "left_wrist", "right_wrist") - else: - load_camera_views = ("primary", "wrist") - - per_dataset_kwargs, weights = get_oxe_dataset_kwargs_and_weights( - self.data_root_dir, - mixture_spec, - load_camera_views=load_camera_views, - load_depth=False, - load_proprio=True, - load_language=True, - action_proprio_normalization_type=ACTION_PROPRIO_NORMALIZATION_TYPE, - ) - rlds_config = dict( - traj_transform_kwargs=dict( - window_size=1, # If we wanted to feed / predict more than one step - future_action_window_size=NUM_ACTIONS_CHUNK-1, # For action chunking - skip_unlabeled=True, # Skip trajectories without language labels - goal_relabeling_strategy="uniform", # Goals are currently unused - ), - frame_transform_kwargs=dict( - resize_size=resize_resolution, - num_parallel_calls=16, # For CPU-intensive ops (decoding, resizing, etc.) - ), - dataset_kwargs_list=per_dataset_kwargs, - shuffle_buffer_size=shuffle_buffer_size, - sample_weights=weights, - balance_weights=True, - traj_transform_threads=len(mixture_spec), - traj_read_threads=len(mixture_spec), - train=train, - ) - - # If applicable, enable image augmentations - if image_aug: - rlds_config["frame_transform_kwargs"].update({"image_augment_kwargs" : dict( - random_resized_crop=dict(scale=[0.9, 0.9], ratio=[1.0, 1.0]), - random_brightness=[0.2], - random_contrast=[0.8, 1.2], - random_saturation=[0.8, 1.2], - random_hue=[0.05], - augment_order=[ - "random_resized_crop", - "random_brightness", - "random_contrast", - "random_saturation", - "random_hue", - ], - )}), - # fmt: on - - # Initialize RLDS Dataset - self.dataset, self.dataset_length, self.dataset_statistics = self.make_dataset(rlds_config) - - def make_dataset(self, rlds_config): - return make_interleaved_dataset(**rlds_config) - - def __iter__(self) -> Dict[str, Any]: - for rlds_batch in self.dataset.as_numpy_iterator(): - yield self.batch_transform(rlds_batch) - - def __len__(self) -> int: - return self.dataset_length - - # === Explicitly Unused === - def __getitem__(self, idx: int) -> None: - raise NotImplementedError("IterableDataset does not implement map-style __getitem__; see __iter__ instead!") - - -class EpisodicRLDSDataset(RLDSDataset): - """Returns full episodes as list of steps instead of individual transitions (useful for visualizations).""" - - def make_dataset(self, rlds_config): - per_dataset_kwargs = rlds_config["dataset_kwargs_list"] - assert len(per_dataset_kwargs) == 1, "Only support single-dataset `mixes` for episodic datasets." - - return make_single_dataset( - per_dataset_kwargs[0], - train=rlds_config["train"], - traj_transform_kwargs=rlds_config["traj_transform_kwargs"], - frame_transform_kwargs=rlds_config["frame_transform_kwargs"], - ) - - def __iter__(self) -> Dict[str, Any]: - for rlds_batch in self.dataset.as_numpy_iterator(): - out = [ - self.batch_transform(tree_map(lambda x: x[i], rlds_batch)) # noqa: B023 - for i in range(rlds_batch["action"].shape[0]) - ] - yield out - - -class DummyDataset(Dataset): - def __init__( - self, - action_tokenizer: ActionTokenizer, - base_tokenizer: PreTrainedTokenizerBase, - image_transform: ImageTransform, - prompt_builder_fn: Type[PromptBuilder], - ) -> None: - self.action_tokenizer = action_tokenizer - self.base_tokenizer = base_tokenizer - self.image_transform = image_transform - self.prompt_builder_fn = prompt_builder_fn - - # Note =>> We expect the dataset to store statistics for action de-normalization. Specifically, we store the - # per-dimension 1st and 99th action quantile. The values below correspond to "no normalization" for simplicity. - self.dataset_statistics = { - "dummy_dataset": { - "action": {"q01": np.zeros((7,), dtype=np.float32), "q99": np.ones((7,), dtype=np.float32)} - } - } - - def __len__(self): - # TODO =>> Replace with number of elements in your dataset! - return 10000 - - def __getitem__(self, idx): - # TODO =>> Load image, action and instruction from disk -- we use dummy values - image = Image.fromarray(np.asarray(np.random.rand(224, 224, 3) * 255.0, dtype=np.uint8)) - action = np.asarray(np.random.rand(7), dtype=np.float32) - instruction = "do something spectacular" - - # Add instruction to VLA prompt - prompt_builder = self.prompt_builder_fn("openvla") - conversation = [ - {"from": "human", "value": f"What action should the robot take to {instruction}?"}, - {"from": "gpt", "value": self.action_tokenizer(action)}, - ] - for turn in conversation: - prompt_builder.add_turn(turn["from"], turn["value"]) - - # Tokenize (w/ `base_tokenizer`) - input_ids = self.base_tokenizer(prompt_builder.get_prompt(), add_special_tokens=True).input_ids - labels = list(input_ids) - - # Tensorize =>> Run Image Transform to get `pixel_values` =>> Return - # =>> IMPORTANT :: IF WE'RE USING HF .forward(..., labels=labels), SHIFTING HAPPENS _INSIDE_ MODEL! - input_ids, labels = torch.tensor(input_ids), torch.tensor(labels) - pixel_values = self.image_transform(image) - - # [CRITICAL] We do not want to take the loss for anything but the predicted action tokens! - labels[: -(len(action) + 1)] = IGNORE_INDEX - - return dict(pixel_values=pixel_values, input_ids=input_ids, labels=labels) diff --git a/capvector-oft/prismatic/vla/datasets/rlds/__init__.py b/capvector-oft/prismatic/vla/datasets/rlds/__init__.py deleted file mode 100644 index 4b260fa5b7326012e85e54be203fa4932f35783d..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/vla/datasets/rlds/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .dataset import make_interleaved_dataset, make_single_dataset diff --git a/capvector-oft/prismatic/vla/datasets/rlds/dataset.py b/capvector-oft/prismatic/vla/datasets/rlds/dataset.py deleted file mode 100644 index 9ff0424d071c8888c8f75e54b2eeaedc7b9bd9c5..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/vla/datasets/rlds/dataset.py +++ /dev/null @@ -1,585 +0,0 @@ -""" -dataset.py - -Core interface script for configuring and initializing RLDS datasets. -""" - -import copy -import inspect -import json -from functools import partial -from typing import Callable, Dict, List, Optional, Tuple, Union - -import dlimp as dl -import numpy as np -import tensorflow as tf -import tensorflow_datasets as tfds - -from prismatic.overwatch import initialize_overwatch -from prismatic.vla.constants import ACTION_DIM, ACTION_PROPRIO_NORMALIZATION_TYPE, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX, NUM_ACTIONS_CHUNK, PROPRIO_DIM, STOP_INDEX -from prismatic.vla.datasets.rlds import obs_transforms, traj_transforms -from prismatic.vla.datasets.rlds.utils import goal_relabeling, task_augmentation -from prismatic.vla.datasets.rlds.utils.data_utils import ( - allocate_threads, - get_dataset_statistics, - normalize_action_and_proprio, - pprint_data_mixture, - tree_map, -) - -# Initialize Overwatch =>> Wraps `logging.Logger` -overwatch = initialize_overwatch(__name__) - - -# Configure Tensorflow with *no GPU devices* (to prevent clobber with PyTorch) -tf.config.set_visible_devices([], "GPU") - - -# ruff: noqa: B006 -def make_dataset_from_rlds( - name: str, - data_dir: str, - *, - train: bool, - standardize_fn: Optional[Callable[[dict], dict]] = None, - shuffle: bool = True, - image_obs_keys: Dict[str, Optional[str]] = {}, - depth_obs_keys: Dict[str, Optional[str]] = {}, - state_obs_keys: List[Optional[str]] = (), - language_key: Optional[str] = None, - action_proprio_normalization_type: ACTION_PROPRIO_NORMALIZATION_TYPE, - dataset_statistics: Optional[Union[dict, str]] = None, - absolute_action_mask: Optional[List[bool]] = None, - action_normalization_mask: Optional[List[bool]] = None, - num_parallel_reads: int = tf.data.AUTOTUNE, - num_parallel_calls: int = tf.data.AUTOTUNE, -) -> Tuple[dl.DLataset, dict]: - """ - This function is responsible for loading a specific RLDS dataset from storage and getting it into a standardized - format. Yields a dataset of trajectories. Does not include CPU-intensive operations. - - If `standardize_fn` is provided, it will be applied to each trajectory. This function should get the trajectory - into a standard format, which includes the keys "observation" and "action". Entry "observation" should be a - dictionary containing some number of additional keys, which will be extracted into an even more standardized format - according to the "*_obs_keys" arguments. - - The `image_obs_keys` and `depth_obs_keys` arguments are mappings from new names to old names, or None in place of an - old name to insert padding. For example, if after `standardize_fn`, your "observation" dict has RGB images called - "workspace" and "wrist", and `image_obs_keys={"primary": "workspace", "secondary": None, "wrist": "wrist"}`, then - the resulting dataset will have an "observation" dict containing the keys "image_primary", "image_secondary", and - "image_wrist", where "image_primary" corresponds to "workspace", "image_secondary" is a padding image, and - "image_wrist" corresponds to "wrist". - - Entry `state_obs_keys` is a list of 1-dimensional proprioceptive keys to concatenate into a single array, which will - be placed in the "proprio" key of the "observation" dict. A single padding element (zero) will be inserted for each - None entry. - - The dataset will also include a "task" dict. If `language_key` is provided, then the "task" dict will contain the - key "language_instruction", extracted from `traj[language_key]`. - - Args: - name (str): The name of the RLDS dataset (usually "name" or "name:version"). - data_dir (str): The path to the data directory. - train (bool): Whether to use the training or validation split. - shuffle (bool, optional): Whether to shuffle the file read order (does NOT fully shuffle the dataset, since one - file usually contains many trajectories)! - standardize_fn (Callable[[dict], dict], optional): A function that, if provided, will be the first - thing applied to each trajectory. - image_obs_keys (Mapping[str, str|None]): Mapping from {new: old} indicating which RGB images to extract from the - "observation" dict. `new_obs = {f"image_{new}": old_obs[old] for new, old in image_obs_keys.items()}`. - If a value of `old` is None, inserts a padding image instead (empty string). - depth_obs_keys (Mapping[str, str|None]): Same as `image_obs_keys`, but for depth images. Keys will be - prefixed with "depth_" instead of "image_". - state_obs_keys (Sequence[str|None]): List of 1-dimensional proprioception keys to be extracted from the - "observation" dict, concatenated, and mapped to "proprio". Inserts 1 element of padding for each None entry. - language_key (str, optional): If provided, the "task" dict will contain the key "language_instruction", - extracted from `traj[language_key]`. - action_proprio_normalization_type (str, optional): The type of normalization to perform on the action, - proprio, or both. Can be "normal" (mean 0, std 1) or "bounds" (normalized to [-1, 1]). - dataset_statistics: (dict|str, optional): dict (or path to JSON file) that contains dataset statistics - for normalization. If `action_proprio_normalization_type` is "normal", this should contain "mean" and - "std" keys. If `action_proprio_normalization_type` is "bounds", this should contain "min" and "max" - keys. May also provide "num_transitions" and "num_trajectories" keys for downstream usage (e.g., for - `make_interleaved_dataset`). If not provided, the statistics will be computed on the fly. - absolute_action_mask (Sequence[bool], optional): By default, all action dimensions are assumed to be - relative. This is important for when `future_action_window_size > 0`: actions that are taken - from beyond the end of the trajectory (or beyond the goal timestep when goal relabeling is used) - need to be made "neutral" to indicate that the task has been completed. For relative actions, - "neutral" means zero, but for absolute actions, "neutral" means repeating the last valid action. - This mask, if provided, indicates which action dimensions are absolute. - action_normalization_mask (Sequence[bool], optional): If provided, indicates which action dimensions - should be normalized. For example, you might not want to normalize the gripper action dimension if - it's always exactly 0 or 1. By default, all action dimensions are normalized. - num_parallel_reads (int): number of parallel read workers. Default to AUTOTUNE. - num_parallel_calls (int): number of parallel calls for traj_map operations. Default to AUTOTUNE. - Returns: - Dataset of trajectories where each step has the following fields: - - observation: - - image_{name1, name2, ...} # RGB image observations - - depth_{name1, name2, ...} # depth image observations - - proprio # 1-dimensional array of proprioceptive observations - - timestep # timestep of each frame - - task: - - language_instruction # language instruction, present if `language_key` is provided - - action # action vector - - dataset_name # name of the dataset - """ - REQUIRED_KEYS = {"observation", "action"} - if language_key is not None: - REQUIRED_KEYS.add(language_key) - - def restructure(traj): - # apply a standardization function, if provided - if standardize_fn is not None: - traj = standardize_fn(traj) - - if not all(k in traj for k in REQUIRED_KEYS): - raise ValueError( - f"Trajectory is missing keys: {REQUIRED_KEYS - set(traj.keys())}. " "Did you write a `standardize_fn`?" - ) - - # extracts images, depth images and proprio from the "observation" dict - traj_len = tf.shape(traj["action"])[0] - old_obs = traj["observation"] - new_obs = {} - for new, old in image_obs_keys.items(): - if old is None: - new_obs[f"image_{new}"] = tf.repeat("", traj_len) # padding - else: - new_obs[f"image_{new}"] = old_obs[old] - - for new, old in depth_obs_keys.items(): - if old is None: - new_obs[f"depth_{new}"] = tf.repeat("", traj_len) # padding - else: - new_obs[f"depth_{new}"] = old_obs[old] - - if state_obs_keys: - new_obs["proprio"] = tf.concat( - [ - ( - tf.zeros((traj_len, 1), dtype=tf.float32) # padding - if key is None - else tf.cast(old_obs[key], tf.float32) - ) - for key in state_obs_keys - ], - axis=1, - ) - - # add timestep info - new_obs["timestep"] = tf.range(traj_len) - - # extracts `language_key` into the "task" dict - task = {} - if language_key is not None: - if traj[language_key].dtype != tf.string: - raise ValueError( - f"Language key {language_key} has dtype {traj[language_key].dtype}, " "but it must be tf.string." - ) - task["language_instruction"] = traj.pop(language_key) - - traj = { - "observation": new_obs, - "task": task, - "action": tf.cast(traj["action"], tf.float32), - "dataset_name": tf.repeat(name, traj_len), - } - - if absolute_action_mask is not None: - if len(absolute_action_mask) != traj["action"].shape[-1]: - raise ValueError( - f"Length of absolute_action_mask ({len(absolute_action_mask)}) " - f"does not match action dimension ({traj['action'].shape[-1]})." - ) - traj["absolute_action_mask"] = tf.tile( - tf.convert_to_tensor(absolute_action_mask, dtype=tf.bool)[None], - [traj_len, 1], - ) - - return traj - - builder = tfds.builder(name, data_dir=data_dir) - - # load or compute dataset statistics - if isinstance(dataset_statistics, str): - with tf.io.gfile.GFile(dataset_statistics, "r") as f: - dataset_statistics = json.load(f) - elif dataset_statistics is None: - full_dataset = dl.DLataset.from_rlds( - builder, split="all", shuffle=False, num_parallel_reads=num_parallel_reads - ).traj_map(restructure, num_parallel_calls) - # tries to load from cache, otherwise computes on the fly - dataset_statistics = get_dataset_statistics( - full_dataset, - hash_dependencies=( - str(builder.info), - str(state_obs_keys), - inspect.getsource(standardize_fn) if standardize_fn is not None else "", - ), - save_dir=builder.data_dir, - ) - dataset_statistics = tree_map(np.array, dataset_statistics) - - # skip normalization for certain action dimensions - if action_normalization_mask is not None: - if len(action_normalization_mask) != dataset_statistics["action"]["mean"].shape[-1]: - raise ValueError( - f"Length of skip_normalization_mask ({len(action_normalization_mask)}) " - f"does not match action dimension ({dataset_statistics['action']['mean'].shape[-1]})." - ) - dataset_statistics["action"]["mask"] = np.array(action_normalization_mask) - - # construct the dataset - split = "train" if train else "val" - - dataset = dl.DLataset.from_rlds(builder, split=split, shuffle=shuffle, num_parallel_reads=num_parallel_reads) - - dataset = dataset.traj_map(restructure, num_parallel_calls) - dataset = dataset.traj_map( - partial( - normalize_action_and_proprio, - metadata=dataset_statistics, - normalization_type=action_proprio_normalization_type, - ), - num_parallel_calls, - ) - - return dataset, dataset_statistics - - -def apply_trajectory_transforms( - dataset: dl.DLataset, - *, - train: bool, - goal_relabeling_strategy: Optional[str] = None, - goal_relabeling_kwargs: dict = {}, - window_size: int = 1, - future_action_window_size: int = 0, - subsample_length: Optional[int] = None, - skip_unlabeled: bool = False, - max_action: Optional[float] = None, - max_proprio: Optional[float] = None, - task_augment_strategy: Optional[str] = None, - task_augment_kwargs: dict = {}, - num_parallel_calls: int = tf.data.AUTOTUNE, -) -> dl.DLataset: - """ - Applies common transforms that happen at a trajectory level. Such transforms are usually some sort of "relabeling" - (e.g., filtering, chunking, adding goals, dropping keys). - - Transforms in this function should have the following properties: - - They require access to an entire trajectory (i.e., they cannot be applied frame-wise). - - They are generally not CPU-intensive, mostly involving moving and copying data. - - They do not require decoded images. - - Args: - dataset (dl.DLataset): The dataset to transform. - train (bool): Whether the dataset is for training (affects subsampling). - goal_relabeling_strategy (str, optional): The goal relabeling strategy to use, or None for - no goal relabeling. See `goal_relabeling.py`. - goal_relabeling_kwargs (dict, optional): Additional keyword arguments to pass to the goal relabeling function. - window_size (int, optional): The length of the snippets that trajectories are chunked into. - future_action_window_size (int, optional): The number of future actions beyond window_size to include - in the chunked actions. - subsample_length (int, optional): If provided, trajectories longer than this will be subsampled to - this length (after goal relabeling and chunking). - skip_unlabeled (bool, optional): Whether to skip trajectories with no language labels. - max_action: (float, optional): If provided, trajectories in which *any* action dimension - of *any* transition has an absolute value larger than this will be skipped. - max_proprio: (float, optional): If provided, trajectories in which *any* proprio dimension - of *any* transition has an absolute value larger than this will be skipped. - task_augment_strategy (str, optional): The task augmentation strategy to use, or None for no task - augmentation. See `task_augmentation.py`. - task_augment_kwargs (dict, optional): Additional keyword arguments to pass to the task augmentation - function. - num_parallel_calls (int, optional): number of parallel calls for map operations. Default to AUTOTUNE. - """ - if skip_unlabeled: - if "language_instruction" not in dataset.element_spec["task"]: - raise ValueError("skip_unlabeled=True but dataset does not have language labels.") - - dataset = dataset.filter(lambda x: tf.math.reduce_any(x["task"]["language_instruction"] != "")) - - if max_action is not None: - dataset = dataset.filter(lambda x: tf.math.reduce_all(tf.math.abs(x["action"]) <= max_action)) - - if max_proprio is not None and "proprio" in dataset.element_spec["observation"]: - dataset = dataset.filter(lambda x: tf.math.reduce_all(tf.math.abs(x["observation"]["proprio"]) <= max_proprio)) - - # marks which entires of the observation and task dicts are padding - dataset = dataset.traj_map(traj_transforms.add_pad_mask_dict, num_parallel_calls) - - # updates the "task" dict - if goal_relabeling_strategy is not None: - dataset = dataset.traj_map( - partial(getattr(goal_relabeling, goal_relabeling_strategy), **goal_relabeling_kwargs), - num_parallel_calls, - ) - - # must run task augmentation before chunking, in case it changes goal timesteps - if train and task_augment_strategy is not None: - # perform task augmentation (e.g., dropping keys) - dataset = dataset.traj_map( - partial( - getattr(task_augmentation, task_augment_strategy), - **task_augment_kwargs, - ), - num_parallel_calls, - ) - - # chunks observations and actions, giving them a new axis at index 1 of size `window_size` and - # `window_size + future_action_window_size`, respectively - dataset = dataset.traj_map( - partial( - traj_transforms.chunk_act_obs, - window_size=window_size, - future_action_window_size=future_action_window_size, - ), - num_parallel_calls, - ) - - if train and subsample_length is not None: - dataset = dataset.traj_map( - partial(traj_transforms.subsample, subsample_length=subsample_length), - num_parallel_calls, - ) - - return dataset - - -def apply_per_dataset_frame_transforms( - dataset: dl.DLataset, - chunk_filter_fn: Optional[Callable] = None, -): - """ - Optionally applied *per-dataset* transforms that happen at a frame level. - - Args: - chunk_filter_fn (callable, optional): Filter function for chunks. - """ - if chunk_filter_fn: - dataset = dataset.filter(chunk_filter_fn) - return dataset - - -def apply_frame_transforms( - dataset: dl.DLataset, - *, - train: bool, - image_augment_kwargs: Union[Dict, Dict[str, Dict]] = {}, - resize_size: Union[Tuple[int, int], Dict[str, Tuple[int, int]]] = {}, - depth_resize_size: Union[Tuple[int, int], Dict[str, Tuple[int, int]]] = {}, - num_parallel_calls: int = tf.data.AUTOTUNE, -) -> dl.DLataset: - """ - Applies common transforms that happen at a frame level. These transforms are usually more CPU-intensive, (e.g., - decoding or resizing images). - - Args: - train (bool): Whether the dataset is for training (affects image augmentation). - dataset (dl.DLataset): The dataset to transform. - image_augment_kwargs (dict|Mapping[str, dict]): Keyword arguments to pass to the image augmentation - function. See `dlimp.transforms.augment_image` for documentation of these kwargs. If a dict of - dicts is provided, then key "k" will be used for "image_{k}" (names determined by `image_obs_keys` - in `make_dataset_from_rlds`). Augmentation will be skipped for missing keys (so pass an empty dict - to skip augmentation for all images). - resize_size (Tuple[int, int]|Mapping[str, Tuple[int, int]]): If provided, images will be resized to - this size. If a dict of tuples is provided, then key "k" will be used for "image_{k}" (names - determined by `image_obs_keys` in `make_dataset_from_rlds`). Resizing will be skipped for missing - keys (so pass an empty dict to skip resizing for all images). - depth_resize_size (Tuple[int, int]|Mapping[str, Tuple[int, int]]): Same as resize_size, but for depth - images. - num_parallel_calls (int): number of parallel calls for frame_map operations. Default to AUTOTUNE. - """ - - # Convenience wrapper that takes a function that operates on a non-chunked "observation" dict and applies - # it to the chunked "observation" dict as well as the non-chunked "task" dict - def apply_obs_transform(fn: Callable[[Dict], Dict], frame: Dict) -> Dict: - frame["task"] = fn(frame["task"]) - frame["observation"] = dl.vmap(fn)(frame["observation"]) - return frame - - # Decode + resize images (and depth images) - dataset = dataset.frame_map( - partial( - apply_obs_transform, - partial(obs_transforms.decode_and_resize, resize_size=resize_size, depth_resize_size=depth_resize_size), - ), - num_parallel_calls, - ) - - if train: - # Augment all images with the same seed, skipping padding images - def aug(frame: dict): - seed = tf.random.uniform([2], maxval=tf.dtypes.int32.max, dtype=tf.int32) - aug_fn = partial(obs_transforms.augment, seed=seed, augment_kwargs=image_augment_kwargs) - return apply_obs_transform(aug_fn, frame) - - dataset = dataset.frame_map(aug, num_parallel_calls) - - return dataset - - -def make_single_dataset( - dataset_kwargs: dict, - *, - train: bool, - traj_transform_kwargs: dict = {}, - frame_transform_kwargs: dict = {}, -) -> dl.DLataset: - """Creates a single dataset from kwargs. Returns a dataset of trajectories. - - Args: - dataset_kwargs: kwargs passed to `make_dataset_from_rlds` that are dataset-specific. - train: whether this is a training or validation dataset. - traj_transform_kwargs: kwargs passed to 'apply_trajectory_transforms'. - frame_transform_kwargs: kwargs passed to 'get_frame_transforms'. - """ - dataset, dataset_statistics = make_dataset_from_rlds( - **dataset_kwargs, - train=train, - ) - dataset = apply_trajectory_transforms(dataset, **traj_transform_kwargs, train=train) - dataset = apply_frame_transforms(dataset, **frame_transform_kwargs, train=train) - - # this seems to reduce memory usage without affecting speed - dataset = dataset.with_ram_budget(1) - - # save for later - return dataset, dataset_statistics["num_trajectories"], dataset_statistics - - -# === Core Initializer === -def make_interleaved_dataset( - dataset_kwargs_list: List[Dict], - sample_weights: Optional[List[float]] = None, - *, - train: bool, - shuffle_buffer_size: int, - traj_transform_kwargs: Optional[Dict] = None, - frame_transform_kwargs: Optional[Dict] = None, - batch_size: Optional[int] = None, - balance_weights: bool = False, - traj_transform_threads: Optional[int] = None, - traj_read_threads: Optional[int] = None, -) -> dl.DLataset: - """ - Creates an interleaved dataset from list of dataset configs (kwargs). Returns a dataset of batched frames. - - Args: - dataset_kwargs_list: list of kwargs, each element of which is passed to `make_dataset_from_rlds`. - "num_parallel_calls" and "num_parallel_reads" are overridden using `traj_transform_threads` and - `traj_read_threads`, respectively. - sample_weights: sampling weights for each dataset in list. If None, defaults to uniform. - train: whether this is a training or validation dataset. - shuffle_buffer_size: size of the dataset shuffle buffer (in number of frames). - traj_transform_kwargs: kwargs passed to `apply_trajectory_transforms`. "num_parallel_calls" is - overridden using `traj_transform_threads`. - frame_transform_kwargs: kwargs passed to `apply_frame_transforms`. - batch_size: batch size, if not provided output is not batched. - balance_weights: if True, the sample weights are multiplied by the number of frames in each dataset. - This makes it so that, if all the sample weights are equal, one full iteration through the interleaved - dataset will correspond to one full iteration through each individual dataset (only in expectation, - since in practice the sampling is random). - traj_transform_threads: total number of parallel calls for trajectory transforms, distributed across - datasets according to their sampling weights. If None, defaults to AUTOTUNE for every dataset. - traj_read_threads: total number of parallel read workers for trajectory transforms, distributed across - datasets according to their sampling weights. If None, defaults to AUTOTUNE for every dataset. - """ - # Default to uniform sampling (if `sample_weights` is not specified) - if not sample_weights: - sample_weights = [1.0] * len(dataset_kwargs_list) - - if len(sample_weights) != len(dataset_kwargs_list): - raise ValueError(f"sample_weights must be None or have length {len(dataset_kwargs_list)}.") - - # Check valid `traj_transform_kwargs` and `frame_transform_kwargs` - if (traj_transform_kwargs is None) or (frame_transform_kwargs is None): - raise ValueError("Missing `traj_transform_kwargs` and `frame_transform_kwargs`!") - - # Get Dataset Sizes - dataset_sizes, all_dataset_statistics = [], {} - for dataset_kwargs in dataset_kwargs_list: - data_kwargs = copy.deepcopy(dataset_kwargs) - if "dataset_frame_transform_kwargs" in data_kwargs: - data_kwargs.pop("dataset_frame_transform_kwargs") - _, dataset_statistics = make_dataset_from_rlds(**data_kwargs, train=train) - dataset_sizes.append(dataset_statistics["num_transitions"]) - all_dataset_statistics[dataset_kwargs["name"]] = dataset_statistics - - # Get the indices of the "primary" datasets (i.e., datasets with sample_weight == 1.0) - primary_dataset_indices = np.array([idx for idx in range(len(sample_weights)) if sample_weights[idx] == 1.0]) - - # Balance and Normalize Weights - if balance_weights: - sample_weights = np.array(sample_weights) * np.array(dataset_sizes) - sample_weights = np.array(sample_weights) / np.sum(sample_weights) - pprint_data_mixture(dataset_kwargs_list, sample_weights) - - # Effective Dataset Length = Number of samples until each dataset has completed at least one epoch - # =>> Note :: Only counting the "primary" datasets (i.e., datasets with sample_weight == 1.0) - dataset_len = int((np.array(dataset_sizes) / sample_weights)[primary_dataset_indices].max()) - - # Allocate Threads based on Weights - threads_per_dataset = allocate_threads(traj_transform_threads, sample_weights) - reads_per_dataset = allocate_threads(traj_read_threads, sample_weights) - - overwatch.info("Threads per Dataset: %s", threads_per_dataset) - overwatch.info("Reads per Dataset: %s", reads_per_dataset) - - # Construct Datasets - overwatch.info("Constructing datasets...") - datasets = [] - for dataset_kwargs, threads, reads in zip( - dataset_kwargs_list, - threads_per_dataset, - reads_per_dataset, - ): - dataset_frame_transform_kwargs = ( - dataset_kwargs.pop("dataset_frame_transform_kwargs") - if "dataset_frame_transform_kwargs" in dataset_kwargs - else {} - ) - dataset, _ = make_dataset_from_rlds( - **dataset_kwargs, - train=train, - num_parallel_calls=threads, - num_parallel_reads=reads, - dataset_statistics=all_dataset_statistics[dataset_kwargs["name"]], - ) - dataset = apply_trajectory_transforms( - dataset.repeat(), - **traj_transform_kwargs, - num_parallel_calls=threads, - train=train, - ).flatten(num_parallel_calls=threads) - dataset = apply_per_dataset_frame_transforms(dataset, **dataset_frame_transform_kwargs) - datasets.append(dataset) - - # Interleave at the Frame Level - dataset: dl.DLataset = dl.DLataset.sample_from_datasets(datasets, sample_weights) - - # Validation =>> fix a single shuffle buffer of data and cache it in RAM; prevents gradual memory increase! - if not train: - dataset = dataset.take(shuffle_buffer_size).cache() - - # Shuffle the Dataset - # =>> IMPORTANT :: Shuffle AFTER .cache(), or else memory will still leak! - dataset = dataset.shuffle(shuffle_buffer_size) - - # Apply Frame Transforms - overwatch.info("Applying frame transforms on dataset...") - dataset = apply_frame_transforms(dataset, **frame_transform_kwargs, train=train) - - # [Contract] When training VLA Policies, we let the Collator handle Batching! - if batch_size is not None: - dataset = dataset.batch(batch_size) - - # Note =>> Seems to reduce memory usage without affecting speed? - dataset = dataset.with_ram_budget(1) - - # Save for Later - dataset.sample_weights = sample_weights - - return dataset, dataset_len, all_dataset_statistics diff --git a/capvector-oft/prismatic/vla/datasets/rlds/obs_transforms.py b/capvector-oft/prismatic/vla/datasets/rlds/obs_transforms.py deleted file mode 100644 index f5f537040dbf17651dbaf169c038268d8cc2a1ad..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/vla/datasets/rlds/obs_transforms.py +++ /dev/null @@ -1,99 +0,0 @@ -""" -obs_transforms.py - -Contains observation-level transforms used in the orca data pipeline. - -These transforms operate on the "observation" dictionary, and are applied at a per-frame level. -""" - -from typing import Dict, Tuple, Union - -import dlimp as dl -import tensorflow as tf -from absl import logging - - -# ruff: noqa: B023 -def augment(obs: Dict, seed: tf.Tensor, augment_kwargs: Union[Dict, Dict[str, Dict]]) -> Dict: - """Augments images, skipping padding images.""" - image_names = {key[6:] for key in obs if key.startswith("image_")} - - # "augment_order" is required in augment_kwargs, so if it's there, we can assume that the user has passed - # in a single augmentation dict (otherwise, we assume that the user has passed in a mapping from image - # name to augmentation dict) - if "augment_order" in augment_kwargs: - augment_kwargs = {name: augment_kwargs for name in image_names} - - for i, name in enumerate(image_names): - if name not in augment_kwargs: - continue - kwargs = augment_kwargs[name] - logging.debug(f"Augmenting image_{name} with kwargs {kwargs}") - obs[f"image_{name}"] = tf.cond( - obs["pad_mask_dict"][f"image_{name}"], - lambda: dl.transforms.augment_image( - obs[f"image_{name}"], - **kwargs, - seed=seed + i, # augment each image differently - ), - lambda: obs[f"image_{name}"], # skip padding images - ) - - return obs - - -def decode_and_resize( - obs: Dict, - resize_size: Union[Tuple[int, int], Dict[str, Tuple[int, int]]], - depth_resize_size: Union[Tuple[int, int], Dict[str, Tuple[int, int]]], -) -> Dict: - """Decodes images and depth images, and then optionally resizes them.""" - image_names = {key[6:] for key in obs if key.startswith("image_")} - depth_names = {key[6:] for key in obs if key.startswith("depth_")} - - if isinstance(resize_size, tuple): - resize_size = {name: resize_size for name in image_names} - if isinstance(depth_resize_size, tuple): - depth_resize_size = {name: depth_resize_size for name in depth_names} - - for name in image_names: - if name not in resize_size: - logging.warning( - f"No resize_size was provided for image_{name}. This will result in 1x1 " - "padding images, which may cause errors if you mix padding and non-padding images." - ) - image = obs[f"image_{name}"] - if image.dtype == tf.string: - if tf.strings.length(image) == 0: - # this is a padding image - image = tf.zeros((*resize_size.get(name, (1, 1)), 3), dtype=tf.uint8) - else: - image = tf.io.decode_image(image, expand_animations=False, dtype=tf.uint8) - elif image.dtype != tf.uint8: - raise ValueError(f"Unsupported image dtype: found image_{name} with dtype {image.dtype}") - if name in resize_size: - image = dl.transforms.resize_image(image, size=resize_size[name]) - obs[f"image_{name}"] = image - - for name in depth_names: - if name not in depth_resize_size: - logging.warning( - f"No depth_resize_size was provided for depth_{name}. This will result in 1x1 " - "padding depth images, which may cause errors if you mix padding and non-padding images." - ) - depth = obs[f"depth_{name}"] - - if depth.dtype == tf.string: - if tf.strings.length(depth) == 0: - depth = tf.zeros((*depth_resize_size.get(name, (1, 1)), 1), dtype=tf.float32) - else: - depth = tf.io.decode_image(depth, expand_animations=False, dtype=tf.float32)[..., 0] - elif depth.dtype != tf.float32: - raise ValueError(f"Unsupported depth dtype: found depth_{name} with dtype {depth.dtype}") - - if name in depth_resize_size: - depth = dl.transforms.resize_depth_image(depth, size=depth_resize_size[name]) - - obs[f"depth_{name}"] = depth - - return obs diff --git a/capvector-oft/prismatic/vla/datasets/rlds/oxe/__init__.py b/capvector-oft/prismatic/vla/datasets/rlds/oxe/__init__.py deleted file mode 100644 index ae77c4a44e477fea36a3f410dc047bcf0f82ef22..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/vla/datasets/rlds/oxe/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .materialize import get_oxe_dataset_kwargs_and_weights -from .mixtures import OXE_NAMED_MIXTURES diff --git a/capvector-oft/prismatic/vla/datasets/rlds/oxe/configs.py b/capvector-oft/prismatic/vla/datasets/rlds/oxe/configs.py deleted file mode 100644 index 5875b34d61afb7e8b53f3690d5509101b3b99527..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/vla/datasets/rlds/oxe/configs.py +++ /dev/null @@ -1,709 +0,0 @@ -""" -configs.py - -Defines per-dataset configuration (kwargs) for each dataset in Open-X Embodiment. - -Configuration adopts the following structure: - image_obs_keys: - primary: primary external RGB - secondary: secondary external RGB - wrist: wrist RGB - - depth_obs_keys: - primary: primary external depth - secondary: secondary external depth - wrist: wrist depth - - # Always 8-dim =>> changes based on `StateEncoding` - state_obs_keys: - StateEncoding.POS_EULER: EEF XYZ (3) + Roll-Pitch-Yaw (3) + (1) + Gripper Open/Close (1) - StateEncoding.POS_QUAT: EEF XYZ (3) + Quaternion (4) + Gripper Open/Close (1) - StateEncoding.JOINT: Joint Angles (7, if fewer) + Gripper Open/Close (1) - - state_encoding: Type of `StateEncoding` - action_encoding: Type of action encoding (e.g., EEF Position vs. Joint Position) -""" - -from enum import IntEnum - -from prismatic.vla.datasets.rlds.oxe.utils.droid_utils import zero_action_filter - - -# Defines Proprioceptive State Encoding Schemes -class StateEncoding(IntEnum): - # fmt: off - NONE = -1 # No Proprioceptive State - POS_EULER = 1 # EEF XYZ (3) + Roll-Pitch-Yaw (3) + (1) + Gripper Open/Close (1) - POS_QUAT = 2 # EEF XYZ (3) + Quaternion (4) + Gripper Open/Close (1) - JOINT = 3 # Joint Angles (7, if fewer) + Gripper Open/Close (1) - JOINT_BIMANUAL = 4 # Joint Angles (2 x [ Joint Angles (6) + Gripper Open/Close (1) ]) - # fmt: on - - -# Defines Action Encoding Schemes -class ActionEncoding(IntEnum): - # fmt: off - EEF_POS = 1 # EEF Delta XYZ (3) + Roll-Pitch-Yaw (3) + Gripper Open/Close (1) - JOINT_POS = 2 # Joint Delta Position (7) + Gripper Open/Close (1) - JOINT_POS_BIMANUAL = 3 # Joint Delta Position (2 x [ Joint Delta Position (6) + Gripper Open/Close (1) ]) - EEF_R6 = 4 # EEF Delta XYZ (3) + R6 (6) + Gripper Open/Close (1) - # fmt: on - - -# === Individual Dataset Configs === -OXE_DATASET_CONFIGS = { - "fractal20220817_data": { - "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["base_pose_tool_reached", "gripper_closed"], - "state_encoding": StateEncoding.POS_QUAT, - "action_encoding": ActionEncoding.EEF_POS, - }, - "kuka": { - "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": [ - "clip_function_input/base_pose_tool_reached", - "gripper_closed", - ], - "state_encoding": StateEncoding.POS_QUAT, - "action_encoding": ActionEncoding.EEF_POS, - }, - "bridge_oxe": { # Version of Bridge V2 in Open X-Embodiment mixture - "image_obs_keys": {"primary": "image", "secondary": "image_1", "wrist": None}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["EEF_state", "gripper_state"], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "bridge_orig": { # Original version of Bridge V2 from project website - "image_obs_keys": {"primary": "image_0", "secondary": "image_1", "wrist": None}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["EEF_state", "gripper_state"], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "bridge_dataset": { # Original version of Bridge V2 from project website - "image_obs_keys": {"primary": "image_0", "secondary": "image_1", "wrist": None}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["EEF_state", "gripper_state"], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "taco_play": { - "image_obs_keys": { - "primary": "rgb_static", - "secondary": None, - "wrist": "rgb_gripper", - }, - "depth_obs_keys": { - "primary": "depth_static", - "secondary": None, - "wrist": "depth_gripper", - }, - "state_obs_keys": ["state_eef", None, "state_gripper"], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "jaco_play": { - "image_obs_keys": { - "primary": "image", - "secondary": None, - "wrist": "image_wrist", - }, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["state_eef", None, "state_gripper"], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "berkeley_cable_routing": { - "image_obs_keys": { - "primary": "image", - "secondary": "top_image", - "wrist": "wrist45_image", - }, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["robot_state", None], - "state_encoding": StateEncoding.JOINT, - "action_encoding": ActionEncoding.EEF_POS, - }, - "roboturk": { - "image_obs_keys": {"primary": "front_rgb", "secondary": None, "wrist": None}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": [None, None, None, None, None, None, None, None], - "state_encoding": StateEncoding.NONE, - "action_encoding": ActionEncoding.EEF_POS, - }, - "nyu_door_opening_surprising_effectiveness": { - "image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": [None, None, None, None, None, None, None, None], - "state_encoding": StateEncoding.NONE, - "action_encoding": ActionEncoding.EEF_POS, - }, - "viola": { - "image_obs_keys": { - "primary": "agentview_rgb", - "secondary": None, - "wrist": "eye_in_hand_rgb", - }, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["joint_states", "gripper_states"], - "state_encoding": StateEncoding.JOINT, - "action_encoding": ActionEncoding.EEF_POS, - }, - "berkeley_autolab_ur5": { - "image_obs_keys": { - "primary": "image", - "secondary": None, - "wrist": "hand_image", - }, - "depth_obs_keys": {"primary": "depth", "secondary": None, "wrist": None}, - "state_obs_keys": ["state"], - "state_encoding": StateEncoding.POS_QUAT, - "action_encoding": ActionEncoding.EEF_POS, - }, - "toto": { - "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["state", None], - "state_encoding": StateEncoding.JOINT, - "action_encoding": ActionEncoding.EEF_POS, - }, - "language_table": { - "image_obs_keys": {"primary": "rgb", "secondary": None, "wrist": None}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["effector_translation", None, None, None, None, None, None], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "columbia_cairlab_pusht_real": { - "image_obs_keys": { - "primary": "image", - "secondary": None, - "wrist": "wrist_image", - }, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["robot_state", None, None, None, None, None, None], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "stanford_kuka_multimodal_dataset_converted_externally_to_rlds": { - "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, - "depth_obs_keys": {"primary": "depth_image", "secondary": None, "wrist": None}, - "state_obs_keys": ["ee_position", "ee_orientation", None], - "state_encoding": StateEncoding.POS_QUAT, - "action_encoding": ActionEncoding.EEF_POS, - }, - "nyu_rot_dataset_converted_externally_to_rlds": { - "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["EEF_state", "gripper_state"], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "stanford_hydra_dataset_converted_externally_to_rlds": { - "image_obs_keys": { - "primary": "image", - "secondary": None, - "wrist": "wrist_image", - }, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["EEF_state", "gripper_state"], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "austin_buds_dataset_converted_externally_to_rlds": { - "image_obs_keys": { - "primary": "image", - "secondary": None, - "wrist": "wrist_image", - }, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["state"], - "state_encoding": StateEncoding.JOINT, - "action_encoding": ActionEncoding.EEF_POS, - }, - "nyu_franka_play_dataset_converted_externally_to_rlds": { - "image_obs_keys": { - "primary": "image", - "secondary": "image_additional_view", - "wrist": None, - }, - "depth_obs_keys": { - "primary": "depth", - "secondary": "depth_additional_view", - "wrist": None, - }, - "state_obs_keys": ["eef_state", None, None], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "maniskill_dataset_converted_externally_to_rlds": { - "image_obs_keys": { - "primary": "image", - "secondary": None, - "wrist": "wrist_image", - }, - "depth_obs_keys": { - "primary": "depth", - "secondary": None, - "wrist": "wrist_depth", - }, - "state_obs_keys": ["tcp_pose", "gripper_state"], - "state_encoding": StateEncoding.POS_QUAT, - "action_encoding": ActionEncoding.EEF_POS, - }, - "furniture_bench_dataset_converted_externally_to_rlds": { - "image_obs_keys": { - "primary": "image", - "secondary": None, - "wrist": "wrist_image", - }, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["state"], - "state_encoding": StateEncoding.POS_QUAT, - "action_encoding": ActionEncoding.EEF_POS, - }, - "cmu_franka_exploration_dataset_converted_externally_to_rlds": { - "image_obs_keys": { - "primary": "highres_image", - "secondary": None, - "wrist": None, - }, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": [None, None, None, None, None, None, None, None], - "state_encoding": StateEncoding.NONE, - "action_encoding": ActionEncoding.EEF_POS, - }, - "ucsd_kitchen_dataset_converted_externally_to_rlds": { - "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["joint_state", None], - "state_encoding": StateEncoding.JOINT, - "action_encoding": ActionEncoding.EEF_POS, - }, - "ucsd_pick_and_place_dataset_converted_externally_to_rlds": { - "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["EEF_state", "gripper_state"], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "austin_sailor_dataset_converted_externally_to_rlds": { - "image_obs_keys": { - "primary": "image", - "secondary": None, - "wrist": "wrist_image", - }, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["state"], - "state_encoding": StateEncoding.POS_QUAT, - "action_encoding": ActionEncoding.EEF_POS, - }, - "austin_sirius_dataset_converted_externally_to_rlds": { - "image_obs_keys": { - "primary": "image", - "secondary": None, - "wrist": "wrist_image", - }, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["state"], - "state_encoding": StateEncoding.POS_QUAT, - "action_encoding": ActionEncoding.EEF_POS, - }, - "bc_z": { - "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": [ - "present/xyz", - "present/axis_angle", - None, - "present/sensed_close", - ], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "utokyo_pr2_opening_fridge_converted_externally_to_rlds": { - "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["EEF_state", "gripper_state"], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": { - "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["EEF_state", "gripper_state"], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "utokyo_xarm_pick_and_place_converted_externally_to_rlds": { - "image_obs_keys": { - "primary": "image", - "secondary": "image2", - "wrist": "hand_image", - }, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["end_effector_pose", None, None], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "utokyo_xarm_bimanual_converted_externally_to_rlds": { - "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["pose_r", None, None], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "robo_net": { - "image_obs_keys": {"primary": "image", "secondary": "image1", "wrist": None}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["EEF_state", "gripper_state"], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "berkeley_mvp_converted_externally_to_rlds": { - "image_obs_keys": {"primary": None, "secondary": None, "wrist": "hand_image"}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["pose", "gripper"], - "state_encoding": StateEncoding.POS_QUAT, - "action_encoding": ActionEncoding.JOINT_POS, - }, - "berkeley_rpt_converted_externally_to_rlds": { - "image_obs_keys": {"primary": None, "secondary": None, "wrist": "hand_image"}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["joint_pos", "gripper"], - "state_encoding": StateEncoding.JOINT, - "action_encoding": ActionEncoding.JOINT_POS, - }, - "kaist_nonprehensile_converted_externally_to_rlds": { - "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["state", None], - "state_encoding": StateEncoding.POS_QUAT, - "action_encoding": ActionEncoding.EEF_POS, - }, - "stanford_mask_vit_converted_externally_to_rlds": { - "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["EEF_state", "gripper_state"], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "tokyo_u_lsmo_converted_externally_to_rlds": { - "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["EEF_state", "gripper_state"], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "dlr_sara_pour_converted_externally_to_rlds": { - "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["state", None, None], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "dlr_sara_grid_clamp_converted_externally_to_rlds": { - "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["state", None, None], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "dlr_edan_shared_control_converted_externally_to_rlds": { - "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["state", None], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "asu_table_top_converted_externally_to_rlds": { - "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["EEF_state", "gripper_state"], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "stanford_robocook_converted_externally_to_rlds": { - "image_obs_keys": {"primary": "image_1", "secondary": "image_2", "wrist": None}, - "depth_obs_keys": {"primary": "depth_1", "secondary": "depth_2", "wrist": None}, - "state_obs_keys": ["EEF_state", "gripper_state"], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "imperialcollege_sawyer_wrist_cam": { - "image_obs_keys": { - "primary": "image", - "secondary": None, - "wrist": "wrist_image", - }, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": [None, None, None, None, None, None, None, "state"], - "state_encoding": StateEncoding.NONE, - "action_encoding": ActionEncoding.EEF_POS, - }, - "iamlab_cmu_pickup_insert_converted_externally_to_rlds": { - "image_obs_keys": { - "primary": "image", - "secondary": None, - "wrist": "wrist_image", - }, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["joint_state", "gripper_state"], - "state_encoding": StateEncoding.JOINT, - "action_encoding": ActionEncoding.EEF_POS, - }, - "uiuc_d3field": { - "image_obs_keys": {"primary": "image_1", "secondary": "image_2", "wrist": None}, - "depth_obs_keys": {"primary": "depth_1", "secondary": "depth_2", "wrist": None}, - "state_obs_keys": [None, None, None, None, None, None, None, None], - "state_encoding": StateEncoding.NONE, - "action_encoding": ActionEncoding.EEF_POS, - }, - "utaustin_mutex": { - "image_obs_keys": { - "primary": "image", - "secondary": None, - "wrist": "wrist_image", - }, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["state"], - "state_encoding": StateEncoding.JOINT, - "action_encoding": ActionEncoding.EEF_POS, - }, - "berkeley_fanuc_manipulation": { - "image_obs_keys": { - "primary": "image", - "secondary": None, - "wrist": "wrist_image", - }, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["joint_state", None, "gripper_state"], - "state_encoding": StateEncoding.JOINT, - "action_encoding": ActionEncoding.EEF_POS, - }, - "cmu_playing_with_food": { - "image_obs_keys": { - "primary": "image", - "secondary": None, - "wrist": "finger_vision_1", - }, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["state", None, None], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "cmu_play_fusion": { - "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["state"], - "state_encoding": StateEncoding.JOINT, - "action_encoding": ActionEncoding.EEF_POS, - }, - "cmu_stretch": { - "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["EEF_state", "gripper_state"], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "berkeley_gnm_recon": { - "image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["state", None, None], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "berkeley_gnm_cory_hall": { - "image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["state", None, None], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "berkeley_gnm_sac_son": { - "image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["state", None, None], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "droid": { - "image_obs_keys": { - "primary": "exterior_image_1_left", - "secondary": "exterior_image_2_left", - "wrist": "wrist_image_left", - }, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["proprio"], - "state_encoding": StateEncoding.POS_QUAT, - "action_encoding": ActionEncoding.EEF_POS, - "aux_kwargs": { - "dataset_frame_transform_kwargs": { - "chunk_filter_fn": zero_action_filter, - }, - }, - }, - "fmb_dataset": { - "image_obs_keys": { - "primary": "image_side_1", - "secondary": "image_side_2", - "wrist": "image_wrist_1", - }, - "depth_obs_keys": { - "primary": "image_side_1_depth", - "secondary": "image_side_2_depth", - "wrist": "image_wrist_1_depth", - }, - "state_obs_keys": ["proprio"], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "dobbe": { - "image_obs_keys": {"primary": "wrist_image", "secondary": None, "wrist": None}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["proprio"], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "roboset": { - "image_obs_keys": { - "primary": "image_left", - "secondary": "image_right", - "wrist": "image_wrist", - }, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["proprio"], - "state_encoding": StateEncoding.JOINT, - "action_encoding": ActionEncoding.JOINT_POS, - }, - "rh20t": { - "image_obs_keys": { - "primary": "image_front", - "secondary": "image_side_right", - "wrist": "image_wrist", - }, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["proprio"], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - ### T-DROID datasets - "tdroid_carrot_in_bowl": { # "put carrot in bowl" task, 50 demos @ 5 Hz control - "image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None}, - "depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None}, - "state_obs_keys": ["EEF_state", "gripper_state"], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "tdroid_pour_corn_in_pot": { # "pour corn from red bowl into steel pot" task, 50 demos @ 5 Hz control - "image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None}, - "depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None}, - "state_obs_keys": ["EEF_state", "gripper_state"], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "tdroid_flip_pot_upright": { # "flip pot upright" task, 10 demos @ 5 Hz control - "image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None}, - "depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None}, - "state_obs_keys": ["EEF_state", "gripper_state"], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "tdroid_move_object_onto_plate": { # "move onto plate" task, 150 demos @ 5 Hz control - "image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None}, - "depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None}, - "state_obs_keys": ["EEF_state", "gripper_state"], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "tdroid_knock_object_over": { # "knock over" task, 70 demos @ 5 Hz control - "image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None}, - "depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None}, - "state_obs_keys": ["EEF_state", "gripper_state"], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "tdroid_cover_object_with_towel": { # "cover with towel" task, 45 demos @ 5 Hz control - "image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None}, - "depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None}, - "state_obs_keys": ["EEF_state", "gripper_state"], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - ### DROID Finetuning datasets - "droid_wipe": { - "image_obs_keys": {"primary": "exterior_image_2_left", "secondary": None, "wrist": "wrist_image_left"}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["proprio"], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - ### LIBERO datasets (modified versions) - "libero_spatial_no_noops": { - "image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["EEF_state", "gripper_state"], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "libero_object_no_noops": { - "image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["EEF_state", "gripper_state"], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "libero_goal_no_noops": { - "image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["EEF_state", "gripper_state"], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "libero_10_no_noops": { - "image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["EEF_state", "gripper_state"], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - "libero_4_task_suites_no_noops": { - "image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["EEF_state", "gripper_state"], - "state_encoding": StateEncoding.POS_EULER, - "action_encoding": ActionEncoding.EEF_POS, - }, - ### ALOHA fine-tuning datasets - "aloha1_fold_shorts_20_demos": { - "image_obs_keys": {"primary": "image", "secondary": None, "left_wrist": "left_wrist_image", "right_wrist": "right_wrist_image"}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["state"], - "state_encoding": StateEncoding.JOINT_BIMANUAL, - "action_encoding": ActionEncoding.JOINT_POS_BIMANUAL, - }, - "aloha1_fold_shirt_30_demos": { - "image_obs_keys": {"primary": "image", "secondary": None, "left_wrist": "left_wrist_image", "right_wrist": "right_wrist_image"}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["state"], - "state_encoding": StateEncoding.JOINT_BIMANUAL, - "action_encoding": ActionEncoding.JOINT_POS_BIMANUAL, - }, - "aloha1_scoop_X_into_bowl_45_demos": { - "image_obs_keys": {"primary": "image", "secondary": None, "left_wrist": "left_wrist_image", "right_wrist": "right_wrist_image"}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["state"], - "state_encoding": StateEncoding.JOINT_BIMANUAL, - "action_encoding": ActionEncoding.JOINT_POS_BIMANUAL, - }, - "aloha1_put_X_into_pot_300_demos": { - "image_obs_keys": {"primary": "image", "secondary": None, "left_wrist": "left_wrist_image", "right_wrist": "right_wrist_image"}, - "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, - "state_obs_keys": ["state"], - "state_encoding": StateEncoding.JOINT_BIMANUAL, - "action_encoding": ActionEncoding.JOINT_POS_BIMANUAL, - }, -} diff --git a/capvector-oft/prismatic/vla/datasets/rlds/oxe/materialize.py b/capvector-oft/prismatic/vla/datasets/rlds/oxe/materialize.py deleted file mode 100644 index 73f3cfaf4b8c2fac59be2f2426ed04132e4baaee..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/vla/datasets/rlds/oxe/materialize.py +++ /dev/null @@ -1,134 +0,0 @@ -""" -materialize.py - -Factory class for initializing Open-X Embodiment dataset kwargs and other parameters; provides and exports functions for -clear control flow. -""" - -from copy import deepcopy -from pathlib import Path -from typing import Any, Dict, List, Tuple - -from prismatic.overwatch import initialize_overwatch -from prismatic.vla.constants import ACTION_DIM, ACTION_PROPRIO_NORMALIZATION_TYPE, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX, NUM_ACTIONS_CHUNK, PROPRIO_DIM, STOP_INDEX -from prismatic.vla.datasets.rlds.oxe.configs import OXE_DATASET_CONFIGS, ActionEncoding -from prismatic.vla.datasets.rlds.oxe.transforms import OXE_STANDARDIZATION_TRANSFORMS - -# Initialize Overwatch =>> Wraps `logging.Logger` -overwatch = initialize_overwatch(__name__) - - -def make_oxe_dataset_kwargs( - dataset_name: str, - data_root_dir: Path, - load_camera_views: Tuple[str] = ("primary",), - load_depth: bool = False, - load_proprio: bool = True, - load_language: bool = True, - action_proprio_normalization_type = ACTION_PROPRIO_NORMALIZATION_TYPE, -) -> Dict[str, Any]: - """Generates config (kwargs) for given dataset from Open-X Embodiment.""" - dataset_kwargs = deepcopy(OXE_DATASET_CONFIGS[dataset_name]) - if dataset_kwargs["action_encoding"] not in [ActionEncoding.EEF_POS, ActionEncoding.EEF_R6, ActionEncoding.JOINT_POS_BIMANUAL]: - raise ValueError(f"Cannot load `{dataset_name}`; only EEF_POS & EEF_R6 & JOINT_POS_BIMANUAL actions supported!") - - # [Contract] For EEF_POS & EEF_R6 actions, only the last action dimension (gripper) is absolute! - # Normalize all action dimensions *except* the gripper - if dataset_kwargs["action_encoding"] is ActionEncoding.EEF_POS: - dataset_kwargs["absolute_action_mask"] = [False] * 6 + [True] - dataset_kwargs["action_normalization_mask"] = [True] * 6 + [False] - elif dataset_kwargs["action_encoding"] is ActionEncoding.EEF_R6: - dataset_kwargs["absolute_action_mask"] = [False] * 9 + [True] - dataset_kwargs["action_normalization_mask"] = [True] * 9 + [False] - elif dataset_kwargs["action_encoding"] is ActionEncoding.JOINT_POS_BIMANUAL: - dataset_kwargs["absolute_action_mask"] = [True] * 14 - dataset_kwargs["action_normalization_mask"] = [True] * 14 - dataset_kwargs["action_proprio_normalization_type"] = action_proprio_normalization_type - - # Adjust Loaded Camera Views - if len(missing_keys := (set(load_camera_views) - set(dataset_kwargs["image_obs_keys"]))) > 0: - raise ValueError(f"Cannot load `{dataset_name}`; missing camera views `{missing_keys}`") - - # Filter - dataset_kwargs["image_obs_keys"] = { - k: v for k, v in dataset_kwargs["image_obs_keys"].items() if k in load_camera_views - } - dataset_kwargs["depth_obs_keys"] = { - k: v for k, v in dataset_kwargs["depth_obs_keys"].items() if k in load_camera_views - } - - # Eliminate Unnecessary Keys - dataset_kwargs.pop("state_encoding") - dataset_kwargs.pop("action_encoding") - if not load_depth: - dataset_kwargs.pop("depth_obs_keys") - if not load_proprio: - dataset_kwargs.pop("state_obs_keys") - - # Load Language - if load_language: - dataset_kwargs["language_key"] = "language_instruction" - - # Specify Standardization Transform - dataset_kwargs["standardize_fn"] = OXE_STANDARDIZATION_TRANSFORMS[dataset_name] - - # Add any aux arguments - if "aux_kwargs" in dataset_kwargs: - dataset_kwargs.update(dataset_kwargs.pop("aux_kwargs")) - - return {"name": dataset_name, "data_dir": str(data_root_dir), **dataset_kwargs} - - -def get_oxe_dataset_kwargs_and_weights( - data_root_dir: Path, - mixture_spec: List[Tuple[str, float]], - load_camera_views: Tuple[str] = ("primary",), - load_depth: bool = False, - load_proprio: bool = True, - load_language: bool = True, - action_proprio_normalization_type = ACTION_PROPRIO_NORMALIZATION_TYPE, -) -> Tuple[Dict[str, Any], List[float]]: - """ - Generates dataset kwargs for a given dataset mix from the Open X-Embodiment dataset. The returned kwargs - (per-dataset configs) and weights can be passed directly to `make_interleaved_dataset`. - - :param data_root_dir: Base directory containing RLDS/TFDS-formatted datasets (from Open-X) - :param mixture_spec: List of (dataset_name, sampling_weight) from `oxe.mixtures.OXE_NAMED_MIXTURES` - :param load_camera_views: Camera views to load; see `oxe.dataset_configs.py` for available views. - :param load_depth: Load depth information in addition to camera RGB. - :param load_proprio: Load proprioceptive state. - :param load_language: Load language instructions. - :param action_proprio_normalization_type: Normalization scheme to use for proprioceptive actions. - - return: Tuple of (per_dataset_kwargs, sampling_weights) - """ - included_datasets, filtered_mixture_spec = set(), [] - for d_name, d_weight in mixture_spec: - if d_name in included_datasets: - overwatch.warning(f"Skipping Duplicate Dataset: `{(d_name, d_weight)}`") - continue - - included_datasets.add(d_name) - filtered_mixture_spec.append((d_name, d_weight)) - - # Assemble Dataset Config (kwargs) and Weights - per_dataset_kwargs, sampling_weights = [], [] - for d_name, d_weight in filtered_mixture_spec: - try: - per_dataset_kwargs.append( - make_oxe_dataset_kwargs( - d_name, - data_root_dir, - load_camera_views, - load_depth, - load_proprio, - load_language, - action_proprio_normalization_type, - ) - ) - sampling_weights.append(d_weight) - - except ValueError as e: - overwatch.warning(f"Skipping `{d_name}` due to Error: {e}") - - return per_dataset_kwargs, sampling_weights diff --git a/capvector-oft/prismatic/vla/datasets/rlds/oxe/mixtures.py b/capvector-oft/prismatic/vla/datasets/rlds/oxe/mixtures.py deleted file mode 100644 index 0f2c3732c4068ddf08b8f94362304799cebaab83..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/vla/datasets/rlds/oxe/mixtures.py +++ /dev/null @@ -1,230 +0,0 @@ -""" -mixtures.py - -Defines a registry of dataset mixtures and weights for the Open-X Embodiment Datasets. Each dataset is associated with -a float "sampling weight" -""" - -from typing import Dict, List, Tuple - -# fmt: off -OXE_NAMED_MIXTURES: Dict[str, List[Tuple[str, float]]] = { - # === Bridge V2 Dataset === - "bridge": [ - # ("bridge_oxe", 1.0), # Version of Bridge V2 in Open-X GCP Bucket - ("bridge_orig", 1.0), # Original Version of Bridge V2 from Project Website - ], - - - # === [Moderate-Scale] Bridge++ Mixtures === - "bridge_rt_1": [ - # ("bridge_oxe", 1.0) # Version of Bridge V2 in Open-X GCP Bucket - ("bridge_orig", 1.0), # Original Version of Bridge V2 from Project Website - - ("fractal20220817_data", 1.0), # Google RT-1 Robot Data (Large-Scale) - ], - - # === RT-X Mixtures === - "rtx": [ - ("fractal20220817_data", 0.54087122203), # Google RT-1 Robot Data (Large-Scale) - ("kuka", 0.8341046294), - # ("bridge_oxe", 1.0) # Version of Bridge V2 in Open-X GCP Bucket - ("bridge_orig", 1.0), # Original Version of Bridge V2 from Project Website - ("taco_play", 2.0), - ("jaco_play", 2.0), - ("berkeley_cable_routing", 3.0), - ("roboturk", 1.0), - # ("nyu_door_opening_surprising_effectiveness", 5.0), # Note --> only contains wrist camera images (skip?) - ("viola", 2.0), - ("berkeley_autolab_ur5", 1.0), - ("toto", 1.0), - ], - - "rtx_franka": [ - ("fractal20220817_data", 0.54087122203), # Google RT-1 Robot Data (Large-Scale) - ("kuka", 0.8341046294), - # ("bridge_oxe", 1.0) # Version of Bridge V2 in Open-X GCP Bucket - ("bridge_orig", 1.0), # Original Version of Bridge V2 from Project Website - ("taco_play", 2.0), - ("jaco_play", 2.0), - ("berkeley_cable_routing", 3.0), - ("roboturk", 1.0), - # ("nyu_door_opening_surprising_effectiveness", 5.0), # Note --> only contains wrist camera images (skip?) - ("viola", 2.0), - ("berkeley_autolab_ur5", 1.0), - ("toto", 1.0), - - ("taco_play", 1.0), - ("berkeley_cable_routing", 1.0), - ("viola", 1.0), - ("toto", 1.0), - ("stanford_hydra_dataset_converted_externally_to_rlds", 1.0), - ("austin_buds_dataset_converted_externally_to_rlds", 3.0), - ("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0), - ("maniskill_dataset_converted_externally_to_rlds", 0.1), - ("furniture_bench_dataset_converted_externally_to_rlds", 0.1), - ("cmu_franka_exploration_dataset_converted_externally_to_rlds", 5.0), - ("austin_sailor_dataset_converted_externally_to_rlds", 1.0), - ("austin_sirius_dataset_converted_externally_to_rlds", 1.0), - ("berkeley_rpt_converted_externally_to_rlds", 1.0), - ("kaist_nonprehensile_converted_externally_to_rlds", 3.0), - ("stanford_robocook_converted_externally_to_rlds", 1.0), - ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0), - ("utaustin_mutex", 1.0), - ("cmu_play_fusion", 1.0), - ], - - # === Open-X Magic Soup === - "oxe_magic_soup": [ - ("fractal20220817_data", 0.54087122203), # Google RT-1 Robot Data (Large-Scale) - ("kuka", 0.8341046294), - # ("bridge_oxe", 1.0) # Version of Bridge V2 in Open-X GCP Bucket - ("bridge_orig", 1.0), # Original Version of Bridge V2 from Project Website - ("taco_play", 2.0), - ("jaco_play", 1.0), - ("berkeley_cable_routing", 1.0), - ("roboturk", 2.0), - # ("nyu_door_opening_surprising_effectiveness", 1.0), # Note --> only contains wrist camera images (skip?) - ("viola", 2.0), - ("berkeley_autolab_ur5", 2.0), - ("toto", 1.0), - ("language_table", 0.1), - ("stanford_hydra_dataset_converted_externally_to_rlds", 2.0), - ("austin_buds_dataset_converted_externally_to_rlds", 1.0), - ("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0), - ("furniture_bench_dataset_converted_externally_to_rlds", 0.1), - ("ucsd_kitchen_dataset_converted_externally_to_rlds", 2.0), - ("austin_sailor_dataset_converted_externally_to_rlds", 1.0), - ("austin_sirius_dataset_converted_externally_to_rlds", 1.0), - # ("bc_z", 0.2), # Note --> raw data is broken! - ("dlr_edan_shared_control_converted_externally_to_rlds", 1.0), - ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0), - # ("uiuc_d3field", 1.0), # Note --> raw data is broken! - ("utaustin_mutex", 1.0), - ("berkeley_fanuc_manipulation", 2.0), - ("cmu_stretch", 1.0), - ], - - # === Open-X Magic Soup++ === - "oxe_magic_soup_plus": [ - ("fractal20220817_data", 0.54087122203), # Google RT-1 Robot Data (Large-Scale) - ("kuka", 0.8341046294), - ("bridge_orig", 1.0), # Original Version of Bridge V2 from Project Website - ("taco_play", 2.0), - ("jaco_play", 1.0), - ("berkeley_cable_routing", 1.0), - ("roboturk", 2.0), - ("viola", 2.0), - ("berkeley_autolab_ur5", 2.0), - ("toto", 1.0), - ("language_table", 0.1), - ("stanford_hydra_dataset_converted_externally_to_rlds", 2.0), - ("austin_buds_dataset_converted_externally_to_rlds", 1.0), - ("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0), - ("furniture_bench_dataset_converted_externally_to_rlds", 0.1), - ("ucsd_kitchen_dataset_converted_externally_to_rlds", 2.0), - ("austin_sailor_dataset_converted_externally_to_rlds", 1.0), - ("austin_sirius_dataset_converted_externally_to_rlds", 1.0), - ("dlr_edan_shared_control_converted_externally_to_rlds", 1.0), - ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0), - ("utaustin_mutex", 1.0), - ("berkeley_fanuc_manipulation", 2.0), - ("cmu_stretch", 1.0), - ## New Datasets in MagicSoup++ - ("bc_z", 0.2), # Note: use v0.1.0 --> later versions broken - ("fmb_dataset", 1.0), - ("dobbe", 0.2), - ("droid", 0.06), - ], - - "oxe_magic_soup_plus_minus": [ - ("fractal20220817_data", 1.0), # Google RT-1 Robot Data (Large-Scale) - ("kuka", 0.8341046294), - ("bridge_orig", 1.0), # Original Version of Bridge V2 from Project Website - ("taco_play", 2.0), - ("jaco_play", 1.0), - ("berkeley_cable_routing", 1.0), - ("roboturk", 2.0), - ("viola", 2.0), - ("berkeley_autolab_ur5", 2.0), - ("toto", 1.0), - # ("language_table", 0.1), - ("stanford_hydra_dataset_converted_externally_to_rlds", 2.0), - ("austin_buds_dataset_converted_externally_to_rlds", 1.0), - ("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0), - ("furniture_bench_dataset_converted_externally_to_rlds", 0.1), - ("ucsd_kitchen_dataset_converted_externally_to_rlds", 2.0), - ("austin_sailor_dataset_converted_externally_to_rlds", 1.0), - ("austin_sirius_dataset_converted_externally_to_rlds", 1.0), - ("dlr_edan_shared_control_converted_externally_to_rlds", 1.0), - ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0), - ("utaustin_mutex", 1.0), - ("berkeley_fanuc_manipulation", 2.0), - ("cmu_stretch", 1.0), - ## New Datasets in MagicSoup++ - ("bc_z", 0.2), # Note: use v0.1.0 --> later versions broken - ("fmb_dataset", 1.0), - ("dobbe", 0.2), - # ("droid", 0.06), - ], - - # === T-DROID Dataset === - "tdroid_carrot_in_bowl": [ - ("tdroid_carrot_in_bowl", 1.0), - ], - "tdroid_pour_corn_in_pot": [ - ("tdroid_pour_corn_in_pot", 1.0), - ], - "tdroid_flip_pot_upright": [ - ("tdroid_flip_pot_upright", 1.0), - ], - "tdroid_move_object_onto_plate": [ - ("tdroid_move_object_onto_plate", 1.0), - ], - "tdroid_knock_object_over": [ - ("tdroid_knock_object_over", 1.0), - ], - "tdroid_cover_object_with_towel": [ - ("tdroid_cover_object_with_towel", 1.0), - ], - - # === DROID Finetuning Datasets === - "droid_wipe": [ - ("droid_wipe", 1.0), - ], - - # === LIBERO Datasets (Modified Versions) === - "libero_spatial_no_noops": [ - ("libero_spatial_no_noops", 1.0), - ], - "libero_object_no_noops": [ - ("libero_object_no_noops", 1.0), - ], - "libero_goal_no_noops": [ - ("libero_goal_no_noops", 1.0), - ], - "libero_10_no_noops": [ - ("libero_10_no_noops", 1.0), - ], - "libero_4_task_suites_no_noops": [ - ("libero_spatial_no_noops", 1.0), - ("libero_object_no_noops", 1.0), - ("libero_goal_no_noops", 1.0), - ("libero_10_no_noops", 1.0), - ], - - # === ALOHA Fine-Tuning Datasets === - "aloha1_fold_shorts_20_demos": [ - ("aloha1_fold_shorts_20_demos", 1.0), - ], - "aloha1_fold_shirt_30_demos": [ - ("aloha1_fold_shirt_30_demos", 1.0), - ], - "aloha1_scoop_X_into_bowl_45_demos": [ - ("aloha1_scoop_X_into_bowl_45_demos", 1.0), - ], - "aloha1_put_X_into_pot_300_demos": [ - ("aloha1_put_X_into_pot_300_demos", 1.0), - ], -# fmt: on -} diff --git a/capvector-oft/prismatic/vla/datasets/rlds/oxe/transforms.py b/capvector-oft/prismatic/vla/datasets/rlds/oxe/transforms.py deleted file mode 100644 index f4853abdda82517e34a4b9d806253a56d07d7124..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/vla/datasets/rlds/oxe/transforms.py +++ /dev/null @@ -1,933 +0,0 @@ -""" -transforms.py - -Defines a registry of per-dataset standardization transforms for each dataset in Open-X Embodiment. - -Transforms adopt the following structure: - Input: Dictionary of *batched* features (i.e., has leading time dimension) - Output: Dictionary `step` =>> { - "observation": { - - State (in chosen state representation) - }, - "action": Action (in chosen action representation), - "language_instruction": str - } -""" - -from typing import Any, Dict - -import tensorflow as tf - -from prismatic.vla.datasets.rlds.oxe.utils.droid_utils import droid_baseact_transform, droid_finetuning_transform -from prismatic.vla.datasets.rlds.utils.data_utils import ( - binarize_gripper_actions, - invert_gripper_actions, - rel2abs_gripper_actions, - relabel_bridge_actions, -) - - -def bridge_oxe_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - """ - Applies to version of Bridge V2 in Open X-Embodiment mixture. - - Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it! - """ - for key in trajectory.keys(): - if key == "traj_metadata": - continue - elif key in ["observation", "action"]: - for key2 in trajectory[key]: - trajectory[key][key2] = trajectory[key][key2][1:] - else: - trajectory[key] = trajectory[key][1:] - - trajectory["action"] = tf.concat( - ( - trajectory["action"]["world_vector"], - trajectory["action"]["rotation_delta"], - tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32), - ), - axis=-1, - ) - trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] - trajectory = relabel_bridge_actions(trajectory) - trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6] - trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] - return trajectory - - -def bridge_orig_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - """ - Applies to original version of Bridge V2 from the official project website. - - Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it! - """ - for key in trajectory.keys(): - if key == "traj_metadata": - continue - elif key == "observation": - for key2 in trajectory[key]: - trajectory[key][key2] = trajectory[key][key2][1:] - else: - trajectory[key] = trajectory[key][1:] - - trajectory["action"] = tf.concat( - [ - trajectory["action"][:, :6], - binarize_gripper_actions(trajectory["action"][:, -1])[:, None], - ], - axis=1, - ) - trajectory = relabel_bridge_actions(trajectory) - trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6] - trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] - return trajectory - - -def ppgm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["action"] = tf.concat( - [ - trajectory["action"][:, :6], - binarize_gripper_actions(trajectory["action"][:, -1])[:, None], - ], - axis=1, - ) - trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"][:, :6] - trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"][:, -1:] - return trajectory - - -def rt1_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - # make gripper action absolute action, +1 = open, 0 = close - gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0] - gripper_action = rel2abs_gripper_actions(gripper_action) - - trajectory["action"] = tf.concat( - ( - trajectory["action"]["world_vector"], - trajectory["action"]["rotation_delta"], - gripper_action[:, None], - ), - axis=-1, - ) - trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] - return trajectory - - -def kuka_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - # make gripper action absolute action, +1 = open, 0 = close - gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0] - gripper_action = rel2abs_gripper_actions(gripper_action) - - trajectory["action"] = tf.concat( - ( - trajectory["action"]["world_vector"], - trajectory["action"]["rotation_delta"], - gripper_action[:, None], - ), - axis=-1, - ) - # decode compressed state - eef_value = tf.io.decode_compressed( - trajectory["observation"]["clip_function_input/base_pose_tool_reached"], - compression_type="ZLIB", - ) - eef_value = tf.io.decode_raw(eef_value, tf.float32) - trajectory["observation"]["clip_function_input/base_pose_tool_reached"] = tf.reshape(eef_value, (-1, 7)) - gripper_value = tf.io.decode_compressed(trajectory["observation"]["gripper_closed"], compression_type="ZLIB") - gripper_value = tf.io.decode_raw(gripper_value, tf.float32) - trajectory["observation"]["gripper_closed"] = tf.reshape(gripper_value, (-1, 1)) - # trajectory["language_instruction"] = tf.fill( - # tf.shape(trajectory["observation"]["natural_language_instruction"]), "" - # ) # delete uninformative language instruction - trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] - return trajectory - - -def taco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["state_eef"] = trajectory["observation"]["robot_obs"][:, :6] - trajectory["observation"]["state_gripper"] = trajectory["observation"]["robot_obs"][:, 7:8] - trajectory["action"] = trajectory["action"]["rel_actions_world"] - - # invert gripper action + clip, +1 = open, 0 = close - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :6], - tf.clip_by_value(trajectory["action"][:, -1:], 0, 1), - ), - axis=-1, - ) - - trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] - return trajectory - - -def jaco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["state_eef"] = trajectory["observation"]["end_effector_cartesian_pos"][:, :6] - trajectory["observation"]["state_gripper"] = trajectory["observation"]["end_effector_cartesian_pos"][:, -1:] - - # make gripper action absolute action, +1 = open, 0 = close - gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0] - gripper_action = rel2abs_gripper_actions(gripper_action) - - trajectory["action"] = tf.concat( - ( - trajectory["action"]["world_vector"], - tf.zeros_like(trajectory["action"]["world_vector"]), - gripper_action[:, None], - ), - axis=-1, - ) - trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] - return trajectory - - -def berkeley_cable_routing_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["action"] = tf.concat( - ( - trajectory["action"]["world_vector"], - trajectory["action"]["rotation_delta"], - tf.zeros_like(trajectory["action"]["world_vector"][:, :1]), - ), - axis=-1, - ) - # trajectory["language_instruction"] = tf.fill( - # tf.shape(trajectory["observation"]["natural_language_instruction"]), "" - # ) # delete uninformative language instruction - trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] - return trajectory - - -def roboturk_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - # invert absolute gripper action, +1 = open, 0 = close - gripper_action = invert_gripper_actions(tf.clip_by_value(trajectory["action"]["gripper_closedness_action"], 0, 1)) - - trajectory["action"] = tf.concat( - ( - trajectory["action"]["world_vector"], - trajectory["action"]["rotation_delta"], - gripper_action, - ), - axis=-1, - ) - # trajectory["language_instruction"] = tf.fill( - # tf.shape(trajectory["observation"]["natural_language_instruction"]), "" - # ) # delete uninformative language instruction - trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] - return trajectory - - -def nyu_door_opening_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - # make gripper action absolute action, +1 = open, 0 = close - gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0] - gripper_action = rel2abs_gripper_actions(gripper_action) - - trajectory["action"] = tf.concat( - ( - trajectory["action"]["world_vector"], - trajectory["action"]["rotation_delta"], - gripper_action[:, None], - ), - axis=-1, - ) - # trajectory["language_instruction"] = tf.fill( - # tf.shape(trajectory["observation"]["natural_language_instruction"]), "" - # ) # delete uninformative language instruction - trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] - return trajectory - - -def viola_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - # make gripper action, +1 = open, 0 = close - gripper_action = trajectory["action"]["gripper_closedness_action"][:, None] - gripper_action = tf.clip_by_value(gripper_action, 0, 1) - gripper_action = invert_gripper_actions(gripper_action) - - trajectory["action"] = tf.concat( - ( - trajectory["action"]["world_vector"], - trajectory["action"]["rotation_delta"], - gripper_action, - ), - axis=-1, - ) - # trajectory["language_instruction"] = tf.fill( - # tf.shape(trajectory["observation"]["natural_language_instruction"]), "" - # ) # delete uninformative language instruction - trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] - return trajectory - - -def berkeley_autolab_ur5_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["state"] = trajectory["observation"]["robot_state"][:, 6:14] - trajectory["observation"]["depth"] = trajectory["observation"].pop("image_with_depth") - - # make gripper action absolute action, +1 = open, 0 = close - gripper_action = trajectory["action"]["gripper_closedness_action"] - gripper_action = rel2abs_gripper_actions(gripper_action) - - trajectory["action"] = tf.concat( - ( - trajectory["action"]["world_vector"], - trajectory["action"]["rotation_delta"], - gripper_action[:, None], - ), - axis=-1, - ) - trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] - return trajectory - - -def toto_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["action"] = tf.concat( - ( - trajectory["action"]["world_vector"], - trajectory["action"]["rotation_delta"], - tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32), - ), - axis=-1, - ) - # trajectory["language_instruction"] = tf.fill( - # tf.shape(trajectory["observation"]["natural_language_instruction"]), "" - # ) # delete uninformative language instruction - trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] - return trajectory - - -def language_table_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - # default to "open" gripper - trajectory["action"] = tf.concat( - ( - trajectory["action"], - tf.zeros_like(trajectory["action"]), - tf.zeros_like(trajectory["action"]), - tf.ones_like(trajectory["action"][:, :1]), - ), - axis=-1, - ) - - # decode language instruction - instruction_bytes = trajectory["observation"]["instruction"] - instruction_encoded = tf.strings.unicode_encode(instruction_bytes, output_encoding="UTF-8") - # Remove trailing padding --> convert RaggedTensor to regular Tensor. - trajectory["language_instruction"] = tf.strings.split(instruction_encoded, "\x00")[:, :1].to_tensor()[:, 0] - return trajectory - - -def pusht_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["action"] = tf.concat( - ( - trajectory["action"]["world_vector"], - trajectory["action"]["rotation_delta"], - trajectory["action"]["gripper_closedness_action"][:, None], - ), - axis=-1, - ) - trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] - return trajectory - - -def stanford_kuka_multimodal_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["depth_image"] = trajectory["observation"]["depth_image"][..., 0] - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :3], - tf.zeros_like(trajectory["action"][:, :3]), - trajectory["action"][:, -1:], - ), - axis=-1, - ) - return trajectory - - -def nyu_rot_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][..., :6] - trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., -1:] - trajectory["action"] = trajectory["action"][..., :7] - return trajectory - - -def stanford_hydra_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - # invert gripper action, +1 = open, 0 = close - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :6], - invert_gripper_actions(trajectory["action"][:, -1:]), - ), - axis=-1, - ) - - trajectory["observation"]["eef_state"] = tf.concat( - ( - trajectory["observation"]["state"][:, :3], - trajectory["observation"]["state"][:, 7:10], - ), - axis=-1, - ) - trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -3:-2] - # trajectory["language_instruction"] = tf.fill( - # tf.shape(trajectory["language_instruction"]), "" - # ) # delete uninformative language instruction - return trajectory - - -def austin_buds_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - # invert gripper action + clip, +1 = open, 0 = close - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :6], - invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)), - ), - axis=-1, - ) - - trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8] - # trajectory["language_instruction"] = tf.fill( - # tf.shape(trajectory["language_instruction"]), "" - # ) # delete uninformative language instruction - return trajectory - - -def nyu_franka_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["depth"] = tf.cast(trajectory["observation"]["depth"][..., 0], tf.float32) - trajectory["observation"]["depth_additional_view"] = tf.cast( - trajectory["observation"]["depth_additional_view"][..., 0], tf.float32 - ) - trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, -6:] - - # clip gripper action, +1 = open, 0 = close - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, -8:-2], - tf.clip_by_value(trajectory["action"][:, -2:-1], 0, 1), - ), - axis=-1, - ) - - # trajectory["language_instruction"] = tf.fill( - # tf.shape(trajectory["language_instruction"]), "" - # ) # delete uninformative language instruction - return trajectory - - -def maniskill_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., 7:8] - return trajectory - - -def furniture_bench_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - import tensorflow_graphics.geometry.transformation as tft - - trajectory["observation"]["state"] = tf.concat( - ( - trajectory["observation"]["state"][:, :7], - trajectory["observation"]["state"][:, -1:], - ), - axis=-1, - ) - - # invert gripper action + clip, +1 = open, 0 = close - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :3], - tft.euler.from_quaternion(trajectory["action"][:, 3:7]), - invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)), - ), - axis=-1, - ) - return trajectory - - -def cmu_franka_exploration_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["action"] = trajectory["action"][..., :-1] - return trajectory - - -def ucsd_kitchen_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7] - trajectory["action"] = trajectory["action"][..., :-1] - return trajectory - - -def ucsd_pick_place_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6] - trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :3], - tf.zeros_like(trajectory["action"][:, :3]), - trajectory["action"][:, -1:], - ), - axis=-1, - ) - return trajectory - - -def austin_sailor_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - # invert gripper action + clip, +1 = open, 0 = close - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :6], - invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)), - ), - axis=-1, - ) - - # trajectory["language_instruction"] = tf.fill( - # tf.shape(trajectory["language_instruction"]), "" - # ) # delete uninformative language instruction - return trajectory - - -def austin_sirius_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - # invert gripper action + clip, +1 = open, 0 = close - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :6], - invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)), - ), - axis=-1, - ) - - # trajectory["language_instruction"] = tf.fill( - # tf.shape(trajectory["language_instruction"]), "" - # ) # delete uninformative language instruction - return trajectory - - -def bc_z_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["action"] = tf.concat( - ( - trajectory["action"]["future/xyz_residual"][:, :3], - trajectory["action"]["future/axis_angle_residual"][:, :3], - invert_gripper_actions(tf.cast(trajectory["action"]["future/target_close"][:, :1], tf.float32)), - ), - axis=-1, - ) - trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] - return trajectory - - -def tokyo_pr2_opening_fridge_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6] - trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] - trajectory["action"] = trajectory["action"][..., :-1] - return trajectory - - -def tokyo_pr2_tabletop_manipulation_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6] - trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] - trajectory["action"] = trajectory["action"][..., :-1] - return trajectory - - -def utokyo_xarm_pick_place_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - return trajectory - - -def utokyo_xarm_bimanual_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["action"] = trajectory["action"][..., -7:] - return trajectory - - -def robo_net_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["eef_state"] = tf.concat( - ( - trajectory["observation"]["state"][:, :4], - tf.zeros_like(trajectory["observation"]["state"][:, :2]), - ), - axis=-1, - ) - trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :4], - tf.zeros_like(trajectory["action"][:, :2]), - trajectory["action"][:, -1:], - ), - axis=-1, - ) - return trajectory - - -def berkeley_mvp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - return trajectory - - -def berkeley_rpt_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - return trajectory - - -def kaist_nonprehensible_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["state"] = trajectory["observation"]["state"][:, -7:] - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :6], - tf.zeros_like(trajectory["action"][:, :1]), - ), - axis=-1, - ) - return trajectory - - -def stanford_mask_vit_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["eef_state"] = tf.concat( - ( - trajectory["observation"]["end_effector_pose"][:, :4], - tf.zeros_like(trajectory["observation"]["end_effector_pose"][:, :2]), - ), - axis=-1, - ) - trajectory["observation"]["gripper_state"] = trajectory["observation"]["end_effector_pose"][:, -1:] - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :4], - tf.zeros_like(trajectory["action"][:, :2]), - trajectory["action"][:, -1:], - ), - axis=-1, - ) - return trajectory - - -def tokyo_lsmo_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6] - trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] - return trajectory - - -def dlr_sara_pour_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - return trajectory - - -def dlr_sara_grid_clamp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :6] - return trajectory - - -def dlr_edan_shared_control_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - # invert gripper action, +1 = open, 0 = close - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :6], - invert_gripper_actions(trajectory["action"][:, -1:]), - ), - axis=-1, - ) - return trajectory - - -def asu_table_top_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["eef_state"] = trajectory["ground_truth_states"]["EE"] - trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] - return trajectory - - -def robocook_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6] - trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] - return trajectory - - -def imperial_wristcam_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["action"] = trajectory["action"][..., :-1] - return trajectory - - -def iamlab_pick_insert_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - import tensorflow_graphics.geometry.transformation as tft - - trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7] - trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 7:8] - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :3], - tft.euler.from_quaternion(trajectory["action"][:, 3:7]), - trajectory["action"][:, 7:8], - ), - axis=-1, - ) - return trajectory - - -def uiuc_d3field_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["action"] = tf.concat( - ( - trajectory["action"], - tf.zeros_like(trajectory["action"]), - tf.zeros_like(trajectory["action"][:, :1]), - ), - axis=-1, - ) - return trajectory - - -def utaustin_mutex_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8] - - # invert gripper action + clip, +1 = open, 0 = close - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :6], - invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)), - ), - axis=-1, - ) - - # trajectory["language_instruction"] = tf.fill( - # tf.shape(trajectory["language_instruction"]), "" - # ) # delete uninformative language instruction - return trajectory - - -def berkeley_fanuc_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :6] - trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 6:7] - - # dataset does not store gripper actions, so use gripper state info, invert so +1 = open, 0 = close - trajectory["action"] = tf.concat( - ( - trajectory["action"], - invert_gripper_actions(trajectory["observation"]["gripper_state"]), - ), - axis=-1, - ) - return trajectory - - -def cmu_playing_with_food_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - import tensorflow_graphics.geometry.transformation as tft - - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :3], - tft.euler.from_quaternion(trajectory["action"][:, 3:7]), - trajectory["action"][:, -1:], - ), - axis=-1, - ) - return trajectory - - -def playfusion_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :3], - trajectory["action"][:, -4:], - ), - axis=-1, - ) - return trajectory - - -def cmu_stretch_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["eef_state"] = tf.concat( - ( - trajectory["observation"]["state"][:, :3], - tf.zeros_like(trajectory["observation"]["state"][:, :3]), - ), - axis=-1, - ) - trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] - trajectory["action"] = trajectory["action"][..., :-1] - return trajectory - - -def gnm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["state"] = tf.concat( - ( - trajectory["observation"]["position"], - tf.zeros_like(trajectory["observation"]["state"][:, :3]), - trajectory["observation"]["yaw"], - ), - axis=-1, - ) - trajectory["action"] = tf.concat( - ( - trajectory["action"], - tf.zeros_like(trajectory["action"]), - tf.zeros_like(trajectory["action"]), - tf.zeros_like(trajectory["action"][:, :1]), - ), - axis=-1, - ) - return trajectory - - -def fmb_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - # every input feature is batched, ie has leading batch dimension - trajectory["observation"]["proprio"] = tf.concat( - ( - trajectory["observation"]["eef_pose"], - trajectory["observation"]["state_gripper_pose"][..., None], - ), - axis=-1, - ) - return trajectory - - -def dobbe_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - # every input feature is batched, ie has leading batch dimension - trajectory["observation"]["proprio"] = trajectory["observation"]["state"] - return trajectory - - -def roboset_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - # every input feature is batched, ie has leading batch dimension - trajectory["observation"]["proprio"] = trajectory["observation"]["state"] - - # gripper action is in -1...1 --> clip to 0...1, flip - gripper_action = trajectory["action"][:, -1:] - gripper_action = invert_gripper_actions(tf.clip_by_value(gripper_action, 0, 1)) - - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :7], - gripper_action, - ), - axis=-1, - ) - return trajectory - - -def rh20t_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["action"] = tf.concat( - ( - trajectory["action"]["tcp_base"], - tf.cast(trajectory["action"]["gripper"][:, None], tf.float32), - ), - axis=-1, - ) - trajectory["observation"]["proprio"] = tf.concat( - ( - trajectory["observation"]["tcp_base"], - trajectory["observation"]["gripper_width"][..., None], - ), - axis=-1, - ) - return trajectory - - -def tdroid_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["action"] = tf.concat( - [ - trajectory["action"][:, :6], - binarize_gripper_actions(trajectory["action"][:, -1])[:, None], - ], - axis=1, - ) - trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"][:, :6] - trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"][:, -1:] - return trajectory - - -def libero_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - # gripper action is in -1 (open)...1 (close) --> clip to 0...1, flip --> +1 = open, 0 = close - gripper_action = trajectory["action"][:, -1:] - gripper_action = invert_gripper_actions(tf.clip_by_value(gripper_action, 0, 1)) - - trajectory["action"] = tf.concat( - [ - trajectory["action"][:, :6], - gripper_action, - ], - axis=1, - ) - trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6] - trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -2:] # 2D gripper state - return trajectory - - -def aloha_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - # Don't need to do anything because dataset is already in the correct format - return trajectory - - -# === Registry === -OXE_STANDARDIZATION_TRANSFORMS = { - "bridge_oxe": bridge_oxe_dataset_transform, - "bridge_orig": bridge_orig_dataset_transform, - "bridge_dataset": bridge_orig_dataset_transform, - "ppgm": ppgm_dataset_transform, - "ppgm_static": ppgm_dataset_transform, - "ppgm_wrist": ppgm_dataset_transform, - "fractal20220817_data": rt1_dataset_transform, - "kuka": kuka_dataset_transform, - "taco_play": taco_play_dataset_transform, - "jaco_play": jaco_play_dataset_transform, - "berkeley_cable_routing": berkeley_cable_routing_dataset_transform, - "roboturk": roboturk_dataset_transform, - "nyu_door_opening_surprising_effectiveness": nyu_door_opening_dataset_transform, - "viola": viola_dataset_transform, - "berkeley_autolab_ur5": berkeley_autolab_ur5_dataset_transform, - "toto": toto_dataset_transform, - "language_table": language_table_dataset_transform, - "columbia_cairlab_pusht_real": pusht_dataset_transform, - "stanford_kuka_multimodal_dataset_converted_externally_to_rlds": stanford_kuka_multimodal_dataset_transform, - "nyu_rot_dataset_converted_externally_to_rlds": nyu_rot_dataset_transform, - "stanford_hydra_dataset_converted_externally_to_rlds": stanford_hydra_dataset_transform, - "austin_buds_dataset_converted_externally_to_rlds": austin_buds_dataset_transform, - "nyu_franka_play_dataset_converted_externally_to_rlds": nyu_franka_play_dataset_transform, - "maniskill_dataset_converted_externally_to_rlds": maniskill_dataset_transform, - "furniture_bench_dataset_converted_externally_to_rlds": furniture_bench_dataset_transform, - "cmu_franka_exploration_dataset_converted_externally_to_rlds": cmu_franka_exploration_dataset_transform, - "ucsd_kitchen_dataset_converted_externally_to_rlds": ucsd_kitchen_dataset_transform, - "ucsd_pick_and_place_dataset_converted_externally_to_rlds": ucsd_pick_place_dataset_transform, - "austin_sailor_dataset_converted_externally_to_rlds": austin_sailor_dataset_transform, - "austin_sirius_dataset_converted_externally_to_rlds": austin_sirius_dataset_transform, - "bc_z": bc_z_dataset_transform, - "utokyo_pr2_opening_fridge_converted_externally_to_rlds": tokyo_pr2_opening_fridge_dataset_transform, - "utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": tokyo_pr2_tabletop_manipulation_dataset_transform, - "utokyo_xarm_pick_and_place_converted_externally_to_rlds": utokyo_xarm_pick_place_dataset_transform, - "utokyo_xarm_bimanual_converted_externally_to_rlds": utokyo_xarm_bimanual_dataset_transform, - "robo_net": robo_net_dataset_transform, - "berkeley_mvp_converted_externally_to_rlds": berkeley_mvp_dataset_transform, - "berkeley_rpt_converted_externally_to_rlds": berkeley_rpt_dataset_transform, - "kaist_nonprehensile_converted_externally_to_rlds": kaist_nonprehensible_dataset_transform, - "stanford_mask_vit_converted_externally_to_rlds": stanford_mask_vit_dataset_transform, - "tokyo_u_lsmo_converted_externally_to_rlds": tokyo_lsmo_dataset_transform, - "dlr_sara_pour_converted_externally_to_rlds": dlr_sara_pour_dataset_transform, - "dlr_sara_grid_clamp_converted_externally_to_rlds": dlr_sara_grid_clamp_dataset_transform, - "dlr_edan_shared_control_converted_externally_to_rlds": dlr_edan_shared_control_dataset_transform, - "asu_table_top_converted_externally_to_rlds": asu_table_top_dataset_transform, - "stanford_robocook_converted_externally_to_rlds": robocook_dataset_transform, - "imperialcollege_sawyer_wrist_cam": imperial_wristcam_dataset_transform, - "iamlab_cmu_pickup_insert_converted_externally_to_rlds": iamlab_pick_insert_dataset_transform, - "uiuc_d3field": uiuc_d3field_dataset_transform, - "utaustin_mutex": utaustin_mutex_dataset_transform, - "berkeley_fanuc_manipulation": berkeley_fanuc_dataset_transform, - "cmu_playing_with_food": cmu_playing_with_food_dataset_transform, - "cmu_play_fusion": playfusion_dataset_transform, - "cmu_stretch": cmu_stretch_dataset_transform, - "berkeley_gnm_recon": gnm_dataset_transform, - "berkeley_gnm_cory_hall": gnm_dataset_transform, - "berkeley_gnm_sac_son": gnm_dataset_transform, - "droid": droid_baseact_transform, - "fmb_dataset": fmb_dataset_transform, - "dobbe": dobbe_dataset_transform, - "roboset": roboset_dataset_transform, - "rh20t": rh20t_dataset_transform, - ### T-DROID datasets - "tdroid_carrot_in_bowl": tdroid_dataset_transform, - "tdroid_pour_corn_in_pot": tdroid_dataset_transform, - "tdroid_flip_pot_upright": tdroid_dataset_transform, - "tdroid_move_object_onto_plate": tdroid_dataset_transform, - "tdroid_knock_object_over": tdroid_dataset_transform, - "tdroid_cover_object_with_towel": tdroid_dataset_transform, - ### DROID Finetuning datasets - "droid_wipe": droid_finetuning_transform, - ### LIBERO datasets (modified versions) - "libero_spatial_no_noops": libero_dataset_transform, - "libero_object_no_noops": libero_dataset_transform, - "libero_goal_no_noops": libero_dataset_transform, - "libero_10_no_noops": libero_dataset_transform, - "libero_4_task_suites_no_noops": libero_dataset_transform, - ### ALOHA fine-tuning datasets - "aloha1_fold_shorts_20_demos": aloha_dataset_transform, - "aloha1_fold_shirt_30_demos": aloha_dataset_transform, - "aloha1_scoop_X_into_bowl_45_demos": aloha_dataset_transform, - "aloha1_put_X_into_pot_300_demos": aloha_dataset_transform, -} diff --git a/capvector-oft/prismatic/vla/datasets/rlds/oxe/utils/droid_utils.py b/capvector-oft/prismatic/vla/datasets/rlds/oxe/utils/droid_utils.py deleted file mode 100644 index b98e59bc2fb0f9b498e00eaca189c2379304e5aa..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/vla/datasets/rlds/oxe/utils/droid_utils.py +++ /dev/null @@ -1,178 +0,0 @@ -"""Episode transforms for DROID dataset.""" - -from typing import Any, Dict - -import tensorflow as tf -import tensorflow_graphics.geometry.transformation as tfg - - -def rmat_to_euler(rot_mat): - return tfg.euler.from_rotation_matrix(rot_mat) - - -def euler_to_rmat(euler): - return tfg.rotation_matrix_3d.from_euler(euler) - - -def invert_rmat(rot_mat): - return tfg.rotation_matrix_3d.inverse(rot_mat) - - -def rotmat_to_rot6d(mat): - """ - Converts rotation matrix to R6 rotation representation (first two rows in rotation matrix). - Args: - mat: rotation matrix - - Returns: 6d vector (first two rows of rotation matrix) - - """ - r6 = mat[..., :2, :] - r6_0, r6_1 = r6[..., 0, :], r6[..., 1, :] - r6_flat = tf.concat([r6_0, r6_1], axis=-1) - return r6_flat - - -def velocity_act_to_wrist_frame(velocity, wrist_in_robot_frame): - """ - Translates velocity actions (translation + rotation) from base frame of the robot to wrist frame. - Args: - velocity: 6d velocity action (3 x translation, 3 x rotation) - wrist_in_robot_frame: 6d pose of the end-effector in robot base frame - - Returns: 9d velocity action in robot wrist frame (3 x translation, 6 x rotation as R6) - - """ - R_frame = euler_to_rmat(wrist_in_robot_frame[:, 3:6]) - R_frame_inv = invert_rmat(R_frame) - - # world to wrist: dT_pi = R^-1 dT_rbt - vel_t = (R_frame_inv @ velocity[:, :3][..., None])[..., 0] - - # world to wrist: dR_pi = R^-1 dR_rbt R - dR = euler_to_rmat(velocity[:, 3:6]) - dR = R_frame_inv @ (dR @ R_frame) - dR_r6 = rotmat_to_rot6d(dR) - return tf.concat([vel_t, dR_r6], axis=-1) - - -def rand_swap_exterior_images(img1, img2): - """ - Randomly swaps the two exterior images (for training with single exterior input). - """ - return tf.cond(tf.random.uniform(shape=[]) > 0.5, lambda: (img1, img2), lambda: (img2, img1)) - - -def droid_baseact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - """ - DROID dataset transformation for actions expressed in *base* frame of the robot. - """ - dt = trajectory["action_dict"]["cartesian_velocity"][:, :3] - dR = trajectory["action_dict"]["cartesian_velocity"][:, 3:6] - - trajectory["action"] = tf.concat( - ( - dt, - dR, - 1 - trajectory["action_dict"]["gripper_position"], - ), - axis=-1, - ) - trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = ( - rand_swap_exterior_images( - trajectory["observation"]["exterior_image_1_left"], - trajectory["observation"]["exterior_image_2_left"], - ) - ) - trajectory["observation"]["proprio"] = tf.concat( - ( - trajectory["observation"]["cartesian_position"], - trajectory["observation"]["gripper_position"], - ), - axis=-1, - ) - return trajectory - - -def droid_wristact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - """ - DROID dataset transformation for actions expressed in *wrist* frame of the robot. - """ - wrist_act = velocity_act_to_wrist_frame( - trajectory["action_dict"]["cartesian_velocity"], trajectory["observation"]["cartesian_position"] - ) - trajectory["action"] = tf.concat( - ( - wrist_act, - trajectory["action_dict"]["gripper_position"], - ), - axis=-1, - ) - trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = ( - rand_swap_exterior_images( - trajectory["observation"]["exterior_image_1_left"], - trajectory["observation"]["exterior_image_2_left"], - ) - ) - trajectory["observation"]["proprio"] = tf.concat( - ( - trajectory["observation"]["cartesian_position"], - trajectory["observation"]["gripper_position"], - ), - axis=-1, - ) - return trajectory - - -def droid_finetuning_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - """ - DROID dataset transformation for actions expressed in *base* frame of the robot. - """ - dt = trajectory["action_dict"]["cartesian_velocity"][:, :3] - dR = trajectory["action_dict"]["cartesian_velocity"][:, 3:6] - trajectory["action"] = tf.concat( - ( - dt, - dR, - 1 - trajectory["action_dict"]["gripper_position"], - ), - axis=-1, - ) - trajectory["observation"]["proprio"] = tf.concat( - ( - trajectory["observation"]["cartesian_position"], - trajectory["observation"]["gripper_position"], - ), - axis=-1, - ) - return trajectory - - -def zero_action_filter(traj: Dict) -> bool: - """ - Filters transitions whose actions are all-0 (only relative actions, no gripper action). - Note: this filter is applied *after* action normalization, so need to compare to "normalized 0". - """ - DROID_Q01 = tf.convert_to_tensor( - [ - -0.7776297926902771, - -0.5803514122962952, - -0.5795090794563293, - -0.6464047729969025, - -0.7041108310222626, - -0.8895104378461838, - ] - ) - DROID_Q99 = tf.convert_to_tensor( - [ - 0.7597932070493698, - 0.5726242214441299, - 0.7351000607013702, - 0.6705610305070877, - 0.6464948207139969, - 0.8897542208433151, - ] - ) - DROID_NORM_0_ACT = 2 * (tf.zeros_like(traj["action"][:, :6]) - DROID_Q01) / (DROID_Q99 - DROID_Q01 + 1e-8) - 1 - - return tf.reduce_any(tf.math.abs(traj["action"][:, :6] - DROID_NORM_0_ACT) > 1e-5) diff --git a/capvector-oft/prismatic/vla/datasets/rlds/traj_transforms.py b/capvector-oft/prismatic/vla/datasets/rlds/traj_transforms.py deleted file mode 100644 index 9d943abbb532fa1af171aa2dc467d1a3c5114c56..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/vla/datasets/rlds/traj_transforms.py +++ /dev/null @@ -1,90 +0,0 @@ -""" -traj_transforms.py - -Contains trajectory transforms used in the orca data pipeline. Trajectory transforms operate on a dictionary -that represents a single trajectory, meaning each tensor has the same leading dimension (the trajectory length). -""" - -import logging -from typing import Dict - -import tensorflow as tf - - -def chunk_act_obs(traj: Dict, window_size: int, future_action_window_size: int = 0) -> Dict: - """ - Chunks actions and observations into the given window_size. - - "observation" keys are given a new axis (at index 1) of size `window_size` containing `window_size - 1` - observations from the past and the current observation. "action" is given a new axis (at index 1) of size - `window_size + future_action_window_size` containing `window_size - 1` actions from the past, the current - action, and `future_action_window_size` actions from the future. "pad_mask" is added to "observation" and - indicates whether an observation should be considered padding (i.e. if it had come from a timestep - before the start of the trajectory). - """ - traj_len = tf.shape(traj["action"])[0] - action_dim = traj["action"].shape[-1] - effective_traj_len = traj_len - future_action_window_size - chunk_indices = tf.broadcast_to(tf.range(-window_size + 1, 1), [effective_traj_len, window_size]) + tf.broadcast_to( - tf.range(effective_traj_len)[:, None], [effective_traj_len, window_size] - ) - - action_chunk_indices = tf.broadcast_to( - tf.range(-window_size + 1, 1 + future_action_window_size), - [effective_traj_len, window_size + future_action_window_size], - ) + tf.broadcast_to( - tf.range(effective_traj_len)[:, None], - [effective_traj_len, window_size + future_action_window_size], - ) - - floored_chunk_indices = tf.maximum(chunk_indices, 0) - - goal_timestep = tf.fill([effective_traj_len], traj_len - 1) - - floored_action_chunk_indices = tf.minimum(tf.maximum(action_chunk_indices, 0), goal_timestep[:, None]) - - traj["observation"] = tf.nest.map_structure(lambda x: tf.gather(x, floored_chunk_indices), traj["observation"]) - traj["action"] = tf.gather(traj["action"], floored_action_chunk_indices) - - # indicates whether an entire observation is padding - traj["observation"]["pad_mask"] = chunk_indices >= 0 - - # Truncate other elements of the trajectory dict - traj["task"] = tf.nest.map_structure(lambda x: tf.gather(x, tf.range(effective_traj_len)), traj["task"]) - traj["dataset_name"] = tf.gather(traj["dataset_name"], tf.range(effective_traj_len)) - traj["absolute_action_mask"] = tf.gather(traj["absolute_action_mask"], tf.range(effective_traj_len)) - - return traj - - -def subsample(traj: Dict, subsample_length: int) -> Dict: - """Subsamples trajectories to the given length.""" - traj_len = tf.shape(traj["action"])[0] - if traj_len > subsample_length: - indices = tf.random.shuffle(tf.range(traj_len))[:subsample_length] - traj = tf.nest.map_structure(lambda x: tf.gather(x, indices), traj) - - return traj - - -def add_pad_mask_dict(traj: Dict) -> Dict: - """ - Adds a dictionary indicating which elements of the observation/task should be treated as padding. - =>> traj["observation"|"task"]["pad_mask_dict"] = {k: traj["observation"|"task"][k] is not padding} - """ - traj_len = tf.shape(traj["action"])[0] - - for key in ["observation", "task"]: - pad_mask_dict = {} - for subkey in traj[key]: - # Handles "language_instruction", "image_*", and "depth_*" - if traj[key][subkey].dtype == tf.string: - pad_mask_dict[subkey] = tf.strings.length(traj[key][subkey]) != 0 - - # All other keys should not be treated as padding - else: - pad_mask_dict[subkey] = tf.ones([traj_len], dtype=tf.bool) - - traj[key]["pad_mask_dict"] = pad_mask_dict - - return traj diff --git a/capvector-oft/prismatic/vla/datasets/rlds/utils/__init__.py b/capvector-oft/prismatic/vla/datasets/rlds/utils/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/capvector-oft/prismatic/vla/datasets/rlds/utils/data_utils.py b/capvector-oft/prismatic/vla/datasets/rlds/utils/data_utils.py deleted file mode 100644 index df49bd6f8defc3ed431dd3cfd5054646f771c0f1..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/vla/datasets/rlds/utils/data_utils.py +++ /dev/null @@ -1,321 +0,0 @@ -""" -data_utils.py - -Additional RLDS-specific data utilities. -""" - -import hashlib -import json -import os -from typing import Any, Callable, Dict, List, Optional, Tuple - -import dlimp as dl -import numpy as np -import tensorflow as tf -from tqdm import tqdm - -from prismatic.overwatch import initialize_overwatch -from prismatic.vla.constants import NormalizationType - -# Initialize Overwatch =>> Wraps `logging.Logger` -overwatch = initialize_overwatch(__name__) - - -def tree_map(fn: Callable, tree: Dict) -> Dict: - return {k: tree_map(fn, v) if isinstance(v, dict) else fn(v) for k, v in tree.items()} - - -def tree_merge(*trees: Dict) -> Dict: - merged = {} - for tree in trees: - for k, v in tree.items(): - if isinstance(v, dict): - merged[k] = tree_merge(merged.get(k, {}), v) - else: - merged[k] = v - return merged - - -def to_padding(tensor: tf.Tensor) -> tf.Tensor: - if tf.debugging.is_numeric_tensor(tensor): - return tf.zeros_like(tensor) - elif tensor.dtype == tf.string: - return tf.fill(tf.shape(tensor), "") - else: - raise ValueError(f"Cannot generate padding for tensor of type {tensor.dtype}.") - - -# === State / Action Processing Primitives === - - -# ruff: noqa: B023 -def normalize_action_and_proprio(traj: Dict, metadata: Dict, normalization_type: NormalizationType): - """Normalizes the action and proprio fields of a trajectory using the given metadata.""" - keys_to_normalize = {"action": "action", "proprio": "observation/proprio"} - - if normalization_type == NormalizationType.NORMAL: - for key, traj_key in keys_to_normalize.items(): - mask = metadata[key].get("mask", tf.ones_like(metadata[key]["mean"], dtype=tf.bool)) - traj = dl.transforms.selective_tree_map( - traj, - match=lambda k, _: k == traj_key, - map_fn=lambda x: tf.where(mask, (x - metadata[key]["mean"]) / (metadata[key]["std"] + 1e-8), x), - ) - - return traj - - elif normalization_type in [NormalizationType.BOUNDS, NormalizationType.BOUNDS_Q99]: - for key, traj_key in keys_to_normalize.items(): - if normalization_type == NormalizationType.BOUNDS: - low = metadata[key]["min"] - high = metadata[key]["max"] - elif normalization_type == NormalizationType.BOUNDS_Q99: - low = metadata[key]["q01"] - high = metadata[key]["q99"] - mask = metadata[key].get("mask", tf.ones_like(metadata[key]["min"], dtype=tf.bool)) - traj = dl.transforms.selective_tree_map( - traj, - match=lambda k, _: k == traj_key, - map_fn=lambda x: tf.where( - mask, - tf.clip_by_value(2 * (x - low) / (high - low + 1e-8) - 1, -1, 1), - x, - ), - ) - - # Note (Moo Jin): Map unused action dimensions (i.e., dimensions where min == max) to all 0s. - zeros_mask = metadata[key]["min"] == metadata[key]["max"] - traj = dl.transforms.selective_tree_map( - traj, match=lambda k, _: k == traj_key, map_fn=lambda x: tf.where(zeros_mask, 0.0, x) - ) - - return traj - - raise ValueError(f"Unknown Normalization Type {normalization_type}") - - -def binarize_gripper_actions(actions: tf.Tensor) -> tf.Tensor: - """ - Converts gripper actions from continuous to binary values (0 and 1). - - We exploit that fact that most of the time, the gripper is fully open (near 1.0) or fully closed (near 0.0). As it - transitions between the two, it sometimes passes through a few intermediate values. We relabel those intermediate - values based on the state that is reached _after_ those intermediate values. - - In the edge case that the trajectory ends with an intermediate value, we give up on binarizing and relabel that - chunk of intermediate values as the last action in the trajectory. - - The `scan_fn` implements the following logic: - new_actions = np.empty_like(actions) - carry = actions[-1] - for i in reversed(range(actions.shape[0])): - if in_between_mask[i]: - carry = carry - else: - carry = float(open_mask[i]) - new_actions[i] = carry - """ - open_mask, closed_mask = actions > 0.95, actions < 0.05 - in_between_mask = tf.logical_not(tf.logical_or(open_mask, closed_mask)) - is_open_float = tf.cast(open_mask, tf.float32) - - def scan_fn(carry, i): - return tf.cond(in_between_mask[i], lambda: tf.cast(carry, tf.float32), lambda: is_open_float[i]) - - return tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), actions[-1], reverse=True) - - -def invert_gripper_actions(actions: tf.Tensor) -> tf.Tensor: - return 1 - actions - - -def rel2abs_gripper_actions(actions: tf.Tensor) -> tf.Tensor: - """ - Converts relative gripper actions (+1 for closing, -1 for opening) to absolute actions (0 = closed; 1 = open). - - Assumes that the first relative gripper is not redundant (i.e. close when already closed)! - """ - # Note =>> -1 for closing, 1 for opening, 0 for no change - opening_mask, closing_mask = actions < -0.1, actions > 0.1 - thresholded_actions = tf.where(opening_mask, 1, tf.where(closing_mask, -1, 0)) - - def scan_fn(carry, i): - return tf.cond(thresholded_actions[i] == 0, lambda: carry, lambda: thresholded_actions[i]) - - # If no relative grasp, assumes open for whole trajectory - start = -1 * thresholded_actions[tf.argmax(thresholded_actions != 0, axis=0)] - start = tf.cond(start == 0, lambda: 1, lambda: start) - - # Note =>> -1 for closed, 1 for open - new_actions = tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), start) - new_actions = tf.cast(new_actions, tf.float32) / 2 + 0.5 - - return new_actions - - -# === Bridge-V2 =>> Dataset-Specific Transform === -def relabel_bridge_actions(traj: Dict[str, Any]) -> Dict[str, Any]: - """Relabels actions to use reached proprioceptive state; discards last timestep (no-action).""" - movement_actions = traj["observation"]["state"][1:, :6] - traj["observation"]["state"][:-1, :6] - traj_truncated = tf.nest.map_structure(lambda x: x[:-1], traj) - traj_truncated["action"] = tf.concat([movement_actions, traj["action"][:-1, -1:]], axis=1) - - return traj_truncated - - -# === RLDS Dataset Initialization Utilities === -def pprint_data_mixture(dataset_kwargs_list: List[Dict[str, Any]], dataset_weights: List[int]) -> None: - print("\n######################################################################################") - print(f"# Loading the following {len(dataset_kwargs_list)} datasets (incl. sampling weight):{'': >24} #") - for dataset_kwargs, weight in zip(dataset_kwargs_list, dataset_weights): - pad = 80 - len(dataset_kwargs["name"]) - print(f"# {dataset_kwargs['name']}: {weight:=>{pad}f} #") - print("######################################################################################\n") - - -def get_dataset_statistics( - dataset: dl.DLataset, - hash_dependencies: Tuple[str, ...], - save_dir: Optional[str] = None, -) -> Dict: - """ - Either computes the statistics of a dataset or loads them from a cache file if this function has been called before - with the same `hash_dependencies`. - - Currently, the statistics include the min/max/mean/std of the actions and proprio as well as the number of - transitions and trajectories in the dataset. - """ - unique_hash = hashlib.sha256("".join(hash_dependencies).encode("utf-8"), usedforsecurity=False).hexdigest() - - # Fallback local path for when data_dir is not writable or not provided - local_path = os.path.expanduser(os.path.join("~", ".cache", "orca", f"dataset_statistics_{unique_hash}.json")) - if save_dir is not None: - path = tf.io.gfile.join(save_dir, f"dataset_statistics_{unique_hash}.json") - else: - path = local_path - - # check if cache file exists and load - if tf.io.gfile.exists(path): - overwatch.info(f"Loading existing dataset statistics from {path}.") - with tf.io.gfile.GFile(path, "r") as f: - metadata = json.load(f) - return metadata - - if os.path.exists(local_path): - overwatch.info(f"Loading existing dataset statistics from {local_path}.") - with open(local_path, "r") as f: - metadata = json.load(f) - return metadata - - dataset = dataset.traj_map( - lambda traj: { - "action": traj["action"], - "proprio": ( - traj["observation"]["proprio"] if "proprio" in traj["observation"] else tf.zeros_like(traj["action"]) - ), - } - ) - - cardinality = dataset.cardinality().numpy() - if cardinality == tf.data.INFINITE_CARDINALITY: - raise ValueError("Cannot compute dataset statistics for infinite datasets.") - - overwatch.info("Computing dataset statistics. This may take a bit, but should only need to happen once.") - actions, proprios, num_transitions, num_trajectories = [], [], 0, 0 - for traj in tqdm(dataset.iterator(), total=cardinality if cardinality != tf.data.UNKNOWN_CARDINALITY else None): - actions.append(traj["action"]) - proprios.append(traj["proprio"]) - num_transitions += traj["action"].shape[0] - num_trajectories += 1 - - actions, proprios = np.concatenate(actions), np.concatenate(proprios) - metadata = { - "action": { - "mean": actions.mean(0).tolist(), - "std": actions.std(0).tolist(), - "max": actions.max(0).tolist(), - "min": actions.min(0).tolist(), - "q01": np.quantile(actions, 0.01, axis=0).tolist(), - "q99": np.quantile(actions, 0.99, axis=0).tolist(), - }, - "proprio": { - "mean": proprios.mean(0).tolist(), - "std": proprios.std(0).tolist(), - "max": proprios.max(0).tolist(), - "min": proprios.min(0).tolist(), - "q01": np.quantile(proprios, 0.01, axis=0).tolist(), - "q99": np.quantile(proprios, 0.99, axis=0).tolist(), - }, - "num_transitions": num_transitions, - "num_trajectories": num_trajectories, - } - - try: - with tf.io.gfile.GFile(path, "w") as f: - json.dump(metadata, f) - except tf.errors.PermissionDeniedError: - overwatch.warning(f"Could not write dataset statistics to {path}. Writing to {local_path} instead.") - os.makedirs(os.path.dirname(local_path), exist_ok=True) - with open(local_path, "w") as f: - json.dump(metadata, f) - - return metadata - - -def save_dataset_statistics(dataset_statistics, run_dir): - """Saves a `dataset_statistics.json` file.""" - out_path = run_dir / "dataset_statistics.json" - with open(out_path, "w") as f_json: - for _, stats in dataset_statistics.items(): - for k in stats["action"].keys(): - if isinstance(stats["action"][k], np.ndarray): - stats["action"][k] = stats["action"][k].tolist() - if "proprio" in stats: - for k in stats["proprio"].keys(): - if isinstance(stats["proprio"][k], np.ndarray): - stats["proprio"][k] = stats["proprio"][k].tolist() - if "num_trajectories" in stats: - if isinstance(stats["num_trajectories"], np.ndarray): - stats["num_trajectories"] = stats["num_trajectories"].item() - if "num_transitions" in stats: - if isinstance(stats["num_transitions"], np.ndarray): - stats["num_transitions"] = stats["num_transitions"].item() - json.dump(dataset_statistics, f_json, indent=2) - overwatch.info(f"Saved dataset statistics file at path {out_path}") - - -def allocate_threads(n: Optional[int], weights: np.ndarray): - """ - Allocates an integer number of threads across datasets based on weights. - - The final array sums to `n`, but each element is no less than 1. If `n` is None, then every dataset is assigned a - value of AUTOTUNE. - """ - if n is None: - return np.array([tf.data.AUTOTUNE] * len(weights)) - - assert np.all(weights >= 0), "Weights must be non-negative" - assert len(weights) <= n, "Number of threads must be at least as large as length of weights" - weights = np.array(weights) / np.sum(weights) - - allocation = np.zeros_like(weights, dtype=int) - while True: - # Give the remaining elements that would get less than 1 a 1 - mask = (weights * n < 1) & (weights > 0) - if not mask.any(): - break - n -= mask.sum() - allocation += mask.astype(int) - - # Recompute the distribution over the remaining elements - weights[mask] = 0 - weights = weights / weights.sum() - - # Allocate the remaining elements - fractional, integral = np.modf(weights * n) - allocation += integral.astype(int) - n -= integral.sum() - for i in np.argsort(fractional)[::-1][: int(n)]: - allocation[i] += 1 - - return allocation diff --git a/capvector-oft/prismatic/vla/datasets/rlds/utils/goal_relabeling.py b/capvector-oft/prismatic/vla/datasets/rlds/utils/goal_relabeling.py deleted file mode 100644 index c8a394955f7c1d3c2aad2ea8b157e6e06b60ae6b..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/vla/datasets/rlds/utils/goal_relabeling.py +++ /dev/null @@ -1,32 +0,0 @@ -""" -goal_relabeling.py - -Contains simple goal relabeling logic for BC use-cases where rewards and next_observations are not required. -Each function should add entries to the "task" dict. -""" - -from typing import Dict - -import tensorflow as tf - -from prismatic.vla.datasets.rlds.utils.data_utils import tree_merge - - -def uniform(traj: Dict) -> Dict: - """Relabels with a true uniform distribution over future states.""" - traj_len = tf.shape(tf.nest.flatten(traj["observation"])[0])[0] - - # Select a random future index for each transition i in the range [i + 1, traj_len) - rand = tf.random.uniform([traj_len]) - low = tf.cast(tf.range(traj_len) + 1, tf.float32) - high = tf.cast(traj_len, tf.float32) - goal_idxs = tf.cast(rand * (high - low) + low, tf.int32) - - # Sometimes there are floating-point errors that cause an out-of-bounds - goal_idxs = tf.minimum(goal_idxs, traj_len - 1) - - # Adds keys to "task" mirroring "observation" keys (`tree_merge` to combine "pad_mask_dict" properly) - goal = tf.nest.map_structure(lambda x: tf.gather(x, goal_idxs), traj["observation"]) - traj["task"] = tree_merge(traj["task"], goal) - - return traj diff --git a/capvector-oft/prismatic/vla/datasets/rlds/utils/task_augmentation.py b/capvector-oft/prismatic/vla/datasets/rlds/utils/task_augmentation.py deleted file mode 100644 index f0d0c8e9785917cabb95742dd21efd4587976aa0..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/vla/datasets/rlds/utils/task_augmentation.py +++ /dev/null @@ -1,57 +0,0 @@ -""" -task_augmentation.py - -Contains basic logic for randomly zeroing out keys in the task specification. -""" - -from typing import Dict - -import tensorflow as tf - -from prismatic.vla.datasets.rlds.utils.data_utils import to_padding - - -def delete_task_conditioning(traj: Dict, keep_image_prob: float) -> Dict: - """ - Randomly drops out either the goal images or the language instruction. Only does something if both of - these are present. - - Args: - traj: A dictionary containing trajectory data. Should have a "task" key. - keep_image_prob: The probability of keeping the goal images. The probability of keeping the language - instruction is 1 - keep_image_prob. - """ - if "language_instruction" not in traj["task"]: - return traj - - image_keys = {key for key in traj["task"].keys() if key.startswith("image_") or key.startswith("depth_")} - if not image_keys: - return traj - - traj_len = tf.shape(traj["action"])[0] - should_keep_images = tf.random.uniform([traj_len]) < keep_image_prob - should_keep_images |= ~traj["task"]["pad_mask_dict"]["language_instruction"] - - for key in image_keys | {"language_instruction"}: - should_keep = should_keep_images if key in image_keys else ~should_keep_images - # pad out the key - traj["task"][key] = tf.where( - should_keep, - traj["task"][key], - to_padding(traj["task"][key]), - ) - # zero out the pad mask dict for the key - traj["task"]["pad_mask_dict"][key] = tf.where( - should_keep, - traj["task"]["pad_mask_dict"][key], - tf.zeros_like(traj["task"]["pad_mask_dict"][key]), - ) - - # when no goal images are present, the goal timestep becomes the final timestep - traj["task"]["timestep"] = tf.where( - should_keep_images, - traj["task"]["timestep"], - traj_len - 1, - ) - - return traj diff --git a/capvector-oft/prismatic/vla/materialize.py b/capvector-oft/prismatic/vla/materialize.py deleted file mode 100644 index 6b267bc24c27e778234f03ee58c48f6d41b34148..0000000000000000000000000000000000000000 --- a/capvector-oft/prismatic/vla/materialize.py +++ /dev/null @@ -1,56 +0,0 @@ -""" -materialize.py - -Factory class for initializing Open-X RLDS-backed datasets, given specified data mixture parameters; provides and -exports individual functions for clear control flow. -""" - -from pathlib import Path -from typing import Tuple, Type - -from torch.utils.data import Dataset -from transformers import PreTrainedTokenizerBase - -from prismatic.models.backbones.llm.prompting import PromptBuilder -from prismatic.models.backbones.vision import ImageTransform -from prismatic.util.data_utils import PaddedCollatorForActionPrediction -from prismatic.vla.action_tokenizer import ActionTokenizer -from prismatic.vla.datasets import EpisodicRLDSDataset, RLDSBatchTransform, RLDSDataset - - -def get_vla_dataset_and_collator( - data_root_dir: Path, - data_mix: str, - image_transform: ImageTransform, - tokenizer: PreTrainedTokenizerBase, - prompt_builder_fn: Type[PromptBuilder], - default_image_resolution: Tuple[int, int, int], - padding_side: str = "right", - predict_stop_token: bool = True, - shuffle_buffer_size: int = 100_000, - train: bool = True, - episodic: bool = False, - image_aug: bool = False, -) -> Tuple[Dataset, ActionTokenizer, PaddedCollatorForActionPrediction]: - """Initialize RLDS Dataset (wraps TFDS), ActionTokenizer, and initialize transform/collation functions.""" - action_tokenizer = ActionTokenizer(tokenizer) - batch_transform = RLDSBatchTransform( - action_tokenizer, tokenizer, image_transform, prompt_builder_fn, predict_stop_token=predict_stop_token - ) - collator = PaddedCollatorForActionPrediction( - tokenizer.model_max_length, tokenizer.pad_token_id, padding_side=padding_side - ) - - # Build RLDS Iterable Dataset - cls = RLDSDataset if not episodic else EpisodicRLDSDataset - dataset = cls( - data_root_dir, - data_mix, - batch_transform, - resize_resolution=default_image_resolution[1:], - shuffle_buffer_size=shuffle_buffer_size, - train=train, - image_aug=image_aug, - ) - - return dataset, action_tokenizer, collator diff --git a/capvector-oft/pyproject.toml b/capvector-oft/pyproject.toml deleted file mode 100644 index 7a2037d18d7df84bd255e6af41f01b3469010cf5..0000000000000000000000000000000000000000 --- a/capvector-oft/pyproject.toml +++ /dev/null @@ -1,102 +0,0 @@ -[build-system] -requires = ["setuptools"] -build-backend = "setuptools.build_meta" - -[project] -name = "openvla-oft" -authors = [ - {name = "Moo Jin Kim", email="moojink@stanford.edu"}, - {name = "Chelsea Finn", email="cbfinn@cs.stanford.edu"}, - {name = "Percy Liang", email="pliang@cs.stanford.edu"}, -] -description = "Fine-Tuning Vision-Language-Action Models: Optimizing Speed and Success" -version = "0.0.1" -readme = "README.md" -requires-python = ">=3.8" -keywords = ["vision-language-actions models", "fine-tuning", "robot learning"] -license = {file = "LICENSE"} -classifiers = [ - "Development Status :: 3 - Alpha", - "Intended Audience :: Developers", - "Intended Audience :: Education", - "Intended Audience :: Science/Research", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3 :: Only", - "Topic :: Scientific/Engineering :: Artificial Intelligence", -] -dependencies = [ - "accelerate>=0.25.0", - "draccus==0.8.0", - "einops", - # "flash_attn==2.5.5", # Here for documentation -- install *AFTER* editable install (follow README) - "huggingface_hub", - "json-numpy", - "jsonlines", - "matplotlib", - "peft==0.11.1", - "protobuf", - "rich", - "sentencepiece==0.1.99", - "timm==0.9.10", - "tokenizers==0.19.1", - "torch==2.2.0", - "torchvision==0.17.0", - "torchaudio==2.2.0", - "transformers @ git+https://github.com/moojink/transformers-openvla-oft.git", # IMPORTANT: Use this fork for bidirectional attn (for parallel decoding) - "wandb", - "tensorflow==2.15.0", - "tensorflow_datasets==4.9.3", - "tensorflow_graphics==2021.12.3", - "dlimp @ git+https://github.com/moojink/dlimp_openvla", - "diffusers", - "imageio", - "uvicorn", - "fastapi", - "json-numpy", -] - -[project.optional-dependencies] -dev = [ - "black>=24.2.0", - "gpustat", - "ipython", - "pre-commit", - "ruff>=0.2.2", -] -sagemaker = [ - "boto3", - "sagemaker" -] - -[project.urls] -homepage = "https://github.com/moojink/openvla-oft" -repository = "https://github.com/moojink/openvla-oft" -documentation = "https://github.com/moojink/openvla-oft" - -[tool.setuptools.packages.find] -where = ["."] -exclude = ["cache"] - -[tool.setuptools.package-data] -"prismatic" = ["py.typed"] - -[tool.black] -line-length = 121 -target-version = ["py38", "py39", "py310"] -preview = true - -[tool.ruff] -line-length = 121 -target-version = "py38" - -[tool.ruff.lint] -select = ["A", "B", "E", "F", "I", "RUF", "W"] -ignore = ["F722"] - -[tool.ruff.lint.per-file-ignores] -"__init__.py" = ["E402", "F401"] diff --git a/capvector-oft/readme.md b/capvector-oft/readme.md deleted file mode 100644 index 7af472c4dcbf76c8347da907358d74019faa9d9b..0000000000000000000000000000000000000000 --- a/capvector-oft/readme.md +++ /dev/null @@ -1,151 +0,0 @@ -## 1. Environment Setup -```bash -# Create and activate conda environment -conda create -n capvector-openvla-oft python=3.10.16 -y -conda activate capvector-openvla-oft - -# Install PyTorch -pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 - -# pip install to download dependencies -pip install -e . - -# Install Flash Attention 2 for training (https://github.com/Dao-AILab/flash-attention) -# =>> If you run into difficulty, try `pip cache remove flash_attn` first -pip install packaging ninja -ninja --version; echo $? # Verify Ninja --> should return exit code "0" -pip install "flash-attn==2.5.5" --no-build-isolation -``` -- If you are uncertain about the version of a dependency, please refer to our [**complete envs list**](envs_list.txt). - - -## 2. Data Preparation -First, clone and install the [LIBERO repo](https://github.com/Lifelong-Robot-Learning/LIBERO) and required packages: -```bash -git clone https://github.com/Lifelong-Robot-Learning/LIBERO.git -pip install -e LIBERO -pip install -r experiments/robot/libero/libero_requirements.txt -``` - -(Optional, if you plan to launch training) Then, to download the [LIBERO datasets](https://huggingface.co/datasets/openvla/modified_libero_rlds) that we used in our fine-tuning experiments, run the command below or download them manually. This will download the LIBERO-Spatial, LIBERO-Object, LIBERO-Goal, and LIBERO-10 datasets in RLDS data format (~10 GB total). You can use these to fine-tune openvla-SF or train other methods like OpenVLA. Note that these are the same datasets used in the original OpenVLA project. If needed, see details on how to download the original non-RLDS datasets [here](https://github.com/openvla/openvla?tab=readme-ov-file#libero-setup). -```bash -git clone git@hf.co:datasets/openvla/modified_libero_rlds ./data/libero # or download manually -``` - -Finally, the directory structure will be as below: -``` -capvector-oft - ├── data - · ├── libero - │ ├── libero_10_no_noops - │ │ └── 1.0.0 (It contains some json files and 32 tfrecord files) - │ ├── libero_goal_no_noops - │ │ └── 1.0.0 (It contains some json files and 16 tfrecord files) - │ ├── libero_object_no_noops - │ │ └── 1.0.0 (It contains some json files and 32 tfrecord files) - │ ├── libero_spatial_no_noops - │ │ └── 1.0.0 (It contains some json files and 16 tfrecord files) - │ - └── other benchmarks ... -``` - -## 3. set up a conda environment (see instructions in [SETUP.md](SETUP.md)). - -## 4. Obtain the CapVector on any dataset (e.g. LIBERO/ROBOTWIN) and merge it to obtain $\theta_{meta}$ - -First, download the [OpenVLA](https://huggingface.co/openvla/openvla-7b/tree/main) and place them in the `./ckpts/` folder. The directory structure is as below: -``` -capvector-oft - ├── ckpts - · ├── openvla-7b - │ ├── added_tokens.json - │ ├── model-00001-of-00003.safetensors - │ └── ... - · -``` -Then, -``` -cd capvector-oft -bash capvector/interpolate.sh #LIBERO -bash capvector/initialized_interpolate_shell/get_vector_robotwin.sh #ROBOTWIN or other custom datasets (PROPRIO_DIM modification required) -``` - -The above steps are equivalent to directly downloading the $\theta_{meta}$: - -* [capvector-openvla-7b](https://huggingface.co/haofuly/capvector_models_collection/capvector_openvlaoft/merged_model) - -This $\theta_{meta}$ is obtained from LIBERO Spatial and the capability vector is merged with OpenVLA weights with vector weight = 1. - -Place them in the `./ckpts/` folder. The directory structure is as below: -``` -capvector-oft - ├── ckpts - · ├── openvla-7b - ├── capvector-openvla-7b - · -``` - -## 5. Start Training -```bash -conda activate capvector-openvla-oft -cd capvector-oft -``` - -First, be sure you have downloaded the LIBERO datasets, as mentioned in the [Data Preparation Section](#data-preparation): `libero_spatial_no_noops`, `libero_object_no_noops`, `libero_goal_no_noops`, `libero_10_no_noops`. (`"_no_noops"` stands for no no-op actions, i.e., training samples with near-zero actions are filtered out). - -Then, prepare the [lora diff](https://huggingface.co/haofuly/capvector_models_collection/capvector_openvlaoft/diff_parameter) and place it at capvector-oft/capvector/lora_diff. This is used to compute the orthogonal loss. - -Next, launch the fine-tuning script below, replacing `X` in the first line with the number of GPUs. The command below launches fine-tuning on LIBERO-Spatial with the hyperparameters that we used in our paper. Here, batch size 8 per GPU will require ~74 GB VRAM, and batch size 1 per GPU will require ~30 GB VRAM. The training results are stored according to the `--run_root_dir` and `--run_id_override`. - -You can refer to the code block below or directly consult training_scripts/training.sh for reference. -```bash - -torchrun --standalone --nnodes 1 --nproc-per-node 1 vla-scripts/finetune_regular_loss.py \ - --vla_path ckpts/capvector-openvla-7b \ - --data_root_dir data/libero/ \ - --dataset_name libero_${TASK}_no_noops \ - --run_root_dir experiments/training_results/ \ - --use_l1_regression True \ - --use_diffusion False \ - --use_film False \ - --num_images_in_input 2 \ - --use_proprio True \ - --batch_size 8 \ - --learning_rate 5e-4 \ - --scheduler CosineAnnealingLR \ - --max_steps 150100 \ - --save_freq 150000 \ - --save_latest_checkpoint_only True \ - --merge_lora_during_training True \ - --regularization_lora_vector_path capvector/lora_diff/sf_150000_steps_spatial_adapter_diff.safetensors \ - --regularization_weight 1e-4 \ - --image_aug True \ - --lora_rank 32 \ - --wandb_entity "YOUR_WANDB_ENTITY" \ - --wandb_project "YOUR_WANDB_PROJECT" \ - --run_id_override "$VERSION" - -``` - -The above training command should reproduce our CapVector results if `X = 1` and the 150K step checkpoint is evaluated. - -Please be sure to test your policy with the same device/GPU used to train it! Otherwise, performance may drop substantially. You may be able to avoid the performance drop if you merge the LoRA weights into the base model on the downstream device used for testing (e.g., if you train on H100 and then merge on A100 before testing on A100). You can see our script [vla-scripts/merge_lora_weights_and_save.py](vla-scripts/merge_lora_weights_and_save.py) for merging the LoRA adapter into the base model offline. It's okay if you already merged LoRA weights into the base OpenVLA model during fine-tuning; you can always redownload the base model and merge again as long as you still have the LoRA adapter (`merge_lora_weights_and_save.py` will handle this for you). - -## 6. Inference - -Then, run the commands below to start evaluations with the independently trained checkpoints: -```bash -python experiments/robot/libero/run_libero_eval.py \ - --pretrained_checkpoint experiments/training_results/$VERSION \ - --task_suite_name libero_${TASK} -``` - -Notes: -* The evaluation script will run 500 trials by default (10 tasks x 50 episodes each). You can modify the number of - trials per task by setting `--num_trials_per_task`. You can also change the random seed via `--seed`. There are - other arguments in the script; we set them to the default values that work with the openvla-SF checkpoints above. -* The evaluation script logs results locally. You can also log results in Weights & Biases - by setting `--use_wandb True` and specifying `--wandb_project ` and `--wandb_entity `. -* The results reported in our paper were obtained using **Python 3.10.16, PyTorch 2.2.0, and - [bidirectional transformers](https://github.com/moojink/transformers-openvla-oft.git)** - on an **NVIDIA H100 GPU**. Please stick to these package versions if possible. diff --git a/capvector-oft/scripts/extern/convert_prismatic_weights_to_hf.py b/capvector-oft/scripts/extern/convert_prismatic_weights_to_hf.py deleted file mode 100644 index 70082c93e1fa1f8ba2dc96e528466d0d619e5d38..0000000000000000000000000000000000000000 --- a/capvector-oft/scripts/extern/convert_prismatic_weights_to_hf.py +++ /dev/null @@ -1,237 +0,0 @@ -""" -convert_prismatic_weights_to_hf.py - -Utility script for converting full Prismatic VLM weights (from this repository, in the default "Prismatic" format) to -the HuggingFace "AutoClasses" (e.g., those defined in `prismatic.extern.hf_*`) for "native" use in `transformers`` -via `trust_remote_code = True`. - -Theoretically, these changes should be fully compatible with directly merging the models into `transformers` down the -line, with first-class support. -""" - -import json -import os -from dataclasses import dataclass -from pathlib import Path -from typing import Dict, List, Union - -import draccus -import timm -import torch -import torch.nn as nn -from huggingface_hub import hf_hub_download -from timm.models.vision_transformer import LayerScale -from transformers import AutoTokenizer - -from prismatic.extern.hf.configuration_prismatic import PrismaticConfig -from prismatic.extern.hf.modeling_prismatic import PrismaticForConditionalGeneration -from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor - - -@dataclass -class HFConvertConfig: - # fmt: off - prismatic_model_path_or_id: Union[str, Path] = ( # Path to Pretrained VLM (on disk or HF Hub) - "siglip-224px+7b" - # "prism-dinosiglip-224px+7b" - ) - output_hf_model_local_path: Path = Path( # Path to Local Path to save HF model - "hf-convert/prismatic-siglip-224px-7b" - ) - output_hf_model_hub_path: str = ( # Path to HF Hub Path for "final" HF model - "TRI-ML/prismatic-siglip-224px-7b" # => huggingface.co/TRI-ML/prismatic-{...} - ) - - # HF Hub Credentials (required for Gated Models like LLaMa-2) - hf_token: Union[str, Path] = Path(".hf_token") # Environment variable or Path to HF Token - - def __post_init__(self) -> None: - self.hf_token = self.hf_token.read_text().strip() if isinstance(self.hf_token, Path) else self.hf_token - - # fmt: on - - -# HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale. -# =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109 -# =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960 -def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor: - return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor - - -def ls_apply_patch(ls_module: LayerScale): - ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone()) - ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale) - del ls_module.gamma - - -# === Conversion Constants === -PROJECTOR_KEY_MAPPING = { - "projector.0.weight": "projector.fc1.weight", - "projector.0.bias": "projector.fc1.bias", - "projector.2.weight": "projector.fc2.weight", - "projector.2.bias": "projector.fc2.bias", - "projector.4.weight": "projector.fc3.weight", - "projector.4.bias": "projector.fc3.bias", -} - - -def remap_state_dicts_for_hf( - projector_state_dict: Dict[str, torch.Tensor], - llm_backbone_state_dict: Dict[str, torch.Tensor], - vision_backbone_state_dicts: List[Dict[str, torch.Tensor]], -) -> Dict[str, torch.Tensor]: - """Iterate through Prismatic component state dictionaries and unify / fix key mapping for HF conversion.""" - hf_state_dict = {} - - # Iterate through Projector =>> use `PROJECTOR_KEY_MAPPING` - for key, value in projector_state_dict.items(): - hf_state_dict[PROJECTOR_KEY_MAPPING[key]] = value - - # Iterate through LLM Backbone =>> replace `llm.` with `language_model.` - for key, value in llm_backbone_state_dict.items(): - hf_state_dict[key.replace("llm.", "language_model.")] = value - - # Iterate through Vision Backbone =>> add "vision_backbone." prefix - assert len(vision_backbone_state_dicts) <= 2, "Prismatic models only support up to 2 (fused) vision backbones!" - for idx, vision_backbone_state_dict in enumerate(vision_backbone_state_dicts): - prefix = "vision_backbone.featurizer" if idx == 0 else "vision_backbone.fused_featurizer" - for key, value in vision_backbone_state_dict.items(): - hf_state_dict[f"{prefix}.{key}"] = value - - return hf_state_dict - - -@draccus.wrap() -def convert_prismatic_weights_to_hf(cfg: HFConvertConfig) -> None: - print(f"[*] Converting Prismatic Model `{cfg.prismatic_model_path_or_id}` to HF Transformers Format") - torch.set_default_dtype(torch.bfloat16) - - # Get `config.json` and `checkpoint_pt` -- mirrors logic in `prismatic.models.load.py` - if os.path.isdir(cfg.prismatic_model_path_or_id): - print(f"[*] Loading from Local Path `{(run_dir := Path(cfg.prismatic_model_path_or_id))}`") - config_json, checkpoint_pt = run_dir / "config.json", run_dir / "checkpoints" / "latest-checkpoint.pt" - - assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`" - assert checkpoint_pt.exists(), f"Missing checkpoint for `{run_dir = }`" - else: - print(f"[*] Downloading Prismatic Checkpoint from HF Hub :: `TRI-ML/{cfg.prismatic_model_path_or_id}`") - config_json = hf_hub_download("TRI-ML/prismatic-vlms", f"{cfg.prismatic_model_path_or_id}/config.json") - checkpoint_pt = hf_hub_download( - "TRI-ML/prismatic-vlms", f"{cfg.prismatic_model_path_or_id}/checkpoints/latest-checkpoint.pt" - ) - - # Load "Native" Config JSON =>> Create LLM Config & Instantiate Tokenizer - with open(config_json, "r") as f: - prismatic_config = json.load(f)["model"] - - # Create HF PrismaticConfig (`transformers.PretrainedConfig`) - hf_config = PrismaticConfig( - vision_backbone_id=prismatic_config["vision_backbone_id"], - llm_backbone_id=prismatic_config["llm_backbone_id"], - arch_specifier=prismatic_config["arch_specifier"], - image_resize_strategy=prismatic_config["image_resize_strategy"], - llm_max_length=prismatic_config["llm_max_length"], - torch_dtype=torch.bfloat16, - ) - - # Instantiate & Add Pad to Tokenizer =>> following `prismatic.models.materialize.get_llm_backbone_and_tokenizer` - # TODO (siddk) :: Implement batched generation -- in which case this should set `padding_side = "left"`! - print("[*] Instantiating and Patching Tokenizer, LLM Config") - tokenizer = AutoTokenizer.from_pretrained( - hf_config.hf_llm_id, model_max_length=hf_config.llm_max_length, token=cfg.hf_token, padding_side="right" - ) - tokenizer.add_special_tokens({"pad_token": ""}) - tokenizer.init_kwargs.pop("add_prefix_space", None) # Pop to prevent unnecessary warning on reload... - assert tokenizer.pad_token_id == hf_config.pad_token_id, "Incorrect Pad Token ID!" - assert len(tokenizer) > hf_config.text_config.vocab_size, "Tokenizer vocabulary must be larger than LLM vocabulary!" - - # Patch LLM Config in `hf_config` with vocab_size (+ `hf_config.pad_to_multiple_of`), pad_token_id + validate - hf_config.text_config.vocab_size += hf_config.pad_to_multiple_of - hf_config.text_config.pad_token_id = hf_config.pad_token_id - hf_config.text_config.torch_dtype = torch.bfloat16 - assert hf_config.text_config.use_cache, "LLM config `use_cache` should be True for inference (set default)!" - - # Create Vision Backbone & Transform =>> following `prismatic.models.materialize.get_vision_backbone_and_transform` - # =>> Deviates a bit from existing code; as such, explicitly tested in `tests/test_image_transforms.py` - print("[*] Loading TIMM Vision Backbone(s) and Image Transform(s) =>> Initializing PrismaticImageProcessor") - timm_vision_backbones, input_sizes, interpolations, means, stds = [], [], [], [], [] - for idx, timm_model_id in enumerate(hf_config.timm_model_ids): - timm_vision_backbone = timm.create_model( - timm_model_id, - pretrained=True, - num_classes=0, - img_size=hf_config.image_sizes[idx], - act_layer=hf_config.timm_override_act_layers[idx], - ) - timm_vision_backbones.append(timm_vision_backbone) - - # Get Per-Backbone Image Processing - data_cfg = timm.data.resolve_model_data_config(timm_vision_backbone) - input_sizes.append((3, hf_config.image_sizes[idx], hf_config.image_sizes[idx])) - interpolations.append(data_cfg["interpolation"]) - means.append(data_cfg["mean"]) - stds.append(data_cfg["std"]) - - # Patch `LayerScale` because of HF annoying `fix_key` overwrite... - for module in timm_vision_backbone.modules(): - if isinstance(module, LayerScale): - ls_apply_patch(module) - - # Create PrismaticImageProcessor (`transformers.ImageProcessingMixin`) - hf_image_processor = PrismaticImageProcessor( - use_fused_vision_backbone=hf_config.use_fused_vision_backbone, - image_resize_strategy=hf_config.image_resize_strategy, - input_sizes=input_sizes, - interpolations=interpolations, - means=means, - stds=stds, - ) - - # Create top-level PrismaticProcessor (`transformers.ProcessorMixin` =>> enables registry w/ AutoProcessor) - print("[*] Creating PrismaticProcessor Instance from Tokenizer and PrismaticImageProcessor") - hf_processor = PrismaticProcessor(image_processor=hf_image_processor, tokenizer=tokenizer) - - # Load Prismatic Model State Dictionary (in preparation for conversion) - print("[*] Loading Prismatic VLM State Dictionary from Checkpoint") - model_state_dict = torch.load(checkpoint_pt, map_location="cpu")["model"] - assert ("downsampler" not in model_state_dict) or (len(model_state_dict["downsampler"]) == 0), "Downsampler?" - assert ("projector" in model_state_dict) and ("llm_backbone" in model_state_dict), "Missing keys!" - - # Convert - print("[*] Running Conversion") - converted_state_dict = remap_state_dicts_for_hf( - model_state_dict["projector"], - model_state_dict["llm_backbone"], - vision_backbone_state_dicts=[vb.state_dict() for vb in timm_vision_backbones], - ) - - # Create PrismaticForConditionalGeneration =>> Note that we can't initialize on `meta` device because TIMM - print("[*] Building (Randomly Initialized) Model =>> PrismaticForConditionalGeneration") - hf_model = PrismaticForConditionalGeneration(hf_config) - hf_model.load_state_dict(converted_state_dict, strict=True, assign=True) - - # Cast Model to BF16 before Saving - hf_model.to(torch.bfloat16) - - # Save Pretrained Versions to Local Path - print("[*] Saving Model & Processor to Local Path") - hf_model.save_pretrained(cfg.output_hf_model_local_path, max_shard_size="7GB") - hf_image_processor.save_pretrained(cfg.output_hf_model_local_path) - hf_processor.save_pretrained(cfg.output_hf_model_local_path) - - # Register AutoClasses - PrismaticConfig.register_for_auto_class() - PrismaticImageProcessor.register_for_auto_class("AutoImageProcessor") - PrismaticProcessor.register_for_auto_class("AutoProcessor") - PrismaticForConditionalGeneration.register_for_auto_class("AutoModelForVision2Seq") - - # Push to Hub - print("[*] Pushing Model & Processor to HF Hub") - hf_config.push_to_hub(cfg.output_hf_model_hub_path) - hf_model.push_to_hub(cfg.output_hf_model_hub_path, max_shard_size="7GB") - hf_image_processor.push_to_hub(cfg.output_hf_model_hub_path) - hf_processor.push_to_hub(cfg.output_hf_model_hub_path) - - -if __name__ == "__main__": - convert_prismatic_weights_to_hf() diff --git a/capvector-oft/scripts/extern/verify_prismatic.py b/capvector-oft/scripts/extern/verify_prismatic.py deleted file mode 100644 index 3717bfb96250117ac754529c18e090b62a0aa4c4..0000000000000000000000000000000000000000 --- a/capvector-oft/scripts/extern/verify_prismatic.py +++ /dev/null @@ -1,134 +0,0 @@ -""" -verify_prismatic.py - -Given an HF-exported Prismatic model, attempt to load via AutoClasses, and verify forward() and generate(). -""" - -import time - -import requests -import torch -from PIL import Image -from transformers import AutoModelForVision2Seq, AutoProcessor - -# === Verification Arguments === -MODEL_PATH = "TRI-ML/prismatic-siglip-224px-7b" -DEFAULT_IMAGE_URL = ( - "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png" -) - -if "-prism-" in MODEL_PATH: - SAMPLE_PROMPTS_FOR_GENERATION = [ - "In: What is sitting in the coffee?\nOut:", - "In: What's the name of the food on the plate?\nOut:", - "In: caption.\nOut:", - "In: how many beinets..?\nOut:", - "In: Can you give me a lyrical description of the scene\nOut:", - ] -else: - SYSTEM_PROMPT = ( - "A chat between a curious user and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the user's questions." - ) - SAMPLE_PROMPTS_FOR_GENERATION = [ - f"{SYSTEM_PROMPT} USER: What is sitting in the coffee? ASSISTANT:", - f"{SYSTEM_PROMPT} USER: What's the name of the food on the plate? ASSISTANT:", - f"{SYSTEM_PROMPT} USER: caption. ASSISTANT:", - f"{SYSTEM_PROMPT} USER: how many beinets..? ASSISTANT:", - f"{SYSTEM_PROMPT} USER: Can you give me a lyrical description of the scene ASSISTANT:", - ] - - -@torch.inference_mode() -def verify_prismatic() -> None: - print(f"[*] Verifying PrismaticForConditionalGeneration using Model `{MODEL_PATH}`") - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - - # Load Processor & VLM - print("[*] Instantiating Processor and Pretrained VLM") - processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True) - - # === AUTOCAST MODE === - # print("[*] Loading in BF16 Autocast Mode") - # vlm = AutoModelForVision2Seq.from_pretrained(MODEL_PATH, low_cpu_mem_usage=True, trust_remote_code=True).to( - # device, dtype=torch.bfloat16 - # ) - - # === NATIVE BFLOAT16 MODE === - # print("[*] Loading in BF16") - # vlm = AutoModelForVision2Seq.from_pretrained( - # MODEL_PATH, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True - # ).to(device) - - # === BFLOAT16 + FLASH-ATTN MODE :: [~14GB of VRAM Passive || 18GB of VRAM Active] === - print("[*] Loading in BF16 with Flash-Attention Enabled") - vlm = AutoModelForVision2Seq.from_pretrained( - MODEL_PATH, - attn_implementation="flash_attention_2", - torch_dtype=torch.bfloat16, - low_cpu_mem_usage=True, - trust_remote_code=True, - ).to(device) - - # === 8-BIT QUANTIZATION MODE (`pip install bitsandbytes`) :: [~9GB of VRAM Passive || 10GB of VRAM Active] === - # print("[*] Loading in 8-Bit Quantization Mode") - # vlm = AutoModelForVision2Seq.from_pretrained( - # MODEL_PATH, - # attn_implementation="flash_attention_2", - # torch_dtype=torch.float16, - # quantization_config=BitsAndBytesConfig(load_in_8bit=True), - # low_cpu_mem_usage=True, - # trust_remote_code=True, - # ) - - # === 4-BIT QUANTIZATION MODE (`pip install bitsandbytes`) :: [~6GB of VRAM Passive || 7GB of VRAM Active] === - # print("[*] Loading in 4-Bit Quantization Mode") - # vlm = AutoModelForVision2Seq.from_pretrained( - # MODEL_PATH, - # attn_implementation="flash_attention_2", - # torch_dtype=torch.float16, - # quantization_config=BitsAndBytesConfig(load_in_4bit=True), - # low_cpu_mem_usage=True, - # trust_remote_code=True, - # ) - - # Iterate over Sample Prompts =>> Generate - image = Image.open(requests.get(DEFAULT_IMAGE_URL, stream=True).raw).convert("RGB") - num_tokens, total_time = 0, 0.0 - - print("[*] Iterating over Sample Prompts\n===\n") - for idx, prompt in enumerate(SAMPLE_PROMPTS_FOR_GENERATION): - # === AUTOCAST MODE (Reproduces Prismatic `scripts/generate.py`) === - # inputs = processor(prompt, image).to(device) - # - # # Using "autocast" to evaluate bit-wise equivalence to `scripts/generate.py` - # # =>> Running in native BF16 is also fine (but leads to slightly different generations) - # with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True): - # gen_ids = vlm.generate(**inputs, do_sample=False, min_length=1, max_length=512) - - # === BFLOAT16 MODE === - inputs = processor(prompt, image).to(device, dtype=torch.bfloat16) - - # === 8-BIT/4-BIT QUANTIZATION MODE === - # inputs = processor(prompt, image).to(device, dtype=torch.float16) - - # Run Inference - gen_ids = None - for _ in range(5): - start_time = time.time() - gen_ids = vlm.generate(**inputs, do_sample=False, min_length=1, max_length=512) - total_time += time.time() - start_time - - gen_ids = gen_ids[0, inputs.input_ids.shape[1] :] - num_tokens += len(gen_ids) - - # === - gen_text = processor.decode(gen_ids, skip_special_tokens=True).strip() - print(f"[{idx + 1}] Input Prompt => {prompt}\n Generated => {gen_text}\n") - - # Compute Tokens / Second - print(f"[*] Generated Tokens per Second = {num_tokens / total_time} w/ {num_tokens = } and {total_time = }") - - -if __name__ == "__main__": - verify_prismatic() diff --git a/capvector-oft/training_scripts/training.sh b/capvector-oft/training_scripts/training.sh deleted file mode 100644 index 639bf80e35de552b249d800c0126602f77df63f5..0000000000000000000000000000000000000000 --- a/capvector-oft/training_scripts/training.sh +++ /dev/null @@ -1,36 +0,0 @@ -VERSION="v0" -TASK="10" # spatial / object / goal / 10 / 90 -VLA_PATH="checkpoints/initialized_pt_vla/initailized_openvla_with_SF_spatial_v0.4.2" -DATA_ROOT_DIR="data/libero_openvla" -RUN_ROOT_DIR="experiments/training_results" -REGULARIZATION_LORA_VECTOR_PATH="checkpoints/lora_diff/sf_150000_steps_spatial_adapter_diff.safetensors" -WANDB_ENTITY="YOUR_WANDB_ENTITY" -WANDB_PROJECT="YOUR_WANDB_PROJECT" -EVAL_LOG_PATH="experiments/eval_logs/${VERSION}_output.log" - -torchrun --standalone --nnodes 1 --nproc-per-node 1 vla-scripts/finetune_regular_loss.py \ - --vla_path "$VLA_PATH" \ - --data_root_dir "$DATA_ROOT_DIR" \ - --dataset_name libero_${TASK}_no_noops \ - --run_root_dir "$RUN_ROOT_DIR" \ - --use_l1_regression True \ - --use_diffusion False \ - --use_film False \ - --num_images_in_input 2 \ - --use_proprio True \ - --batch_size 8 \ - --learning_rate 5e-4 \ - --scheduler CosineAnnealingLR \ - --max_steps 150100 \ - --save_freq 150000 \ - --save_latest_checkpoint_only True \ - --merge_lora_during_training True \ - --regularization_lora_vector_path "$REGULARIZATION_LORA_VECTOR_PATH" \ - --regularization_weight 1e-4 \ - --image_aug True \ - --lora_rank 32 \ - --wandb_entity "$WANDB_ENTITY" \ - --wandb_project "$WANDB_PROJECT" \ - --run_id_override "$VERSION" - -python experiments/robot/libero/run_libero_eval.py --pretrained_checkpoint "$RUN_ROOT_DIR/$VERSION" --task_suite_name libero_${TASK} > "$EVAL_LOG_PATH" 2>&1 diff --git a/capvector-oft/vla-scripts/deploy.py b/capvector-oft/vla-scripts/deploy.py deleted file mode 100644 index a7d743bf132cabff28c75ea0342913f8aeae2c52..0000000000000000000000000000000000000000 --- a/capvector-oft/vla-scripts/deploy.py +++ /dev/null @@ -1,156 +0,0 @@ -""" -deploy.py - -Starts VLA server which the client can query to get robot actions. -""" - -import os.path - -# ruff: noqa: E402 -import json_numpy - -json_numpy.patch() -import json -import logging -import numpy as np -import traceback -from dataclasses import dataclass -from pathlib import Path -from typing import Any, Dict, Optional, Union - -import draccus -import torch -import uvicorn -from fastapi import FastAPI -from fastapi.responses import JSONResponse -from PIL import Image -from transformers import AutoModelForVision2Seq, AutoProcessor - -from experiments.robot.openvla_utils import ( - get_vla, - get_vla_action, - get_action_head, - get_processor, - get_proprio_projector, -) -from experiments.robot.robot_utils import ( - get_image_resize_size, -) -from prismatic.vla.constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX, NUM_ACTIONS_CHUNK, PROPRIO_DIM, STOP_INDEX - - -def get_openvla_prompt(instruction: str, openvla_path: Union[str, Path]) -> str: - return f"In: What action should the robot take to {instruction.lower()}?\nOut:" - - -# === Server Interface === -class OpenVLAServer: - def __init__(self, cfg) -> Path: - """ - A simple server for OpenVLA models; exposes `/act` to predict an action for a given observation + instruction. - """ - self.cfg = cfg - - # Load model - self.vla = get_vla(cfg) - - # Load proprio projector - self.proprio_projector = None - if cfg.use_proprio: - self.proprio_projector = get_proprio_projector(cfg, self.vla.llm_dim, PROPRIO_DIM) - - # Load continuous action head - self.action_head = None - if cfg.use_l1_regression or cfg.use_diffusion: - self.action_head = get_action_head(cfg, self.vla.llm_dim) - - # Check that the model contains the action un-normalization key - assert cfg.unnorm_key in self.vla.norm_stats, f"Action un-norm key {cfg.unnorm_key} not found in VLA `norm_stats`!" - - # Get Hugging Face processor - self.processor = None - self.processor = get_processor(cfg) - - # Get expected image dimensions - self.resize_size = get_image_resize_size(cfg) - - - def get_server_action(self, payload: Dict[str, Any]) -> str: - try: - if double_encode := "encoded" in payload: - # Support cases where `json_numpy` is hard to install, and numpy arrays are "double-encoded" as strings - assert len(payload.keys()) == 1, "Only uses encoded payload!" - payload = json.loads(payload["encoded"]) - - observation = payload - instruction = observation["instruction"] - - action = get_vla_action( - self.cfg, self.vla, self.processor, observation, instruction, action_head=self.action_head, proprio_projector=self.proprio_projector, use_film=self.cfg.use_film, - ) - - if double_encode: - return JSONResponse(json_numpy.dumps(action)) - else: - return JSONResponse(action) - except: # noqa: E722 - logging.error(traceback.format_exc()) - logging.warning( - "Your request threw an error; make sure your request complies with the expected format:\n" - "{'observation': dict, 'instruction': str}\n" - ) - return "error" - - def run(self, host: str = "0.0.0.0", port: int = 8777) -> None: - self.app = FastAPI() - self.app.post("/act")(self.get_server_action) - uvicorn.run(self.app, host=host, port=port) - - -@dataclass -class DeployConfig: - # fmt: off - - # Server Configuration - host: str = "0.0.0.0" # Host IP Address - port: int = 8777 # Host Port - - ################################################################################################################# - # Model-specific parameters - ################################################################################################################# - model_family: str = "openvla" # Model family - pretrained_checkpoint: Union[str, Path] = "" # Pretrained checkpoint path - - use_l1_regression: bool = True # If True, uses continuous action head with L1 regression objective - use_diffusion: bool = False # If True, uses continuous action head with diffusion modeling objective (DDIM) - num_diffusion_steps_train: int = 50 # (When `diffusion==True`) Number of diffusion steps used for training - num_diffusion_steps_inference: int = 50 # (When `diffusion==True`) Number of diffusion steps used for inference - use_film: bool = False # If True, uses FiLM to infuse language inputs into visual features - num_images_in_input: int = 3 # Number of images in the VLA input (default: 3) - use_proprio: bool = True # Whether to include proprio state in input - - center_crop: bool = True # Center crop? (if trained w/ random crop image aug) - - lora_rank: int = 32 # Rank of LoRA weight matrix (MAKE SURE THIS MATCHES TRAINING!) - - unnorm_key: Union[str, Path] = "" # Action un-normalization key - use_relative_actions: bool = False # Whether to use relative actions (delta joint angles) - - load_in_8bit: bool = False # (For OpenVLA only) Load with 8-bit quantization - load_in_4bit: bool = False # (For OpenVLA only) Load with 4-bit quantization - - ################################################################################################################# - # Utils - ################################################################################################################# - seed: int = 7 # Random Seed (for reproducibility) - # fmt: on - - -@draccus.wrap() -def deploy(cfg: DeployConfig) -> None: - server = OpenVLAServer(cfg) - server.run(cfg.host, port=cfg.port) - - -if __name__ == "__main__": - deploy() diff --git a/capvector-oft/vla-scripts/extern/convert_openvla_weights_to_hf.py b/capvector-oft/vla-scripts/extern/convert_openvla_weights_to_hf.py deleted file mode 100644 index b23fb1c26edf14d3efc01c6137653e23a1011b8d..0000000000000000000000000000000000000000 --- a/capvector-oft/vla-scripts/extern/convert_openvla_weights_to_hf.py +++ /dev/null @@ -1,272 +0,0 @@ -""" -convert_openvla_weights_to_hf.py - -Utility script for converting full OpenVLA VLA weights (from this repository, in the default "Prismatic" format) to -the HuggingFace "AutoClasses" (e.g., those defined in `prismatic.extern.hf_*`) for "native" use in `transformers`` -via `trust_remote_code = True`. - -Theoretically, these changes should be fully compatible with directly merging the models into `transformers` down the -line, with first-class support. - -Usage: - python vla-scripts/extern/convert_openvla_weights_to_hf.py \ - --openvla_model_path_or_id \ - --output_hf_model_local_path -""" - -import json -import os -import shutil -from dataclasses import dataclass -from pathlib import Path -from typing import Dict, Union - -import draccus -import timm -import torch -import torch.nn as nn -from huggingface_hub import hf_hub_download -from timm.models.vision_transformer import LayerScale -from transformers import AutoTokenizer - -from prismatic.conf import ModelConfig -from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig -from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction -from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor - - -@dataclass -class HFConvertConfig: - # fmt: off - openvla_model_path_or_id: Union[str, Path] = ( # Path to Pretrained VLA (on disk or HF Hub) - "runs/prism-dinosiglip-224px+mx-oxe-magic-soup-plus+n8+b32+x7" - ) - output_hf_model_local_path: Path = Path( # Path to Local Path to save HF model - "hf-convert/openvla-7b" - ) - output_hf_model_hub_path: str = "openvla/openvla-7b" # (Optional) Path to HF Hub Path to push - # model to - - # HF Hub Credentials (required for Gated Models like LLaMa-2) - hf_token: Union[str, Path] = Path(".hf_token") # Environment variable or Path to HF Token - - def __post_init__(self) -> None: - self.hf_token = self.hf_token.read_text().strip() if isinstance(self.hf_token, Path) else self.hf_token - - # fmt: on - - -# HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale. -# =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109 -# =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960 -def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor: - return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor - - -def ls_apply_patch(ls_module: LayerScale): - ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone()) - ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale) - del ls_module.gamma - - -# === Conversion Constants === -PROJECTOR_KEY_MAPPING = { - "projector.0.weight": "projector.fc1.weight", - "projector.0.bias": "projector.fc1.bias", - "projector.2.weight": "projector.fc2.weight", - "projector.2.bias": "projector.fc2.bias", - "projector.4.weight": "projector.fc3.weight", - "projector.4.bias": "projector.fc3.bias", -} - - -def remap_state_dicts_for_hf( - prismatic_vision_backbone_state_dict: Dict[str, torch.Tensor], - projector_state_dict: Dict[str, torch.Tensor], - llm_backbone_state_dict: Dict[str, torch.Tensor], - use_fused_vision_backbone: bool = False, -) -> Dict[str, torch.Tensor]: - """Iterate through Prismatic component state dictionaries and unify / fix key mapping for HF conversion.""" - hf_state_dict = {} - - # Iterate through Projector =>> use `PROJECTOR_KEY_MAPPING` - for key, value in projector_state_dict.items(): - hf_state_dict[PROJECTOR_KEY_MAPPING[key]] = value - - # Iterate through LLM Backbone =>> replace `llm.` with `language_model.` - for key, value in llm_backbone_state_dict.items(): - hf_state_dict[key.replace("llm.", "language_model.")] = value - - # Iterate through Vision Backbone =>> add "vision_backbone." prefix - if not use_fused_vision_backbone: - for key, value in prismatic_vision_backbone_state_dict.items(): - hf_state_dict[key.replace("featurizer.", "vision_backbone.featurizer.")] = value - else: - # Note =>> Assumes that backbones are always DINO + SigLIP... - for key, value in prismatic_vision_backbone_state_dict.items(): - if key.startswith("dino_featurizer"): - if key.endswith(".gamma"): - # Handle `LayerScale gamma` =>> DINOv2 only! - key = key.replace(".gamma", ".scale_factor") - hf_state_dict[key.replace("dino_featurizer.", "vision_backbone.featurizer.")] = value - elif key.startswith("siglip_featurizer"): - hf_state_dict[key.replace("siglip_featurizer.", "vision_backbone.fused_featurizer.")] = value - - return hf_state_dict - - -@draccus.wrap() -def convert_openvla_weights_to_hf(cfg: HFConvertConfig) -> None: - print(f"[*] Converting OpenVLA Model `{cfg.openvla_model_path_or_id}` to HF Transformers Format") - torch.set_default_dtype(torch.bfloat16) - - # Get `config.json`, 'dataset_statistics.json' and `checkpoint_pt` -- mirrors logic in `prismatic.models.load.py` - if os.path.isdir(cfg.openvla_model_path_or_id): - print(f"[*] Loading from Local Path `{(run_dir := Path(cfg.openvla_model_path_or_id))}`") - config_json, checkpoint_pt = run_dir / "config.json", run_dir / "checkpoints" / "latest-checkpoint.pt" - dataset_statistics_json = run_dir / "dataset_statistics.json" - - assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`" - assert checkpoint_pt.exists(), f"Missing checkpoint for `{run_dir = }`" - assert dataset_statistics_json.exists(), f"Missing `dataset_statistics.json` for `{run_dir = }`" - else: - print(f"[*] Downloading Prismatic Checkpoint from HF Hub :: `TRI-ML/{cfg.openvla_model_path_or_id}`") - config_json = hf_hub_download("openvla/openvla-dev", f"{cfg.openvla_model_path_or_id}/config.json") - checkpoint_pt = hf_hub_download( - "openvla/openvla-dev", f"{cfg.openvla_model_path_or_id}/checkpoints/latest-checkpoint.pt" - ) - dataset_statistics_json = hf_hub_download( - "openvla/openvla-dev", f"{cfg.openvla_model_path_or_id}/dataset_statistics.json" - ) - - # Load "Native" Config JSON =>> Create LLM Config & Instantiate Tokenizer - with open(config_json, "r") as f: - vla_cfg = json.load(f)["vla"] - prismatic_config = ModelConfig.get_choice_class(vla_cfg["base_vlm"])().__dict__ - - # Load Normalization Statistics - with open(dataset_statistics_json, "r") as f: - norm_stats = json.load(f) - - # Create HF OpenVLAConfig (`transformers.PretrainedConfig`) - hf_config = OpenVLAConfig( - vision_backbone_id=prismatic_config["vision_backbone_id"], - llm_backbone_id=prismatic_config["llm_backbone_id"], - arch_specifier=prismatic_config["arch_specifier"], - image_resize_strategy=prismatic_config["image_resize_strategy"], - llm_max_length=prismatic_config["llm_max_length"], - torch_dtype=torch.bfloat16, - norm_stats=norm_stats, - ) - - # Instantiate & Add Pad to Tokenizer =>> following `prismatic.models.materialize.get_llm_backbone_and_tokenizer` - # TODO (siddk) :: Implement batched generation -- in which case this should set `padding_side = "left"`! - print("[*] Instantiating and Patching Tokenizer, LLM Config") - tokenizer = AutoTokenizer.from_pretrained( - hf_config.hf_llm_id, model_max_length=hf_config.llm_max_length, token=cfg.hf_token, padding_side="right" - ) - tokenizer.add_special_tokens({"pad_token": ""}) - tokenizer.init_kwargs.pop("add_prefix_space", None) # Pop to prevent unnecessary warning on reload... - assert tokenizer.pad_token_id == hf_config.pad_token_id, "Incorrect Pad Token ID!" - assert len(tokenizer) > hf_config.text_config.vocab_size, "Tokenizer vocabulary must be larger than LLM vocabulary!" - - # Patch LLM Config in `hf_config` with vocab_size (+ `hf_config.pad_to_multiple_of`), pad_token_id + validate - hf_config.text_config.vocab_size += hf_config.pad_to_multiple_of - hf_config.text_config.pad_token_id = hf_config.pad_token_id - hf_config.text_config.torch_dtype = torch.bfloat16 - assert hf_config.text_config.use_cache, "LLM config `use_cache` should be True for inference (set default)!" - - # Create Vision Backbone & Transform =>> following `prismatic.models.materialize.get_vision_backbone_and_transform` - # =>> Deviates a bit from existing code; as such, explicitly tested in `tests/test_image_transforms.py` - print("[*] Loading TIMM Vision Backbone(s) and Image Transform(s) =>> Initializing PrismaticImageProcessor") - input_sizes, interpolations, means, stds = [], [], [], [] - for idx, timm_model_id in enumerate(hf_config.timm_model_ids): - timm_vision_backbone = timm.create_model( - timm_model_id, - pretrained=True, - num_classes=0, - img_size=hf_config.image_sizes[idx], - act_layer=hf_config.timm_override_act_layers[idx], - ) - - # Get Per-Backbone Image Processing - data_cfg = timm.data.resolve_model_data_config(timm_vision_backbone) - input_sizes.append((3, hf_config.image_sizes[idx], hf_config.image_sizes[idx])) - interpolations.append(data_cfg["interpolation"]) - means.append(data_cfg["mean"]) - stds.append(data_cfg["std"]) - - # Patch `LayerScale` because of HF annoying `fix_key` overwrite... - for module in timm_vision_backbone.modules(): - if isinstance(module, LayerScale): - ls_apply_patch(module) - - # Create PrismaticImageProcessor (`transformers.ImageProcessingMixin`) - hf_image_processor = PrismaticImageProcessor( - use_fused_vision_backbone=hf_config.use_fused_vision_backbone, - image_resize_strategy=hf_config.image_resize_strategy, - input_sizes=input_sizes, - interpolations=interpolations, - means=means, - stds=stds, - ) - - # Create top-level PrismaticProcessor (`transformers.ProcessorMixin` =>> enables registry w/ AutoProcessor) - print("[*] Creating PrismaticProcessor Instance from Tokenizer and PrismaticImageProcessor") - hf_processor = PrismaticProcessor(image_processor=hf_image_processor, tokenizer=tokenizer) - - # Load Prismatic Model State Dictionary (in preparation for conversion) - print("[*] Loading Prismatic VLM State Dictionary from Checkpoint") - model_state_dict = torch.load(checkpoint_pt, map_location="cpu")["model"] - assert ("downsampler" not in model_state_dict) or (len(model_state_dict["downsampler"]) == 0), "Downsampler?" - assert all([k in model_state_dict for k in ["vision_backbone", "projector", "llm_backbone"]]), "Missing keys!" - - # Convert - print("[*] Running Conversion") - converted_state_dict = remap_state_dicts_for_hf( - model_state_dict["vision_backbone"], - model_state_dict["projector"], - model_state_dict["llm_backbone"], - use_fused_vision_backbone=hf_config.use_fused_vision_backbone, - ) - - # Create PrismaticForConditionalGeneration =>> Note that we can't initialize on `meta` device because TIMM - print("[*] Building (Randomly Initialized) Model =>> OpenVLAForActionPrediction") - hf_model = OpenVLAForActionPrediction(hf_config) - hf_model.load_state_dict(converted_state_dict, strict=True, assign=True) - - # Cast Model to BF16 before Saving - hf_model.to(torch.bfloat16) - - # Save Pretrained Versions to Local Path - print("[*] Saving Model & Processor to Local Path") - hf_model.save_pretrained(cfg.output_hf_model_local_path, max_shard_size="7GB") - hf_image_processor.save_pretrained(cfg.output_hf_model_local_path) - hf_processor.save_pretrained(cfg.output_hf_model_local_path) - - # Copy `dataset_statistics.json` File to Converted Checkpoint Directory - output_dataset_statistics_json = cfg.output_hf_model_local_path / "dataset_statistics.json" - shutil.copyfile(dataset_statistics_json, output_dataset_statistics_json) - - print(f"[*] Saving Complete! Saved converted checkpoint to: {cfg.output_hf_model_local_path}") - - ##################################################################################### - # Optional: Push Model to Hugging Face Hub - ##################################################################################### - - # # Register AutoClasses - # OpenVLAConfig.register_for_auto_class() - # PrismaticImageProcessor.register_for_auto_class("AutoImageProcessor") - # PrismaticProcessor.register_for_auto_class("AutoProcessor") - # OpenVLAForActionPrediction.register_for_auto_class("AutoModelForVision2Seq") - - # # Push to HF Hub - # print("[*] Pushing Model & Processor to HF Hub") - # hf_config.push_to_hub(cfg.output_hf_model_hub_path) - # hf_model.push_to_hub(cfg.output_hf_model_hub_path, max_shard_size="7GB") - # hf_image_processor.push_to_hub(cfg.output_hf_model_hub_path) - # hf_processor.push_to_hub(cfg.output_hf_model_hub_path) - - -if __name__ == "__main__": - convert_openvla_weights_to_hf() diff --git a/capvector-oft/vla-scripts/extern/verify_openvla.py b/capvector-oft/vla-scripts/extern/verify_openvla.py deleted file mode 100644 index e2d8e82bd2902b3e8aef003ef88c2ab5ce213d3c..0000000000000000000000000000000000000000 --- a/capvector-oft/vla-scripts/extern/verify_openvla.py +++ /dev/null @@ -1,89 +0,0 @@ -""" -verify_openvla.py - -Given an HF-exported OpenVLA model, attempt to load via AutoClasses, and verify forward() and predict_action(). -""" - -import time - -import numpy as np -import torch -from PIL import Image -from transformers import AutoModelForVision2Seq, AutoProcessor - -# === Verification Arguments -MODEL_PATH = "openvla/openvla-7b" -SYSTEM_PROMPT = ( - "A chat between a curious user and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the user's questions." -) -INSTRUCTION = "put spoon on towel" - - -def get_openvla_prompt(instruction: str) -> str: - if "v01" in MODEL_PATH: - return f"{SYSTEM_PROMPT} USER: What action should the robot take to {instruction.lower()}? ASSISTANT:" - else: - return f"In: What action should the robot take to {instruction.lower()}?\nOut:" - - -@torch.inference_mode() -def verify_openvla() -> None: - print(f"[*] Verifying OpenVLAForActionPrediction using Model `{MODEL_PATH}`") - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - - # Load Processor & VLA - print("[*] Instantiating Processor and Pretrained OpenVLA") - processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True) - - # === BFLOAT16 + FLASH-ATTN MODE === - print("[*] Loading in BF16 with Flash-Attention Enabled") - vla = AutoModelForVision2Seq.from_pretrained( - MODEL_PATH, - attn_implementation="flash_attention_2", - torch_dtype=torch.bfloat16, - low_cpu_mem_usage=True, - trust_remote_code=True, - ).to(device) - - # === 8-BIT QUANTIZATION MODE (`pip install bitsandbytes`) :: [~9GB of VRAM Passive || 10GB of VRAM Active] === - # print("[*] Loading in 8-Bit Quantization Mode") - # vla = AutoModelForVision2Seq.from_pretrained( - # MODEL_PATH, - # attn_implementation="flash_attention_2", - # torch_dtype=torch.float16, - # quantization_config=BitsAndBytesConfig(load_in_8bit=True), - # low_cpu_mem_usage=True, - # trust_remote_code=True, - # ) - - # === 4-BIT QUANTIZATION MODE (`pip install bitsandbytes`) :: [~6GB of VRAM Passive || 7GB of VRAM Active] === - # print("[*] Loading in 4-Bit Quantization Mode") - # vla = AutoModelForVision2Seq.from_pretrained( - # MODEL_PATH, - # attn_implementation="flash_attention_2", - # torch_dtype=torch.float16, - # quantization_config=BitsAndBytesConfig(load_in_4bit=True), - # low_cpu_mem_usage=True, - # trust_remote_code=True, - # ) - - print("[*] Iterating with Randomly Generated Images") - for _ in range(100): - prompt = get_openvla_prompt(INSTRUCTION) - image = Image.fromarray(np.asarray(np.random.rand(256, 256, 3) * 255, dtype=np.uint8)) - - # === BFLOAT16 MODE === - inputs = processor(prompt, image).to(device, dtype=torch.bfloat16) - - # === 8-BIT/4-BIT QUANTIZATION MODE === - # inputs = processor(prompt, image).to(device, dtype=torch.float16) - - # Run OpenVLA Inference - start_time = time.time() - action = vla.predict_action(**inputs, unnorm_key="bridge_orig", do_sample=False) - print(f"\t=>> Time: {time.time() - start_time:.4f} || Action: {action}") - - -if __name__ == "__main__": - verify_openvla() diff --git a/capvector-oft/vla-scripts/finetune.py b/capvector-oft/vla-scripts/finetune.py deleted file mode 100644 index f0c23cf5b5249188b4c335924578b10e9847df36..0000000000000000000000000000000000000000 --- a/capvector-oft/vla-scripts/finetune.py +++ /dev/null @@ -1,1152 +0,0 @@ -""" -finetune.py - -Fine-tunes OpenVLA via LoRA. -""" - -import os -import time -from collections import deque -from dataclasses import dataclass -from pathlib import Path -from typing import Dict, Optional, Tuple, Type - -import draccus -import torch -import torch.distributed as dist -import torch.nn as nn -import tqdm -from accelerate import PartialState -from huggingface_hub import HfApi, snapshot_download -from peft import LoraConfig, PeftModel, get_peft_model -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim import AdamW -from torch.optim.lr_scheduler import MultiStepLR -from torch.utils.data import DataLoader -from transformers import AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor -from transformers.modeling_outputs import CausalLMOutputWithPast - -import wandb -os.environ["WANDB_MODE"]="offline" - -from experiments.robot.openvla_utils import ( - check_model_logic_mismatch, - model_is_on_hf_hub, - update_auto_map, -) - -from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig -from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction -from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor -from prismatic.models.action_heads import DiffusionActionHead, L1RegressionActionHead -from prismatic.models.backbones.llm.prompting import PurePromptBuilder -from prismatic.models.film_vit_wrapper import FiLMedPrismaticVisionBackbone -from prismatic.models.projectors import ( - NoisyActionProjector, - ProprioProjector, -) -from prismatic.training.train_utils import ( - compute_actions_l1_loss, - compute_token_accuracy, - get_current_action_mask, - get_next_actions_mask, -) -from prismatic.util.data_utils import PaddedCollatorForActionPrediction -from prismatic.vla.action_tokenizer import ActionTokenizer -from prismatic.vla.constants import ( - ACTION_DIM, - ACTION_PROPRIO_NORMALIZATION_TYPE, - NUM_ACTIONS_CHUNK, - PROPRIO_DIM, -) -from prismatic.vla.datasets import RLDSBatchTransform, RLDSDataset -from prismatic.vla.datasets.rlds.utils.data_utils import save_dataset_statistics - -# Sane Defaults -os.environ["TOKENIZERS_PARALLELISM"] = "false" - - -import debugpy -try: - debugpy.listen(("localhost", 9501)) - print("Waiting for debugger attach") - debugpy.wait_for_client() -except Exception as e: - pass - - -@dataclass -class FinetuneConfig: - # fmt: off - vla_path: str = "openvla/openvla-7b" # Path to OpenVLA model (on HuggingFace Hub or stored locally) - - # Dataset - data_root_dir: Path = Path("datasets/rlds") # Directory containing RLDS datasets - dataset_name: str = "aloha_scoop_x_into_bowl" # Name of fine-tuning dataset (e.g., `aloha_scoop_x_into_bowl`) - run_root_dir: Path = Path("runs") # Path to directory to store logs & checkpoints - shuffle_buffer_size: int = 100_000 # Dataloader shuffle buffer size (can reduce if OOM errors occur) - - # Algorithm and architecture - use_l1_regression: bool = True # If True, trains continuous action head with L1 regression objective - use_diffusion: bool = False # If True, trains continuous action head with diffusion modeling objective (DDIM) - num_diffusion_steps_train: int = 50 # (When `diffusion==True`) Number of diffusion steps used for training - use_film: bool = False # If True, uses FiLM to infuse language inputs into visual features - num_images_in_input: int = 1 # Number of images in the VLA input (default: 1) - use_proprio: bool = False # If True, includes robot proprioceptive state in input - - # Training configuration - batch_size: int = 8 # Batch size per device (total batch size = batch_size * num GPUs) - learning_rate: float = 5e-4 # Learning rate - lr_warmup_steps: int = 0 # Number of steps to warm up learning rate (from 10% to 100%) - num_steps_before_decay: int = 100_000 # Number of steps before LR decays by 10x - grad_accumulation_steps: int = 1 # Number of gradient accumulation steps - max_steps: int = 200_000 # Max number of training steps - use_val_set: bool = False # If True, uses validation set and log validation metrics - val_freq: int = 10_000 # (When `use_val_set==True`) Validation set logging frequency in steps - val_time_limit: int = 180 # (When `use_val_set==True`) Time limit for computing validation metrics - save_freq: int = 10_000 # Checkpoint saving frequency in steps - save_latest_checkpoint_only: bool = False # If True, saves only 1 checkpoint, overwriting latest checkpoint - # (If False, saves all checkpoints) - resume: bool = False # If True, resumes from checkpoint - resume_step: Optional[int] = None # (When `resume==True`) Step number that we are resuming from - image_aug: bool = True # If True, trains with image augmentations (HIGHLY RECOMMENDED) - diffusion_sample_freq: int = 50 # (When `use_diffusion==True`) Frequency for sampling in steps - - # LoRA - use_lora: bool = True # If True, uses LoRA fine-tuning - lora_rank: int = 32 # Rank of LoRA weight matrix - lora_dropout: float = 0.0 # Dropout applied to LoRA weights - merge_lora_during_training: bool = True # If True, merges LoRA weights and saves result during training - # Note: Merging can be very slow on some machines. If so, set to - # False and merge final checkpoint offline! - - # Logging - wandb_entity: str = "your-wandb-entity" # Name of WandB entity - wandb_project: str = "your-wandb-project" # Name of WandB project - run_id_note: Optional[str] = None # Extra note to add to end of run ID for logging - run_id_override: Optional[str] = None # Optional string to override the run ID with - wandb_log_freq: int = 10 # WandB logging frequency in steps - - # fmt: on - - -def remove_ddp_in_checkpoint(state_dict) -> dict: - """ - Removes the 'module.' prefix from parameter names in a PyTorch model state dictionary that was saved using - DistributedDataParallel (DDP). - - When a model is trained using PyTorch's DistributedDataParallel, the saved state dictionary contains parameters - prefixed with 'module.'. This function removes these prefixes to make the state dictionary compatible when - loading into models that are not yet wrapped in DDP. - - Args: - state_dict (dict): PyTorch model state dictionary. - - Returns: - dict: A new state dictionary with the same contents but with 'module.' prefixes removed from parameter names. - Parameters without the 'module.' prefix remain unchanged. - """ - new_state_dict = {} - for k, v in state_dict.items(): - if k[:7] == "module.": - new_state_dict[k[7:]] = v - else: - new_state_dict[k] = v - return new_state_dict - - -def get_run_id(cfg) -> str: - """ - Generates or retrieves an identifier string for an experiment run. - - Args: - cfg (FinetuneConfig): Training configuration. - - Returns: - str: Experiment run ID. - """ - if cfg.run_id_override is not None: - # Override the run ID with the user-provided ID - run_id = cfg.run_id_override - elif cfg.resume: - # Override run ID with the previous resumed run's ID - run_id = cfg.vla_path.split("/")[-1] - # Remove the "--XXX_chkpt" suffix from the run ID if it exists - if "chkpt" in run_id.split("--")[-1]: - run_id = "--".join(run_id.split("--")[:-1]) - else: - run_id = ( - f"{cfg.vla_path.split('/')[-1]}+{cfg.dataset_name}" - f"+b{cfg.batch_size * cfg.grad_accumulation_steps}" - f"+lr-{cfg.learning_rate}" - ) - if cfg.use_lora: - run_id += f"+lora-r{cfg.lora_rank}+dropout-{cfg.lora_dropout}" - if cfg.image_aug: - run_id += "--image_aug" - if cfg.run_id_note is not None: - run_id += f"--{cfg.run_id_note}" - return run_id - - -def load_checkpoint(module_name: str, path: str, step: int, device: str = "cpu") -> dict: - """ - Loads a checkpoint for a given module. - - Args: - module_name (str): Name of model component to load checkpoint for. - path (str): Path to checkpoint directory. - step (int): Gradient step number of saved checkpoint. - device (str): String specifying how to remap storage locations (default = "cpu"). - - Returns: - dict: PyTorch model state dictionary. - """ - checkpoint_path = os.path.join(path, f"{module_name}--{step}_checkpoint.pt") - print(f"Loading checkpoint: {checkpoint_path}") - state_dict = torch.load(checkpoint_path, weights_only=True, map_location=device) - return remove_ddp_in_checkpoint(state_dict) - - -def wrap_ddp(module: nn.Module, device_id: int, find_unused: bool = False) -> DDP: - """ - Wrap a module with DistributedDataParallel. - - Args: - module (nn.Module): PyTorch module. - device_id (str): Device ID. - find_unused (bool): Whether to detect parameters without gradients in distributed training. - - Returns: - DistributedDataParallel: PyTorch module wrapped with DDP. - """ - return DDP(module, device_ids=[device_id], find_unused_parameters=find_unused, gradient_as_bucket_view=True) - - -def count_parameters(module: nn.Module, name: str) -> None: - """ - Counts and prints the number of trainable parameters in a module. - - Args: - module (nn.Module): PyTorch module. - module_name (str): Name of model component. - - Returns: - None. - """ - num_params = sum(p.numel() for p in module.parameters() if p.requires_grad) - print(f"# trainable params in {name}: {num_params}") - - -def init_module( - module_class: Type[nn.Module], - module_name: str, - cfg: FinetuneConfig, - device_id: int, - module_args: dict, - to_bf16: bool = False, - find_unused_params: bool = False, -) -> DDP: - """ - Initializes a module, optionally loads checkpoint, moves to device, and wraps with DDP. - - Args: - module_class (Type[nn.Module]): Class of PyTorch module to initialize. - module_name (str): Name of model component to load checkpoint for. - cfg (FinetuneConfig): Training configuration. - device_id (str): Device ID. - module_args (dict): Args for initializing the module. - to_bf16 (bool): Whether to convert to torch.bfloat16 data type. - find_unused_params (bool): Whether to detect parameters without gradients in distributed training. - - Returns: - DistributedDataParallel: PyTorch module wrapped with DDP. - """ - module = module_class(**module_args) - count_parameters(module, module_name) - - if cfg.resume: - state_dict = load_checkpoint(module_name, cfg.vla_path, cfg.resume_step) - module.load_state_dict(state_dict) - - if to_bf16: - module = module.to(torch.bfloat16) - module = module.to(device_id) - - return wrap_ddp(module, device_id, find_unused_params) - - -def run_forward_pass( - vla, - action_head, - noisy_action_projector, - proprio_projector, - batch, - action_tokenizer, - device_id, - use_l1_regression, - use_diffusion, - use_proprio, - use_film, - num_patches, - compute_diffusion_l1=False, - num_diffusion_steps_train=None, -) -> Tuple[torch.Tensor, Dict[str, float]]: - """ - Compute model forward pass and metrics for both training and validation. - - Args: - vla (OpenVLAForActionPrediction): Vision-language-action policy. - action_head (nn.Module): Action head module. - noisy_action_projector (nn.Module): Noisy action projector module (only used for diffusion). - proprio_projector (nn.Module): Proprioceptive state projector module. - batch (dict): Input batch. - action_tokenizer (ActionTokenizer): Action tokenizer. - device_id (str): Device ID. - use_l1_regression (bool): Whether to use L1 regression. - use_diffusion (bool): Whether to use diffusion. - use_proprio (bool): Whether to use proprioceptive state as input. - use_film (bool): Whether to use FiLM for better language following. - num_patches (int): Number of vision patches. - compute_diffusion_l1 (bool): Whether to sample actions and compute L1 loss for diffusion (do this once every - diffusion_sample_freq steps during training; do it every batch for validation) - num_diffusion_steps_train (int): Number of diffusion steps for training (only used for diffusion). - - Returns: - tuple: (loss, metrics_dict) - loss: The loss tensor with gradient for backpropagation. - metrics_dict: Dictionary of computed metrics (detached values for logging). - """ - metrics = {} - - # Get ground-truth action labels - ground_truth_actions = batch["actions"].to(device_id).to(torch.bfloat16) - - # [Only for diffusion] Sample noisy actions used as input for noise predictor network - if use_diffusion: - noisy_dict = action_head.module.sample_noisy_actions(ground_truth_actions) - noise, noisy_actions, diffusion_timestep_embeddings = ( - noisy_dict["noise"], - noisy_dict["noisy_actions"], - noisy_dict["diffusion_timestep_embeddings"], - ) - else: - noise, noisy_actions, diffusion_timestep_embeddings = None, None, None - - # VLA forward pass - with torch.autocast("cuda", dtype=torch.bfloat16): - output: CausalLMOutputWithPast = vla( - input_ids=batch["input_ids"].to(device_id), - attention_mask=batch["attention_mask"].to(device_id), - pixel_values=batch["pixel_values"].to(torch.bfloat16).to(device_id), - labels=batch["labels"], - output_hidden_states=True, - proprio=batch["proprio"] if use_proprio else None, - proprio_projector=proprio_projector if use_proprio else None, - noisy_actions=noisy_actions if use_diffusion else None, - noisy_action_projector=noisy_action_projector if use_diffusion else None, - diffusion_timestep_embeddings=diffusion_timestep_embeddings if use_diffusion else None, - use_film=use_film, - ) - - # Get action masks needed for logging - ground_truth_token_ids = batch["labels"][:, 1:].to(device_id) - current_action_mask = get_current_action_mask(ground_truth_token_ids) - next_actions_mask = get_next_actions_mask(ground_truth_token_ids) - - # Compute metrics for discrete action representation (next-token prediction) - if not (use_l1_regression or use_diffusion): - loss = output.loss - predicted_token_ids = output.logits[:, num_patches:-1].argmax(dim=2) - curr_action_accuracy = compute_token_accuracy( - predicted_token_ids, ground_truth_token_ids, mask=current_action_mask - ) - curr_action_l1_loss = compute_actions_l1_loss( - action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask=current_action_mask - ) - next_actions_accuracy = compute_token_accuracy( - predicted_token_ids, ground_truth_token_ids, mask=next_actions_mask - ) - next_actions_l1_loss = compute_actions_l1_loss( - action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask=next_actions_mask - ) - metrics.update( - { - "loss_value": loss.item(), # Detached value for logging - "curr_action_accuracy": curr_action_accuracy.item(), - "curr_action_l1_loss": curr_action_l1_loss.item(), - "next_actions_accuracy": next_actions_accuracy.item(), - "next_actions_l1_loss": next_actions_l1_loss.item(), - } - ) - # Compute metrics for continuous action representations (L1 regression | diffusion) - else: - # Get last layer hidden states - last_hidden_states = output.hidden_states[-1] # (B, seq_len, D) - # Get hidden states for text portion of prompt+response (after the vision patches) - text_hidden_states = last_hidden_states[:, num_patches:-1] - # Get hidden states for action portion of response - batch_size = batch["input_ids"].shape[0] - actions_hidden_states = ( - text_hidden_states[current_action_mask | next_actions_mask] - .reshape(batch_size, NUM_ACTIONS_CHUNK * ACTION_DIM, -1) - .to(torch.bfloat16) - ) # (B, act_chunk_len, D) - - if use_l1_regression: - # Predict action - predicted_actions = action_head.module.predict_action(actions_hidden_states) - # Get full L1 loss - loss = torch.nn.L1Loss()(ground_truth_actions, predicted_actions) - - if use_diffusion: - # Predict noise - noise_pred = action_head.module.predict_noise(actions_hidden_states) - # Get diffusion noise prediction MSE loss - noise_pred = noise_pred.reshape(noise.shape) - loss = nn.functional.mse_loss(noise_pred, noise, reduction="mean") - - # Only sample actions and compute L1 losses if specified - if compute_diffusion_l1: - with torch.no_grad(): - predicted_actions = run_diffusion_sampling( - vla=vla, - action_head=action_head, - noisy_action_projector=noisy_action_projector, - proprio_projector=proprio_projector, - batch=batch, - batch_size=batch_size, - num_patches=num_patches, - actions_shape=ground_truth_actions.shape, - device_id=device_id, - current_action_mask=current_action_mask, - next_actions_mask=next_actions_mask, - use_proprio=use_proprio, - use_film=use_film, - ) - - metrics.update( - { - "loss_value": loss.item(), # Detached value for logging - } - ) - - # Get detailed L1 losses for logging - should_log_l1_loss = not use_diffusion or (use_diffusion and compute_diffusion_l1) - if should_log_l1_loss: - ground_truth_curr_action = ground_truth_actions[:, 0] - predicted_curr_action = predicted_actions[:, 0] - ground_truth_next_actions = ground_truth_actions[:, 1:] - predicted_next_actions = predicted_actions[:, 1:] - curr_action_l1_loss = torch.nn.L1Loss()(ground_truth_curr_action, predicted_curr_action) - next_actions_l1_loss = torch.nn.L1Loss()(ground_truth_next_actions, predicted_next_actions) - metrics.update( - { - "curr_action_l1_loss": curr_action_l1_loss.item(), - "next_actions_l1_loss": next_actions_l1_loss.item(), - } - ) - - # Return both the loss tensor (with gradients) and the metrics dictionary (with detached values) - return loss, metrics - - -def run_diffusion_sampling( - vla, - action_head, - noisy_action_projector, - proprio_projector, - batch, - batch_size, - num_patches, - actions_shape, - device_id, - current_action_mask, - next_actions_mask, - use_proprio, - use_film, -) -> torch.Tensor: - """ - Run diffusion sampling (reverse diffusion) to generate actions. - - Args: - vla (OpenVLAForActionPrediction): Vision-language-action policy. - action_head (nn.Module): Action head module. - noisy_action_projector (nn.Module): Noisy action projector module (only used for diffusion). - proprio_projector (nn.Module): Proprioceptive state projector module. - batch (dict): Input batch. - batch_size (int): Batch size. - num_patches (int): Number of vision patches. - actions_shape (tuple): Shape of ground-truth actions. - device_id (str): Device ID. - current_action_mask (torch.Tensor): Mask for current action. - next_actions_mask (torch.Tensor): Mask for next actions. - use_proprio (bool): Whether to use proprioceptive state as input. - use_film (bool): Whether to use FiLM for better language following. - - Returns: - torch.Tensor: Predicted actions. - """ - # Sample random noisy action, used as the starting point for reverse diffusion - noise = torch.randn( - size=(batch_size, NUM_ACTIONS_CHUNK, ACTION_DIM), - device=device_id, - dtype=torch.bfloat16, - ) # (B, chunk_len, action_dim) - - # Set diffusion timestep values - action_head.module.noise_scheduler.set_timesteps(action_head.module.num_diffusion_steps_train) - - # Reverse diffusion: Iteratively denoise to generate action, conditioned on observation - curr_noisy_actions = noise - for t in action_head.module.noise_scheduler.timesteps: - # Get diffusion model's noise prediction (conditioned on VLA latent embedding, current noisy action embedding, - # and diffusion timestep embedding) - timesteps = torch.Tensor([t]).repeat(batch_size).to(device_id) - diffusion_timestep_embeddings = ( - action_head.module.time_encoder(timesteps).to(curr_noisy_actions.dtype).to(curr_noisy_actions.device) - ) # (B, llm_dim) - diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) # (B, 1, llm_dim) - - with torch.autocast("cuda", dtype=torch.bfloat16): - output = vla( - input_ids=batch["input_ids"].to(device_id), - attention_mask=batch["attention_mask"].to(device_id), - pixel_values=batch["pixel_values"].to(torch.bfloat16).to(device_id), - labels=batch["labels"], - output_hidden_states=True, - proprio=batch["proprio"] if use_proprio else None, - proprio_projector=proprio_projector if use_proprio else None, - noisy_actions=curr_noisy_actions, - noisy_action_projector=noisy_action_projector, - diffusion_timestep_embeddings=diffusion_timestep_embeddings, - use_film=use_film, - ) - # Get last layer hidden states - last_hidden_states = output.hidden_states[-1] # (B, seq_len, D) - # Get hidden states for text portion of prompt+response (after the vision patches) - text_hidden_states = last_hidden_states[:, num_patches:-1] - # Get hidden states for action portion of response - actions_hidden_states = text_hidden_states[current_action_mask | next_actions_mask].reshape( - batch_size, NUM_ACTIONS_CHUNK * ACTION_DIM, -1 - ) # (B, act_chunk_len, D) - actions_hidden_states = actions_hidden_states.to(torch.bfloat16) - # Predict noise - noise_pred = action_head.module.predict_noise(actions_hidden_states) - - # Compute the action at the previous diffusion timestep: x_t -> x_{t-1} - curr_noisy_actions = action_head.module.noise_scheduler.step(noise_pred, t, curr_noisy_actions).prev_sample - - return curr_noisy_actions.reshape(actions_shape) - - -def compute_smoothened_metrics(metrics_deques) -> dict: - """ - Compute smoothened metrics from recent deques. - - Args: - metrics_deques (dict): Dictionary of deques containing recent metrics. - - Returns: - dict: Dictionary of smoothened metrics. - """ - smoothened_metrics = {} - for name, deque in metrics_deques.items(): - if deque and len(deque) > 0: - smoothened_metrics[name] = sum(deque) / len(deque) - return smoothened_metrics - - -def log_metrics_to_wandb(metrics, prefix, step, wandb_entity) -> None: - """ - Log metrics to Weights & Biases. - - Args: - metrics (dict): Dictionary of metrics to log - prefix (str): Prefix for metric names - step (int): Training step - wandb_entity (str): W&B entity instance - - Returns: - None. - """ - log_dict = {} - for name, value in metrics.items(): - # Map loss_value to Loss for better readability in W&B - if name == "loss_value": - log_dict[f"{prefix}/Loss"] = value - # Keep other metrics as is - else: - log_dict[f"{prefix}/{name.replace('_', ' ').title()}"] = value - wandb_entity.log(log_dict, step=step) - - -def save_training_checkpoint( - cfg, - run_dir, - log_step, - vla, - processor, - proprio_projector, - noisy_action_projector, - action_head, - train_dataset, - distributed_state, -) -> None: - """ - Save all training checkpoints including model components, LoRA adapter, and dataset statistics. - - Args: - cfg (FinetuneConfig): Training configuration. - run_dir (Path): Experiment run directory path. - log_step (int): Current logging step. - vla (OpenVLAForActionPrediction): Vision-language-action policy. - processor (PrismaticProcessor): OpenVLA inputs processor. - proprio_projector (nn.Module): Proprioceptive state projector module. - noisy_action_projector (nn.Module): Noisy action projector module (only used for diffusion). - action_head (nn.Module): Action head module. - train_dataset (RLDSDataset): Training dataset. - distributed_state (PartialState): Distributed training state. - - Returns: - None. - """ - # Determine checkpoint paths and naming - if cfg.save_latest_checkpoint_only: - checkpoint_dir = run_dir - checkpoint_name_suffix = "latest_checkpoint.pt" - else: - checkpoint_dir = Path(str(run_dir) + f"--{log_step}_chkpt") - checkpoint_name_suffix = f"{log_step}_checkpoint.pt" - - adapter_dir = checkpoint_dir / "lora_adapter" - - # Create directories and save dataset statistics (main process only) - if distributed_state.is_main_process: - os.makedirs(checkpoint_dir, exist_ok=True) - os.makedirs(adapter_dir, exist_ok=True) - save_dataset_statistics(train_dataset.dataset_statistics, checkpoint_dir) - print(f"Saving Model Checkpoint for Step {log_step}") - - # Wait for directories to be created - dist.barrier() - - # Save model components (main process only) - if distributed_state.is_main_process: - # Save processor and LoRA adapter - processor.save_pretrained(checkpoint_dir) - vla.module.save_pretrained(adapter_dir) - - # Save other components - if cfg.use_proprio and proprio_projector is not None: - torch.save(proprio_projector.state_dict(), checkpoint_dir / f"proprio_projector--{checkpoint_name_suffix}") - - if cfg.use_diffusion and noisy_action_projector is not None: - torch.save( - noisy_action_projector.state_dict(), checkpoint_dir / f"noisy_action_projector--{checkpoint_name_suffix}" - ) - - if (cfg.use_l1_regression or cfg.use_diffusion) and action_head is not None: - torch.save(action_head.state_dict(), checkpoint_dir / f"action_head--{checkpoint_name_suffix}") - - if cfg.use_film: - # To be safe, just save the entire vision backbone (not just FiLM components) - torch.save( - vla.module.vision_backbone.state_dict(), checkpoint_dir / f"vision_backbone--{checkpoint_name_suffix}" - ) - - # Wait for model components to be saved - dist.barrier() - - # Merge LoRA weights into base model and save resulting model checkpoint - # Note: Can be very slow on some devices; if so, we recommend merging offline - if cfg.use_lora and cfg.merge_lora_during_training: - base_vla = AutoModelForVision2Seq.from_pretrained( - cfg.vla_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True - ) - merged_vla = PeftModel.from_pretrained(base_vla, adapter_dir) - merged_vla = merged_vla.merge_and_unload() - - if distributed_state.is_main_process: - merged_vla.save_pretrained(checkpoint_dir) - print(f"Saved merged model for Step {log_step} at: {checkpoint_dir}") - - # Wait for merged model to be saved - dist.barrier() - - -def run_validation( - vla, - action_head, - noisy_action_projector, - proprio_projector, - val_dataloader, - action_tokenizer, - device_id, - cfg, - num_patches, - log_step, - distributed_state, - val_time_limit, -) -> None: - """ - Compute validation set metrics for logging. - - Args: - vla (OpenVLAForActionPrediction): Vision-language-action policy. - action_head (nn.Module): Action head module. - noisy_action_projector (nn.Module): Noisy action projector module (only used for diffusion). - proprio_projector (nn.Module): Proprioceptive state projector module. - val_dataloader (DataLoader): Validation data loader. - action_tokenizer (ActionTokenizer): Action tokenizer. - device_id (str): Device ID. - cfg (FinetuneConfig): Training configuration. - num_patches (int): Number of vision patches. - log_step (int): Current logging step. - distributed_state (PartialState): Distributed training state. - val_time_limit (int): Time limit for computing validation metrics. - - Returns: - None. - """ - val_start_time = time.time() - vla.eval() - val_batches_count = 0 - - # List to store validation metrics - all_val_metrics = [] - - with torch.no_grad(): - for batch in val_dataloader: - # Always compute L1 loss for validation, even for diffusion - _, metrics = run_forward_pass( - vla=vla, - action_head=action_head, - noisy_action_projector=noisy_action_projector, - proprio_projector=proprio_projector, - batch=batch, - action_tokenizer=action_tokenizer, - device_id=device_id, - use_l1_regression=cfg.use_l1_regression, - use_diffusion=cfg.use_diffusion, - use_proprio=cfg.use_proprio, - use_film=cfg.use_film, - num_patches=num_patches, - compute_diffusion_l1=True, - num_diffusion_steps_train=cfg.num_diffusion_steps_train if cfg.use_diffusion else None, - ) - - # Add the loss value to the metrics - metrics["loss"] = metrics["loss_value"] - all_val_metrics.append(metrics) - val_batches_count += 1 - - # Cut testing on validation set short if it exceeds time limit - if time.time() - val_start_time > val_time_limit: - break - - # Compute average validation metrics - avg_val_metrics = {} - for metric_name in all_val_metrics[0].keys(): - values = [metrics[metric_name] for metrics in all_val_metrics if metric_name in metrics] - if values: - avg_val_metrics[metric_name] = sum(values) / len(values) - - # Add batch count to metrics - avg_val_metrics["val_batches_count"] = val_batches_count - - # Log validation metrics to W&B - if distributed_state.is_main_process: - log_metrics_to_wandb(avg_val_metrics, "VLA Val", log_step, wandb) - - -@draccus.wrap() -def finetune(cfg: FinetuneConfig) -> None: - """ - Fine-tunes base VLA on demonstration dataset via LoRA. - - Allows toggling different action representations (discrete vs. continuous), different learning objectives - (next-token prediction vs. L1 regression vs. diffusion), FiLM. Also allows for additional model inputs, - such as additional camera images and robot proprioceptive state. Assumes parallel action generation with - action chunking. - - Args: - cfg (FinetuneConfig): Training configuration. - - Returns: - None. - """ - assert cfg.use_lora, "Only LoRA fine-tuning is supported. Please set --use_lora=True!" - assert not (cfg.use_l1_regression and cfg.use_diffusion), ( - "Cannot do both L1 regression and diffusion. Please pick one of them!" - ) - - # Trim trailing forward slash ('/') in VLA path if it exists - cfg.vla_path = cfg.vla_path.rstrip("/") - print(f"Fine-tuning OpenVLA Model `{cfg.vla_path}` on `{cfg.dataset_name}`") - - # Get experiment run ID - run_id = get_run_id(cfg) - - # Create experiment run directory - run_dir = cfg.run_root_dir / run_id - os.makedirs(run_dir, exist_ok=True) - - # GPU setup - distributed_state = PartialState() - device_id = distributed_state.local_process_index - torch.cuda.set_device(device_id) - torch.cuda.empty_cache() - - # Initialize wandb logging - if distributed_state.is_main_process: - wandb.init(entity=cfg.wandb_entity, project=cfg.wandb_project, name=run_id) - - # Print detected constants - print( - "Detected constants:\n" - f"\tNUM_ACTIONS_CHUNK: {NUM_ACTIONS_CHUNK}\n" - f"\tACTION_DIM: {ACTION_DIM}\n" - f"\tPROPRIO_DIM: {PROPRIO_DIM}\n" - f"\tACTION_PROPRIO_NORMALIZATION_TYPE: {ACTION_PROPRIO_NORMALIZATION_TYPE}" - ) - - # Two options: - # (1) Base model is on Hugging Face Hub - # - Then download it and record the path to the download directory - # (2) Base model is stored locally - # - Then register model config in HF Auto Classes - # In both cases, we want to check whether any changes have been made to - # the `modeling_prismatic.py` file in this codebase; if so, we will copy - # the file to the downloaded or locally stored checkpoint directory so - # that the user's changes to the VLA class logic go into effect - if model_is_on_hf_hub(cfg.vla_path): - # Download model directly from Hugging Face Hub - vla_download_path = snapshot_download(repo_id=cfg.vla_path) - # Overwrite VLA path - cfg.vla_path = vla_download_path - else: - # Register OpenVLA model to HF Auto Classes (not needed if the model is on HF Hub) - AutoConfig.register("openvla", OpenVLAConfig) - AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor) - AutoProcessor.register(OpenVLAConfig, PrismaticProcessor) - AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction) - - # Update config.json and sync model files - if distributed_state.is_main_process: - update_auto_map(cfg.vla_path) - check_model_logic_mismatch(cfg.vla_path) - - # Wait for model files to be synced - dist.barrier() - - # Load processor and VLA - processor = AutoProcessor.from_pretrained(cfg.vla_path, trust_remote_code=True) - vla = AutoModelForVision2Seq.from_pretrained( - cfg.vla_path, - torch_dtype=torch.bfloat16, - low_cpu_mem_usage=True, - trust_remote_code=True, - ).to(device_id) - - # Set number of images in VLA input - vla.vision_backbone.set_num_images_in_input(cfg.num_images_in_input) - - # LoRA setup - if cfg.use_lora: - lora_config = LoraConfig( - r=cfg.lora_rank, - lora_alpha=min(cfg.lora_rank, 16), - lora_dropout=cfg.lora_dropout, - target_modules="all-linear", - init_lora_weights="gaussian", - ) - vla = get_peft_model(vla, lora_config) - vla.print_trainable_parameters() - - # FiLM setup - if cfg.use_film: - count_parameters(vla.vision_backbone, "vla.vision_backbone (original)") - # Wrap vision backbone with FiLM wrapper - # Important: For this, must specify `vla.model.vision_backbone` instead of just `vla.vision_backbone`, since the - # latter would cause the new wrapped backbone to be saved as a new attribute of `vla` instead of overwriting the - # original one (due to the LoRA wrapper) - vla.model.vision_backbone = FiLMedPrismaticVisionBackbone( - vision_backbone=vla.model.vision_backbone, - llm_dim=vla.llm_dim, - ) - count_parameters(vla.vision_backbone, "vla.vision_backbone (post-wrap)") - if cfg.resume: - state_dict = load_checkpoint("vision_backbone", cfg.vla_path, cfg.resume_step) - vla.model.vision_backbone.load_state_dict(state_dict) - vla.model.vision_backbone = vla.model.vision_backbone.to(device_id) - - # Wrap VLA with DDP - vla = wrap_ddp(vla, device_id, find_unused=True) - - # If applicable, instantiate proprio projector - if cfg.use_proprio: - proprio_projector = init_module( - ProprioProjector, - "proprio_projector", - cfg, - device_id, - {"llm_dim": vla.module.llm_dim, "proprio_dim": PROPRIO_DIM}, - ) - - # If applicable, instantiate continuous action head for L1 regression - if cfg.use_l1_regression: - action_head = init_module( - L1RegressionActionHead, - "action_head", - cfg, - device_id, - {"input_dim": vla.module.llm_dim, "hidden_dim": vla.module.llm_dim, "action_dim": ACTION_DIM}, - to_bf16=True, - ) - - # If applicable, instantiate diffusion action head and noisy action projector - if cfg.use_diffusion: - action_head = init_module( - DiffusionActionHead, - "action_head", - cfg, - device_id, - { - "input_dim": vla.module.llm_dim, - "hidden_dim": vla.module.llm_dim, - "action_dim": ACTION_DIM, - "num_diffusion_steps_train": cfg.num_diffusion_steps_train, - }, - to_bf16=True, - ) - noisy_action_projector = init_module( - NoisyActionProjector, "noisy_action_projector", cfg, device_id, {"llm_dim": vla.module.llm_dim} - ) - - # Get number of vision patches - NUM_PATCHES = vla.module.vision_backbone.get_num_patches() * vla.module.vision_backbone.get_num_images_in_input() - # If we have proprio inputs, a single proprio embedding is appended to the end of the vision patch embeddings - if cfg.use_proprio: - NUM_PATCHES += 1 - # For diffusion, a single diffusion timestep embedding is appended to the end of the vision patch embeddings - if cfg.use_diffusion: - NUM_PATCHES += 1 - - # Instantiate optimizer - trainable_params = [param for param in vla.parameters() if param.requires_grad] - if cfg.use_l1_regression or cfg.use_diffusion: - trainable_params += [param for param in action_head.parameters() if param.requires_grad] - if cfg.use_diffusion: - trainable_params += [param for param in noisy_action_projector.parameters() if param.requires_grad] - if cfg.use_proprio: - trainable_params += [param for param in proprio_projector.parameters() if param.requires_grad] - print(f"# total trainable params: {sum(p.numel() for p in trainable_params)}") - optimizer = AdamW(trainable_params, lr=cfg.learning_rate) - - # Record original learning rate - original_lr = optimizer.param_groups[0]["lr"] - - # Create learning rate scheduler - scheduler = MultiStepLR( - optimizer, - milestones=[cfg.num_steps_before_decay], # Number of steps after which LR will change - gamma=0.1, # Multiplicative factor of learning rate decay - ) - - # Create Action Tokenizer - action_tokenizer = ActionTokenizer(processor.tokenizer) - - # Load Fine-tuning Dataset =>> note that we use an RLDS-formatted dataset following Open X-Embodiment by default. - # =>> If you want to use a non-RLDS dataset (e.g., a standard PyTorch Dataset) see the following commented block. - # =>> Note that our training code does not loop over epochs because the RLDS loader does this implicitly; if using - # your own Dataset, make sure to add the appropriate logic to the training loop! - # - # --- - # from prismatic.vla.datasets import DummyDataset - # - # train_dataset = DummyDataset( - # action_tokenizer, - # processor.tokenizer, - # image_transform=processor.image_processor.apply_transform, - # prompt_builder_fn=PurePromptBuilder, - # ) - # --- - - # We assume that the model takes as input one third-person camera image and 1 or 2 optional wrist camera image(s) - use_wrist_image = cfg.num_images_in_input > 1 - - # Create training and optional validation datasets - batch_transform = RLDSBatchTransform( - action_tokenizer, - processor.tokenizer, - image_transform=processor.image_processor.apply_transform, - prompt_builder_fn=PurePromptBuilder, - use_wrist_image=use_wrist_image, - use_proprio=cfg.use_proprio, - ) - train_dataset = RLDSDataset( - cfg.data_root_dir, - cfg.dataset_name, - batch_transform, - resize_resolution=tuple(vla.module.config.image_sizes), - shuffle_buffer_size=cfg.shuffle_buffer_size, - image_aug=cfg.image_aug, - ) - if cfg.use_val_set: - val_dataset = RLDSDataset( - cfg.data_root_dir, - cfg.dataset_name, - batch_transform, - resize_resolution=tuple(vla.module.config.image_sizes), - shuffle_buffer_size=cfg.shuffle_buffer_size // 10, - image_aug=cfg.image_aug, - train=False, - ) - - # [Important] Save dataset statistics so that we can unnormalize actions during inference - if distributed_state.is_main_process: - save_dataset_statistics(train_dataset.dataset_statistics, run_dir) - - # Create collator and dataloader - collator = PaddedCollatorForActionPrediction( - processor.tokenizer.model_max_length, processor.tokenizer.pad_token_id, padding_side="right" - ) - dataloader = DataLoader( - train_dataset, - batch_size=cfg.batch_size, - sampler=None, - collate_fn=collator, - num_workers=0, # Important: Set to 0 if using RLDS, which uses its own parallelism - ) - if cfg.use_val_set: - val_batch_size = cfg.batch_size - val_dataloader = DataLoader( - val_dataset, - batch_size=val_batch_size, - sampler=None, - collate_fn=collator, - num_workers=0, # Important: Set to 0 if using RLDS, which uses its own parallelism - ) - - # Deque to store recent train metrics (used for computing smoothened metrics for gradient accumulation) - recent_metrics = { - "loss_value": deque(maxlen=cfg.grad_accumulation_steps), - "curr_action_accuracy": deque(maxlen=cfg.grad_accumulation_steps), - "curr_action_l1_loss": deque(maxlen=cfg.grad_accumulation_steps), - "next_actions_accuracy": deque(maxlen=cfg.grad_accumulation_steps), - "next_actions_l1_loss": deque(maxlen=cfg.grad_accumulation_steps), - } - - # Start training - with tqdm.tqdm(total=cfg.max_steps, leave=False) as progress: - vla.train() - optimizer.zero_grad() - for batch_idx, batch in enumerate(dataloader): - # Compute training metrics and loss - compute_diffusion_l1 = cfg.use_diffusion and batch_idx % cfg.diffusion_sample_freq == 0 - loss, metrics = run_forward_pass( - vla=vla, - action_head=action_head, - noisy_action_projector=noisy_action_projector if cfg.use_diffusion else None, - proprio_projector=proprio_projector if cfg.use_proprio else None, - batch=batch, - action_tokenizer=action_tokenizer, - device_id=device_id, - use_l1_regression=cfg.use_l1_regression, - use_diffusion=cfg.use_diffusion, - use_proprio=cfg.use_proprio, - use_film=cfg.use_film, - num_patches=NUM_PATCHES, - compute_diffusion_l1=compute_diffusion_l1, - num_diffusion_steps_train=cfg.num_diffusion_steps_train if cfg.use_diffusion else None, - ) - - # Normalize loss to account for gradient accumulation - normalized_loss = loss / cfg.grad_accumulation_steps - - # Backward pass - normalized_loss.backward() - - # Store recent train metrics - for metric_name, value in metrics.items(): - if metric_name in recent_metrics: - recent_metrics[metric_name].append(value) - - # Compute gradient step index - gradient_step_idx = batch_idx // cfg.grad_accumulation_steps - - # Compute smoothened train metrics - smoothened_metrics = compute_smoothened_metrics(recent_metrics) - - # Push Metrics to W&B (every wandb_log_freq gradient steps) - log_step = gradient_step_idx if not cfg.resume else cfg.resume_step + gradient_step_idx - if distributed_state.is_main_process and log_step % cfg.wandb_log_freq == 0: - log_metrics_to_wandb(smoothened_metrics, "VLA Train", log_step, wandb) - - # [If applicable] Linearly warm up learning rate from 10% to 100% of original - if cfg.lr_warmup_steps > 0: - lr_progress = min((gradient_step_idx + 1) / cfg.lr_warmup_steps, 1.0) # Cap at 1.0 - current_lr = original_lr * (0.1 + 0.9 * lr_progress) - for param_group in optimizer.param_groups: - param_group["lr"] = current_lr - - if distributed_state.is_main_process and gradient_step_idx % cfg.wandb_log_freq == 0: - # Log the learning rate - # Make sure to do this AFTER any learning rate modifications (e.g., warmup/decay) - wandb.log( - { - "VLA Train/Learning Rate": scheduler.get_last_lr()[0], - }, - step=log_step, - ) - - # Optimizer and LR scheduler step - if (batch_idx + 1) % cfg.grad_accumulation_steps == 0: - optimizer.step() - scheduler.step() - optimizer.zero_grad() - progress.update() - - # Save model checkpoint: either keep latest checkpoint only or all checkpoints - if gradient_step_idx > 0 and log_step % cfg.save_freq == 0: - save_training_checkpoint( - cfg=cfg, - run_dir=run_dir, - log_step=log_step, - vla=vla, - processor=processor, - proprio_projector=proprio_projector if cfg.use_proprio else None, - noisy_action_projector=noisy_action_projector if cfg.use_diffusion else None, - action_head=action_head if (cfg.use_l1_regression or cfg.use_diffusion) else None, - train_dataset=train_dataset, - distributed_state=distributed_state, - ) - - # Test model on validation set - if cfg.use_val_set and log_step > 0 and log_step % cfg.val_freq == 0: - run_validation( - vla=vla, - action_head=action_head, - noisy_action_projector=noisy_action_projector if cfg.use_diffusion else None, - proprio_projector=proprio_projector if cfg.use_proprio else None, - val_dataloader=val_dataloader, - action_tokenizer=action_tokenizer, - device_id=device_id, - cfg=cfg, - num_patches=NUM_PATCHES, - log_step=log_step, - distributed_state=distributed_state, - val_time_limit=cfg.val_time_limit, - ) - # Set model back to training mode after validation - vla.train() - - # Stop training when max_steps is reached - if log_step == cfg.max_steps: - print(f"Max step {cfg.max_steps} reached! Stopping training...") - break - - -if __name__ == "__main__": - finetune() diff --git a/capvector-oft/vla-scripts/finetune_regular_loss.py b/capvector-oft/vla-scripts/finetune_regular_loss.py deleted file mode 100644 index eef55749a8f81c57575b8eef1052617db6bd49a5..0000000000000000000000000000000000000000 --- a/capvector-oft/vla-scripts/finetune_regular_loss.py +++ /dev/null @@ -1,1790 +0,0 @@ -#This is for the experiment of CapVector, stopping the gradient propagation in the direction of the new added vector -""" -finetune.py - -Fine-tunes OpenVLA via LoRA. -""" - -import os -import ctypes - -lib_path = "/share/miniconda3/lib/libstdc++.so.6" - -try: - ctypes.CDLL(lib_path) - print(f"Successfully preloaded {lib_path}") -except Exception as e: - print(f"Failed to preload {lib_path}: {e}") - -import os -import time -from collections import deque -from dataclasses import dataclass -from pathlib import Path -from typing import Dict, Optional, Tuple, Type - -import draccus -import torch -import torch.distributed as dist -import torch.nn as nn -import tqdm -import numpy as np -from accelerate import PartialState -from huggingface_hub import HfApi, snapshot_download -from peft import LoraConfig, PeftModel, get_peft_model -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim import AdamW -from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR -from torch.utils.data import DataLoader -from transformers import get_cosine_schedule_with_warmup -from transformers import AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor -from transformers.modeling_outputs import CausalLMOutputWithPast - -import wandb -os.environ["WANDB_MODE"]="offline" - -try: - from safetensors import safe_open - SAFETENSORS_AVAILABLE = True -except ImportError: - SAFETENSORS_AVAILABLE = False - print("Warning: safetensors not available, will try torch.load instead") - -from experiments.robot.openvla_utils import ( - check_model_logic_mismatch, - model_is_on_hf_hub, - update_auto_map, -) - -from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig -from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction -from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor -from prismatic.models.action_heads import DiffusionActionHead, L1RegressionActionHead -from prismatic.models.backbones.llm.prompting import PurePromptBuilder -from prismatic.models.film_vit_wrapper import FiLMedPrismaticVisionBackbone -from prismatic.models.ema_model import EMAModel -from prismatic.models.projectors import ( - NoisyActionProjector, - ProprioProjector, -) -from prismatic.training.train_utils import ( - compute_actions_l1_loss, - compute_token_accuracy, - get_current_action_mask, - get_next_actions_mask, -) -from prismatic.util.data_utils import PaddedCollatorForActionPrediction -from prismatic.vla.action_tokenizer import ActionTokenizer -from prismatic.vla.constants import ( - ACTION_DIM, - ACTION_PROPRIO_NORMALIZATION_TYPE, - NUM_ACTIONS_CHUNK, - PROPRIO_DIM, -) -from prismatic.vla.datasets import RLDSBatchTransform, RLDSDataset -from prismatic.vla.datasets.rlds.utils.data_utils import save_dataset_statistics - -# Sane Defaults -os.environ["TOKENIZERS_PARALLELISM"] = "false" - -#wx: stop gradient in the feature vector direction - -EPS = 1e-12 - -def register_orthogonal_grad_hook(model, vector_W, debug=False): - name_to_param = dict(model.named_parameters()) - - hooked_A = 0 - hooked_B = 0 - hooked_direct = 0 - - missed = 0 - missed_name = [] - - direct_missed = 0 - direct_missed_name = [] - - printed = {"A": False, "B": False, "D": False} - - def proj_out(g2, v2): - vn2 = (v2 * v2).sum().detach() - if vn2.item() <= EPS: - return g2 - gv = (g2 * v2).sum() - return g2 - (gv / (vn2 + EPS)) * v2 - - for w_name, vW in vector_W.items(): - if "vision_backbone" in w_name: - continue - - prefix = "base_model.model." - A_name = prefix + w_name.replace(".weight", ".lora_A.default.weight") - B_name = prefix + w_name.replace(".weight", ".lora_B.default.weight") - - # ===== 1) 先尝试 LoRA hook ===== - if A_name in name_to_param and B_name in name_to_param: - A = name_to_param[A_name] - B = name_to_param[B_name] - - # 两个都不训练就不 hook - if (not A.requires_grad) and (not B.requires_grad): - continue - - # vW 固定到 device/dtype - vW = vW.to(device=A.device, dtype=A.dtype) - vW2 = vW.reshape(vW.shape[0], -1) if vW.ndim != 2 else vW # [out, in_flat] - - # ---- hook A:动态用当前 B 计算 vA = B^T vW ---- - if A.requires_grad: - def hook_A(g, A_ref=A, B_ref=B, vW2_ref=vW2): - if g is None: - return None - g2 = g.reshape(g.shape[0], -1) if g.ndim != 2 else g - - B_mat = B_ref.detach() - B2 = B_mat.reshape(B_mat.shape[0], -1) if B_mat.ndim != 2 else B_mat # [out, r] - - if B2.shape[0] != vW2_ref.shape[0]: - return g - - vA = torch.matmul(B2.transpose(0, 1), vW2_ref) # [r, in_flat] - - if debug and not printed["A"]: - print(f"[hook fired] A: ||B||={B2.norm().item():.4e}, ||vA||={vA.norm().item():.4e}, ||g||={g2.norm().item():.4e}") - printed["A"] = True - - g2_new = proj_out(g2, vA) - return g2_new.view_as(g) - - A.register_hook(hook_A) - hooked_A += 1 - - # ---- hook B:动态用当前 A 计算 vB = vW A^T ---- - if B.requires_grad: - def hook_B(g, A_ref=A, B_ref=B, vW2_ref=vW2): - if g is None: - return None - g2 = g.reshape(g.shape[0], -1) if g.ndim != 2 else g - - A_mat = A_ref.detach() - A2 = A_mat.reshape(A_mat.shape[0], -1) if A_mat.ndim != 2 else A_mat # [r, in_flat] - - if A2.shape[1] != vW2_ref.shape[1]: - return g - - vB = torch.matmul(vW2_ref, A2.transpose(0, 1)) # [out, r] - - if debug and not printed["B"]: - print(f"[hook fired] B: ||A||={A2.norm().item():.4e}, ||vB||={vB.norm().item():.4e}, ||g||={g2.norm().item():.4e}") - printed["B"] = True - - g2_new = proj_out(g2, vB) - return g2_new.view_as(g) - - B.register_hook(hook_B) - hooked_B += 1 - - # 这一轮已经成功走 LoRA 分支了 - continue - - # ===== 2) LoRA 不存在:fallback 到“直接参数”hook(比如 layernorm)===== - missed += 1 - missed_name.append(w_name) - - # 尝试对齐到非 LoRA 参数名 - # 绝大多数情况下:base_model.model. - direct_name = prefix + w_name - - # 有些 vector 的命名可能不带 base_model.model,而你的模型参数名可能是别的前缀 - # 这里给一个“再尝试一次”的备选:如果 direct_name 找不到,就尝试去掉 language_model/等前缀的情况 - # (你也可以按自己工程实际再加规则) - if direct_name not in name_to_param: - # 再试一次:如果 w_name 本身已经含 base_model.model 就不加 prefix - if w_name in name_to_param: - direct_name = w_name - else: - direct_missed += 1 - direct_missed_name.append(w_name) - continue - - P = name_to_param[direct_name] - if not P.requires_grad: - # 找到了但不训练:不 hook,也不算 direct_missed - continue - - vP = vector_W[w_name].to(device=P.device, dtype=P.dtype) - vP2 = vP.reshape(vP.shape[0], -1) if vP.ndim != 2 else vP - - def hook_direct(g, v_ref=vP2): - if g is None: - return None - g2 = g.reshape(g.shape[0], -1) if g.ndim != 2 else g - - # shape 不匹配就不动(避免 hook 改尺寸报错) - if g2.shape != v_ref.shape: - return g - - if debug and not printed["D"]: - print(f"[hook fired] Direct: param={direct_name}, ||v||={v_ref.norm().item():.4e}, ||g||={g2.norm().item():.4e}") - printed["D"] = True - - g2_new = proj_out(g2, v_ref) - return g2_new.view_as(g) - - P.register_hook(hook_direct) - hooked_direct += 1 - - print( - f"[hook summary] hooked lora_A: {hooked_A}, lora_B: {hooked_B}, direct: {hooked_direct}, " - f"missed(lora-not-found): {missed}, direct_missed: {direct_missed}" - ) - - # 如果你想看具体 miss 列表: - # print("[missed lora-not-found names]") - # for n in missed_name: print(" -", n) - # print("[direct_missed names]") - # for n in direct_missed_name: print(" -", n) - - # import pdb; pdb.set_trace() - - -# def register_orthogonal_grad_hook(model, vector_W, debug=False): -# name_to_param = dict(model.named_parameters()) - -# hooked_A = 0 -# hooked_B = 0 -# missed = 0 - -# printed = {"A": False, "B": False} # 用于只打印一次 - -# for w_name, vW in vector_W.items(): -# if "vision_backbone" in w_name: -# continue -# # import pdb; pdb.set_trace() -# prefix = "base_model.model." -# A_name = prefix + w_name.replace(".weight", ".lora_A.default.weight") -# B_name = prefix + w_name.replace(".weight", ".lora_B.default.weight") - -# if A_name not in name_to_param or B_name not in name_to_param: -# missed += 1 -# continue - -# A = name_to_param[A_name] -# B = name_to_param[B_name] - -# if (not A.requires_grad) and (not B.requires_grad): -# continue - -# vW = vW.to(device=A.device, dtype=A.dtype) - -# with torch.no_grad(): -# # A_mat = A.detach().view(1, -1) # (1, in) -# # B_mat = B.detach().view(-1, 1) # (out,1) - -# # vA = torch.matmul(B_mat.T, vW) # (1,in) -# # vB = torch.matmul(vW, A_mat.T) # (out,1) -# B_mat = B.detach() -# A_mat = A.detach() -# # import pdb; pdb.set_trace() - -# # 统一把 vW 变成二维: [out, in_flat] -# if vW.ndim != 2: -# vW2 = vW.reshape(vW.shape[0], -1) -# else: -# vW2 = vW - -# # A 也可能不是严格二维(一般是二维,但保险起见)#看了一下AB都是二维 -# if A_mat.ndim != 2: -# A2 = A_mat.reshape(A_mat.shape[0], -1) # [r, in_flat] -# else: -# A2 = A_mat - -# # B 通常是二维 [out, r] -# if B_mat.ndim != 2: -# B2 = B_mat.reshape(B_mat.shape[0], -1) # [out, r] -# else: -# B2 = B_mat - -# # 形状校验:不匹配就跳过这个 w_name(避免再报错) -# # 需要:B2: [out, r] 与 vW2: [out, in_flat] 的 out 对齐 -# # 需要:A2: [r, in_flat] 与 vW2: [out, in_flat] 的 in_flat 对齐 -# if B2.shape[0] != vW2.shape[0] or A2.shape[1] != vW2.shape[1] or A2.shape[0] != B2.shape[1]: -# missed += 1 -# continue - -# vA = torch.matmul(B2.transpose(0, 1), vW2) # [r, in_flat] -# vB = torch.matmul(vW2, A2.transpose(0, 1)) # [out, r] - - -# # hook A -# if A.requires_grad: -# vA_norm2 = (vA * vA).sum().detach() -# if vA_norm2.item() > EPS: -# def make_hook_A(v, vn2): -# def hook(g): -# if debug and not printed["A"]: -# print(f"[hook fired] lora_A grad norm: {g.norm().item():.4e}") -# printed["A"] = True -# gv = (g * v).sum() -# proj = (gv / (vn2 + EPS)) * v -# return g - proj -# return hook - -# A.register_hook(make_hook_A(vA, vA_norm2)) -# hooked_A += 1 - -# # hook B -# if B.requires_grad: -# vB_norm2 = (vB * vB).sum().detach() -# if vB_norm2.item() > EPS: -# def make_hook_B(v, vn2): -# def hook(g): -# if debug and not printed["B"]: -# print(f"[hook fired] lora_B grad norm: {g.norm().item():.4e}") -# printed["B"] = True -# gv = (g * v).sum() -# proj = (gv / (vn2 + EPS)) * v -# return g - proj -# return hook - -# B.register_hook(make_hook_B(vB, vB_norm2)) -# hooked_B += 1 - -# print(f"[hook summary] hooked lora_A: {hooked_A}, hooked lora_B: {hooked_B}, missed: {missed}") -# import pdb; pdb.set_trace() - - - -# 用法: -# vector_sd = torch.load("your_vector.pth")["state_dict"] or similar -# register_orthogonal_grad_hook(model, vector_sd) - - -# import debugpy -# try: -# debugpy.listen(("localhost", 9501)) -# print("Waiting for debugger attach") -# debugpy.wait_for_client() -# except Exception as e: -# pass - - -@dataclass -class FinetuneConfig: - # fmt: off - vla_path: str = "openvla/openvla-7b" # Path to OpenVLA model (on HuggingFace Hub or stored locally) - - # Dataset - data_root_dir: Path = Path("datasets/rlds") # Directory containing RLDS datasets - dataset_name: str = "aloha_scoop_x_into_bowl" # Name of fine-tuning dataset (e.g., `aloha_scoop_x_into_bowl`) - run_root_dir: Path = Path("runs") # Path to directory to store logs & checkpoints - shuffle_buffer_size: int = 100_000 # Dataloader shuffle buffer size (can reduce if OOM errors occur) - - # Algorithm and architecture - use_l1_regression: bool = True # If True, trains continuous action head with L1 regression objective - use_diffusion: bool = False # If True, trains continuous action head with diffusion modeling objective (DDIM) - num_diffusion_steps_train: int = 50 # (When `diffusion==True`) Number of diffusion steps used for training - use_film: bool = False # If True, uses FiLM to infuse language inputs into visual features - num_images_in_input: int = 1 # Number of images in the VLA input (default: 1) - use_proprio: bool = False # If True, includes robot proprioceptive state in input - - # Training configuration - batch_size: int = 8 # Batch size per device (total batch size = batch_size * num GPUs) - learning_rate: float = 5e-4 # Learning rate - lr_warmup_steps: int = 0 # Number of steps to warm up learning rate (from 10% to 100%) - num_steps_before_decay: int = 100_000 # Number of steps before LR decays by 10x - grad_accumulation_steps: int = 1 # Number of gradient accumulation steps - max_steps: int = 200_000 # Max number of training steps - use_val_set: bool = False # If True, uses validation set and log validation metrics - val_freq: int = 10_000 # (When `use_val_set==True`) Validation set logging frequency in steps - val_time_limit: int = 180 # (When `use_val_set==True`) Time limit for computing validation metrics - save_freq: int = 10_000 # Checkpoint saving frequency in steps - save_latest_checkpoint_only: bool = False # If True, saves only 1 checkpoint, overwriting latest checkpoint - # (If False, saves all checkpoints) - scheduler: str = 'MultiStepLR' # "MultiStepLR" or "CosineAnnealingLR" or "WarmupCosineLR" - resume: bool = False # If True, resumes from checkpoint - resume_step: Optional[int] = None # (When `resume==True`) Step number that we are resuming from - image_aug: bool = True # If True, trains with image augmentations (HIGHLY RECOMMENDED) - diffusion_sample_freq: int = 50 # (When `use_diffusion==True`) Frequency for sampling in steps - - # LoRA - use_lora: bool = True # If True, uses LoRA fine-tuning - lora_rank: int = 32 # Rank of LoRA weight matrix - lora_dropout: float = 0.0 # Dropout applied to LoRA weights - merge_lora_during_training: bool = True # If True, merges LoRA weights and saves result during training - # Note: Merging can be very slow on some machines. If so, set to - # False and merge final checkpoint offline! - - # Regularization - regularization_lora_vector_path: str = None # Path to regularization vector - regularization_weight: float = 1e-3 # Weight of regularization loss - - # Logging - wandb_entity: str = "your-wandb-entity" # Name of WandB entity - wandb_project: str = "your-wandb-project" # Name of WandB project - run_id_note: Optional[str] = None # Extra note to add to end of run ID for logging - run_id_override: Optional[str] = None # Optional string to override the run ID with - wandb_log_freq: int = 10 # WandB logging frequency in steps - - # EMA - use_ema: bool = False # If True, maintains an EMA copy of the model - inv_gamma: float = 1 # EMA inverse gamma parameter - - # fmt: on - - -def remove_ddp_in_checkpoint(state_dict) -> dict: - """ - Removes the 'module.' prefix from parameter names in a PyTorch model state dictionary that was saved using - DistributedDataParallel (DDP). - - When a model is trained using PyTorch's DistributedDataParallel, the saved state dictionary contains parameters - prefixed with 'module.'. This function removes these prefixes to make the state dictionary compatible when - loading into models that are not yet wrapped in DDP. - - Args: - state_dict (dict): PyTorch model state dictionary. - - Returns: - dict: A new state dictionary with the same contents but with 'module.' prefixes removed from parameter names. - Parameters without the 'module.' prefix remain unchanged. - """ - new_state_dict = {} - for k, v in state_dict.items(): - if k[:7] == "module.": - new_state_dict[k[7:]] = v - else: - new_state_dict[k] = v - return new_state_dict - - -def get_run_id(cfg) -> str: - """ - Generates or retrieves an identifier string for an experiment run. - - Args: - cfg (FinetuneConfig): Training configuration. - - Returns: - str: Experiment run ID. - """ - if cfg.run_id_override is not None: - # Override the run ID with the user-provided ID - run_id = cfg.run_id_override - elif cfg.resume: - # Override run ID with the previous resumed run's ID - run_id = cfg.vla_path.split("/")[-1] - # Remove the "--XXX_chkpt" suffix from the run ID if it exists - if "chkpt" in run_id.split("--")[-1]: - run_id = "--".join(run_id.split("--")[:-1]) - else: - run_id = ( - f"{cfg.vla_path.split('/')[-1]}+{cfg.dataset_name}" - f"+b{cfg.batch_size * cfg.grad_accumulation_steps}" - f"+lr-{cfg.learning_rate}" - ) - if cfg.use_lora: - run_id += f"+lora-r{cfg.lora_rank}+dropout-{cfg.lora_dropout}" - if cfg.image_aug: - run_id += "--image_aug" - if cfg.run_id_note is not None: - run_id += f"--{cfg.run_id_note}" - return run_id - - -def load_checkpoint(module_name: str, path: str, step: int, device: str = "cpu") -> dict: - """ - Loads a checkpoint for a given module. - - Args: - module_name (str): Name of model component to load checkpoint for. - path (str): Path to checkpoint directory. - step (int): Gradient step number of saved checkpoint. - device (str): String specifying how to remap storage locations (default = "cpu"). - - Returns: - dict: PyTorch model state dictionary. - """ - checkpoint_path = os.path.join(path, f"{module_name}--{step}_checkpoint.pt") - print(f"Loading checkpoint: {checkpoint_path}") - state_dict = torch.load(checkpoint_path, weights_only=True, map_location=device) - return remove_ddp_in_checkpoint(state_dict) - - -def wrap_ddp(module: nn.Module, device_id: int, find_unused: bool = False) -> DDP: - """ - Wrap a module with DistributedDataParallel. - - Args: - module (nn.Module): PyTorch module. - device_id (str): Device ID. - find_unused (bool): Whether to detect parameters without gradients in distributed training. - - Returns: - DistributedDataParallel: PyTorch module wrapped with DDP. - """ - return DDP(module, device_ids=[device_id], find_unused_parameters=find_unused, gradient_as_bucket_view=True) - - -def count_parameters(module: nn.Module, name: str) -> None: - """ - Counts and prints the number of trainable parameters in a module. - - Args: - module (nn.Module): PyTorch module. - module_name (str): Name of model component. - - Returns: - None. - """ - num_params = sum(p.numel() for p in module.parameters() if p.requires_grad) - print(f"# trainable params in {name}: {num_params}") - - -def init_module( - module_class: Type[nn.Module], - module_name: str, - cfg: FinetuneConfig, - device_id: int, - module_args: dict, - to_bf16: bool = False, - find_unused_params: bool = False, -) -> DDP: - """ - Initializes a module, optionally loads checkpoint, moves to device, and wraps with DDP. - - Args: - module_class (Type[nn.Module]): Class of PyTorch module to initialize. - module_name (str): Name of model component to load checkpoint for. - cfg (FinetuneConfig): Training configuration. - device_id (str): Device ID. - module_args (dict): Args for initializing the module. - to_bf16 (bool): Whether to convert to torch.bfloat16 data type. - find_unused_params (bool): Whether to detect parameters without gradients in distributed training. - - Returns: - DistributedDataParallel: PyTorch module wrapped with DDP. - """ - module = module_class(**module_args) - count_parameters(module, module_name) - - if cfg.resume: - state_dict = load_checkpoint(module_name, cfg.vla_path, cfg.resume_step) - module.load_state_dict(state_dict) - - if to_bf16: - module = module.to(torch.bfloat16) - module = module.to(device_id) - - return wrap_ddp(module, device_id, find_unused_params) - - -def run_forward_pass( - vla, - action_head, - noisy_action_projector, - proprio_projector, - batch, - action_tokenizer, - device_id, - use_l1_regression, - use_diffusion, - use_proprio, - use_film, - num_patches, - compute_diffusion_l1=False, - num_diffusion_steps_train=None, -) -> Tuple[torch.Tensor, Dict[str, float]]: - """ - Compute model forward pass and metrics for both training and validation. - - Args: - vla (OpenVLAForActionPrediction): Vision-language-action policy. - action_head (nn.Module): Action head module. - noisy_action_projector (nn.Module): Noisy action projector module (only used for diffusion). - proprio_projector (nn.Module): Proprioceptive state projector module. - batch (dict): Input batch. - action_tokenizer (ActionTokenizer): Action tokenizer. - device_id (str): Device ID. - use_l1_regression (bool): Whether to use L1 regression. - use_diffusion (bool): Whether to use diffusion. - use_proprio (bool): Whether to use proprioceptive state as input. - use_film (bool): Whether to use FiLM for better language following. - num_patches (int): Number of vision patches. - compute_diffusion_l1 (bool): Whether to sample actions and compute L1 loss for diffusion (do this once every - diffusion_sample_freq steps during training; do it every batch for validation) - num_diffusion_steps_train (int): Number of diffusion steps for training (only used for diffusion). - - Returns: - tuple: (loss, metrics_dict) - loss: The loss tensor with gradient for backpropagation. - metrics_dict: Dictionary of computed metrics (detached values for logging). - """ - metrics = {} - - # Get ground-truth action labels - ground_truth_actions = batch["actions"].to(device_id).to(torch.bfloat16) - - # [Only for diffusion] Sample noisy actions used as input for noise predictor network - if use_diffusion: - noisy_dict = action_head.module.sample_noisy_actions(ground_truth_actions) - noise, noisy_actions, diffusion_timestep_embeddings = ( - noisy_dict["noise"], - noisy_dict["noisy_actions"], - noisy_dict["diffusion_timestep_embeddings"], - ) - else: - noise, noisy_actions, diffusion_timestep_embeddings = None, None, None - - # VLA forward pass - with torch.autocast("cuda", dtype=torch.bfloat16): - output: CausalLMOutputWithPast = vla( - input_ids=batch["input_ids"].to(device_id), - attention_mask=batch["attention_mask"].to(device_id), - pixel_values=batch["pixel_values"].to(torch.bfloat16).to(device_id), - labels=batch["labels"], - output_hidden_states=True, - proprio=batch["proprio"] if use_proprio else None, - proprio_projector=proprio_projector if use_proprio else None, - noisy_actions=noisy_actions if use_diffusion else None, - noisy_action_projector=noisy_action_projector if use_diffusion else None, - diffusion_timestep_embeddings=diffusion_timestep_embeddings if use_diffusion else None, - use_film=use_film, - ) - - # Get action masks needed for logging - ground_truth_token_ids = batch["labels"][:, 1:].to(device_id) - current_action_mask = get_current_action_mask(ground_truth_token_ids) - next_actions_mask = get_next_actions_mask(ground_truth_token_ids) - - # Compute metrics for discrete action representation (next-token prediction) - if not (use_l1_regression or use_diffusion): - loss = output.loss - predicted_token_ids = output.logits[:, num_patches:-1].argmax(dim=2) - curr_action_accuracy = compute_token_accuracy( - predicted_token_ids, ground_truth_token_ids, mask=current_action_mask - ) - curr_action_l1_loss = compute_actions_l1_loss( - action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask=current_action_mask - ) - next_actions_accuracy = compute_token_accuracy( - predicted_token_ids, ground_truth_token_ids, mask=next_actions_mask - ) - next_actions_l1_loss = compute_actions_l1_loss( - action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask=next_actions_mask - ) - metrics.update( - { - "loss_value": loss.item(), # Detached value for logging - "curr_action_accuracy": curr_action_accuracy.item(), - "curr_action_l1_loss": curr_action_l1_loss.item(), - "next_actions_accuracy": next_actions_accuracy.item(), - "next_actions_l1_loss": next_actions_l1_loss.item(), - } - ) - # Compute metrics for continuous action representations (L1 regression | diffusion) - else: - # Get last layer hidden states - last_hidden_states = output.hidden_states[-1] # (B, seq_len, D) - # Get hidden states for text portion of prompt+response (after the vision patches) - text_hidden_states = last_hidden_states[:, num_patches:-1] - # Get hidden states for action portion of response - batch_size = batch["input_ids"].shape[0] - actions_hidden_states = ( - text_hidden_states[current_action_mask | next_actions_mask] - .reshape(batch_size, NUM_ACTIONS_CHUNK * ACTION_DIM, -1) - .to(torch.bfloat16) - ) # (B, act_chunk_len, D) - - if use_l1_regression: - # Predict action - predicted_actions = action_head.module.predict_action(actions_hidden_states) - # Get full L1 loss - loss = torch.nn.L1Loss()(ground_truth_actions, predicted_actions) - - if use_diffusion: - # Predict noise - noise_pred = action_head.module.predict_noise(actions_hidden_states) - # Get diffusion noise prediction MSE loss - noise_pred = noise_pred.reshape(noise.shape) - loss = nn.functional.mse_loss(noise_pred, noise, reduction="mean") - - # Only sample actions and compute L1 losses if specified - if compute_diffusion_l1: - with torch.no_grad(): - predicted_actions = run_diffusion_sampling( - vla=vla, - action_head=action_head, - noisy_action_projector=noisy_action_projector, - proprio_projector=proprio_projector, - batch=batch, - batch_size=batch_size, - num_patches=num_patches, - actions_shape=ground_truth_actions.shape, - device_id=device_id, - current_action_mask=current_action_mask, - next_actions_mask=next_actions_mask, - use_proprio=use_proprio, - use_film=use_film, - ) - - metrics.update( - { - "loss_value": loss.item(), # Detached value for logging - } - ) - - # Get detailed L1 losses for logging - should_log_l1_loss = not use_diffusion or (use_diffusion and compute_diffusion_l1) - if should_log_l1_loss: - ground_truth_curr_action = ground_truth_actions[:, 0] - predicted_curr_action = predicted_actions[:, 0] - ground_truth_next_actions = ground_truth_actions[:, 1:] - predicted_next_actions = predicted_actions[:, 1:] - curr_action_l1_loss = torch.nn.L1Loss()(ground_truth_curr_action, predicted_curr_action) - next_actions_l1_loss = torch.nn.L1Loss()(ground_truth_next_actions, predicted_next_actions) - metrics.update( - { - "curr_action_l1_loss": curr_action_l1_loss.item(), - "next_actions_l1_loss": next_actions_l1_loss.item(), - } - ) - - # Return both the loss tensor (with gradients) and the metrics dictionary (with detached values) - return loss, metrics - - -def run_diffusion_sampling( - vla, - action_head, - noisy_action_projector, - proprio_projector, - batch, - batch_size, - num_patches, - actions_shape, - device_id, - current_action_mask, - next_actions_mask, - use_proprio, - use_film, -) -> torch.Tensor: - """ - Run diffusion sampling (reverse diffusion) to generate actions. - - Args: - vla (OpenVLAForActionPrediction): Vision-language-action policy. - action_head (nn.Module): Action head module. - noisy_action_projector (nn.Module): Noisy action projector module (only used for diffusion). - proprio_projector (nn.Module): Proprioceptive state projector module. - batch (dict): Input batch. - batch_size (int): Batch size. - num_patches (int): Number of vision patches. - actions_shape (tuple): Shape of ground-truth actions. - device_id (str): Device ID. - current_action_mask (torch.Tensor): Mask for current action. - next_actions_mask (torch.Tensor): Mask for next actions. - use_proprio (bool): Whether to use proprioceptive state as input. - use_film (bool): Whether to use FiLM for better language following. - - Returns: - torch.Tensor: Predicted actions. - """ - # Sample random noisy action, used as the starting point for reverse diffusion - noise = torch.randn( - size=(batch_size, NUM_ACTIONS_CHUNK, ACTION_DIM), - device=device_id, - dtype=torch.bfloat16, - ) # (B, chunk_len, action_dim) - - # Set diffusion timestep values - action_head.module.noise_scheduler.set_timesteps(action_head.module.num_diffusion_steps_train) - - # Reverse diffusion: Iteratively denoise to generate action, conditioned on observation - curr_noisy_actions = noise - for t in action_head.module.noise_scheduler.timesteps: - # Get diffusion model's noise prediction (conditioned on VLA latent embedding, current noisy action embedding, - # and diffusion timestep embedding) - timesteps = torch.Tensor([t]).repeat(batch_size).to(device_id) - diffusion_timestep_embeddings = ( - action_head.module.time_encoder(timesteps).to(curr_noisy_actions.dtype).to(curr_noisy_actions.device) - ) # (B, llm_dim) - diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) # (B, 1, llm_dim) - - with torch.autocast("cuda", dtype=torch.bfloat16): - output = vla( - input_ids=batch["input_ids"].to(device_id), - attention_mask=batch["attention_mask"].to(device_id), - pixel_values=batch["pixel_values"].to(torch.bfloat16).to(device_id), - labels=batch["labels"], - output_hidden_states=True, - proprio=batch["proprio"] if use_proprio else None, - proprio_projector=proprio_projector if use_proprio else None, - noisy_actions=curr_noisy_actions, - noisy_action_projector=noisy_action_projector, - diffusion_timestep_embeddings=diffusion_timestep_embeddings, - use_film=use_film, - ) - # Get last layer hidden states - last_hidden_states = output.hidden_states[-1] # (B, seq_len, D) - # Get hidden states for text portion of prompt+response (after the vision patches) - text_hidden_states = last_hidden_states[:, num_patches:-1] - # Get hidden states for action portion of response - actions_hidden_states = text_hidden_states[current_action_mask | next_actions_mask].reshape( - batch_size, NUM_ACTIONS_CHUNK * ACTION_DIM, -1 - ) # (B, act_chunk_len, D) - actions_hidden_states = actions_hidden_states.to(torch.bfloat16) - # Predict noise - noise_pred = action_head.module.predict_noise(actions_hidden_states) - - # Compute the action at the previous diffusion timestep: x_t -> x_{t-1} - curr_noisy_actions = action_head.module.noise_scheduler.step(noise_pred, t, curr_noisy_actions).prev_sample - - return curr_noisy_actions.reshape(actions_shape) - - -def compute_smoothened_metrics(metrics_deques) -> dict: - """ - Compute smoothened metrics from recent deques. - - Args: - metrics_deques (dict): Dictionary of deques containing recent metrics. - - Returns: - dict: Dictionary of smoothened metrics. - """ - smoothened_metrics = {} - for name, deque in metrics_deques.items(): - if deque and len(deque) > 0: - smoothened_metrics[name] = sum(deque) / len(deque) - return smoothened_metrics - - -def compute_diff_regularization_loss(model, diff_params_dict, regularization_weight=1.0): - """ - 计算模型参数和diff_path中同名参数之间的正则化loss,用于防止模型参数向diff_path参数的方向更新。 - 参考正交化loss的实现方式,计算参数之间的内积来惩罚相似性。 - - Args: - model: 模型(可能是DDP包装的) - diff_params_dict: 从diff_path加载的参数字典 - regularization_weight: 正则化权重 - - Returns: - regularization_loss: 正则化loss值 - """ - orthogonal_loss = 0. - matched_count = 0 - - # 获取模型的实际模块(如果是DDP包装的) - model_module = model.module if hasattr(model, 'module') else model - - for name, param in model_module.named_parameters(): - if "lora" in name: - if not param.requires_grad: - continue - - # 尝试匹配diff_params_dict中的同名参数 - # 需要处理可能的命名差异: - # 1. diff_path中可能没有"base_model.model."前缀 - # 2. diff_path中可能在.lora_A或.lora_B后多了一个".default" - # 例如:model中是 "xxx.lora_A.weight" - # diff中是 "xxx.lora_A.default.weight" - matched_diff_param = None - - # 首先尝试直接匹配 - if name in diff_params_dict: - import pdb; pdb.set_trace() - matched_diff_param = diff_params_dict[name] - else: - # import pdb; pdb.set_trace() - # 尝试处理".default"的差异:在.lora_A或.lora_B后添加.default - # follow o-lora只约束lora_A的参数 - if ".lora_A." in name: - name_with_default = name.replace(".lora_A.default.", ".lora_A.") - if name_with_default in diff_params_dict: - matched_diff_param = diff_params_dict[name_with_default] - # elif ".lora_B." in name: - # name_with_default = name.replace(".lora_B.default.", ".lora_B.") - # if name_with_default in diff_params_dict: - # matched_diff_param = diff_params_dict[name_with_default] - - if matched_diff_param is not None: - # print(f"匹配到参数: {name}") - # 确保参数在同一个设备上 - diff_param = matched_diff_param.to(device=param.device, dtype=param.dtype) - - # 检查形状是否匹配 - if param.shape == diff_param.shape: - # 使用detach().clone().requires_grad_()来避免DDP的重复标记问题 - # 这会创建一个新的tensor,保持梯度连接,但不会触发DDP的重复标记 - param_safe = param.clone() - diff_param_safe = diff_param.detach().clone() - - # 对于视觉模型内的多维lora参数 - param_flat = param_safe.reshape(-1) # [N] - diff_param_flat = diff_param_safe.reshape(-1) # [N] - inner_product = torch.abs((param_flat * diff_param_flat).sum()) - orthogonal_loss += inner_product - matched_count += 1 - # print(f"匹配到参数: {name} 的正则化loss: {inner_product}") - - # print(f"正则化loss: {orthogonal_loss}") - if matched_count > 0: - orthogonal_loss = orthogonal_loss * regularization_weight - else: - # 如果没有匹配的参数,返回0(需要梯度,这样在backward时不会报错) - # 但实际梯度为0,所以不会影响训练 - device = next(model_module.parameters()).device - orthogonal_loss = torch.tensor(0.0, device=device, requires_grad=True) - - return orthogonal_loss - - -def load_diff_params(diff_path, device="cpu"): - """ - 从safetensors或pth文件加载参数。 - - Args: - diff_path: 参数文件路径 - device: 加载到的设备 - - Returns: - diff_params_dict: 参数字典 - """ - diff_params_dict = {} - - if diff_path.endswith('.safetensors'): - if not SAFETENSORS_AVAILABLE: - raise ImportError("safetensors library is required to load .safetensors files") - - with safe_open(diff_path, framework="pt", device=device) as f: - for key in f.keys(): - diff_params_dict[key] = f.get_tensor(key) - else: - # 假设是pth或其他torch格式 - loaded = torch.load(diff_path, map_location=device) - if isinstance(loaded, dict): - if "state_dict" in loaded: - diff_params_dict = loaded["state_dict"] - else: - diff_params_dict = loaded - else: - diff_params_dict = loaded - - return diff_params_dict - - -def log_metrics_to_wandb(metrics, prefix, step, wandb_entity) -> None: - """ - Log metrics to Weights & Biases. - - Args: - metrics (dict): Dictionary of metrics to log - prefix (str): Prefix for metric names - step (int): Training step - wandb_entity (str): W&B entity instance - - Returns: - None. - """ - log_dict = {} - for name, value in metrics.items(): - # Map loss_value to Loss for better readability in W&B - if name == "loss_value": - log_dict[f"{prefix}/Loss"] = value - # Keep other metrics as is - else: - log_dict[f"{prefix}/{name.replace('_', ' ').title()}"] = value - wandb_entity.log(log_dict, step=step) - - -def save_training_checkpoint( - cfg, - run_dir, - log_step, - vla, - processor, - proprio_projector, - noisy_action_projector, - action_head, - train_dataset, - distributed_state, -) -> None: - """ - Save all training checkpoints including model components, LoRA adapter, and dataset statistics. - - Args: - cfg (FinetuneConfig): Training configuration. - run_dir (Path): Experiment run directory path. - log_step (int): Current logging step. - vla (OpenVLAForActionPrediction): Vision-language-action policy. - processor (PrismaticProcessor): OpenVLA inputs processor. - proprio_projector (nn.Module): Proprioceptive state projector module. - noisy_action_projector (nn.Module): Noisy action projector module (only used for diffusion). - action_head (nn.Module): Action head module. - train_dataset (RLDSDataset): Training dataset. - distributed_state (PartialState): Distributed training state. - - Returns: - None. - """ - # Determine checkpoint paths and naming - if cfg.save_latest_checkpoint_only: - checkpoint_dir = run_dir - checkpoint_name_suffix = "latest_checkpoint.pt" - else: - checkpoint_dir = run_dir / f"{log_step}_chkpt" - checkpoint_name_suffix = f"{log_step}_checkpoint.pt" - - adapter_dir = checkpoint_dir / "lora_adapter" - - # Create directories and save dataset statistics (main process only) - if distributed_state.is_main_process: - os.makedirs(checkpoint_dir, exist_ok=True) - os.makedirs(adapter_dir, exist_ok=True) - save_dataset_statistics(train_dataset.dataset_statistics, checkpoint_dir) - print(f"Saving Model Checkpoint for Step {log_step}") - - # Wait for directories to be created - dist.barrier() - - # Save model components (main process only) - if distributed_state.is_main_process: - # Save processor and LoRA adapter - processor.save_pretrained(checkpoint_dir) - vla.module.save_pretrained(adapter_dir) - - # Save other components - if cfg.use_proprio and proprio_projector is not None: - torch.save(proprio_projector.state_dict(), checkpoint_dir / f"proprio_projector--{checkpoint_name_suffix}") - - if cfg.use_diffusion and noisy_action_projector is not None: - torch.save( - noisy_action_projector.state_dict(), checkpoint_dir / f"noisy_action_projector--{checkpoint_name_suffix}" - ) - - if (cfg.use_l1_regression or cfg.use_diffusion) and action_head is not None: - torch.save(action_head.state_dict(), checkpoint_dir / f"action_head--{checkpoint_name_suffix}") - - if cfg.use_film: - # To be safe, just save the entire vision backbone (not just FiLM components) - torch.save( - vla.module.vision_backbone.state_dict(), checkpoint_dir / f"vision_backbone--{checkpoint_name_suffix}" - ) - - # Wait for model components to be saved - dist.barrier() - - # Merge LoRA weights into base model and save resulting model checkpoint - # Note: Can be very slow on some devices; if so, we recommend merging offline - if cfg.use_lora and cfg.merge_lora_during_training: - base_vla = AutoModelForVision2Seq.from_pretrained( - cfg.vla_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True - ) - merged_vla = PeftModel.from_pretrained(base_vla, adapter_dir) - merged_vla = merged_vla.merge_and_unload() - - if distributed_state.is_main_process: - merged_vla.save_pretrained(checkpoint_dir) - print(f"Saved merged model for Step {log_step} at: {checkpoint_dir}") - - # Wait for merged model to be saved - dist.barrier() - - -def run_validation( - vla, - action_head, - noisy_action_projector, - proprio_projector, - val_dataloader, - action_tokenizer, - device_id, - cfg, - num_patches, - log_step, - distributed_state, - val_time_limit, -) -> None: - """ - Compute validation set metrics for logging. - - Args: - vla (OpenVLAForActionPrediction): Vision-language-action policy. - action_head (nn.Module): Action head module. - noisy_action_projector (nn.Module): Noisy action projector module (only used for diffusion). - proprio_projector (nn.Module): Proprioceptive state projector module. - val_dataloader (DataLoader): Validation data loader. - action_tokenizer (ActionTokenizer): Action tokenizer. - device_id (str): Device ID. - cfg (FinetuneConfig): Training configuration. - num_patches (int): Number of vision patches. - log_step (int): Current logging step. - distributed_state (PartialState): Distributed training state. - val_time_limit (int): Time limit for computing validation metrics. - - Returns: - None. - """ - val_start_time = time.time() - vla.eval() - val_batches_count = 0 - - # List to store validation metrics - all_val_metrics = [] - - with torch.no_grad(): - for batch in val_dataloader: - # Always compute L1 loss for validation, even for diffusion - _, metrics = run_forward_pass( - vla=vla, - action_head=action_head, - noisy_action_projector=noisy_action_projector, - proprio_projector=proprio_projector, - batch=batch, - action_tokenizer=action_tokenizer, - device_id=device_id, - use_l1_regression=cfg.use_l1_regression, - use_diffusion=cfg.use_diffusion, - use_proprio=cfg.use_proprio, - use_film=cfg.use_film, - num_patches=num_patches, - compute_diffusion_l1=True, - num_diffusion_steps_train=cfg.num_diffusion_steps_train if cfg.use_diffusion else None, - ) - - # Add the loss value to the metrics - metrics["loss"] = metrics["loss_value"] - all_val_metrics.append(metrics) - val_batches_count += 1 - - # Cut testing on validation set short if it exceeds time limit - if time.time() - val_start_time > val_time_limit: - break - - # Compute average validation metrics - avg_val_metrics = {} - for metric_name in all_val_metrics[0].keys(): - values = [metrics[metric_name] for metrics in all_val_metrics if metric_name in metrics] - if values: - avg_val_metrics[metric_name] = sum(values) / len(values) - - # Add batch count to metrics - avg_val_metrics["val_batches_count"] = val_batches_count - - # Log validation metrics to W&B - if distributed_state.is_main_process: - log_metrics_to_wandb(avg_val_metrics, "VLA Val", log_step, wandb) - - -@draccus.wrap() -def finetune(cfg: FinetuneConfig) -> None: - """ - Fine-tunes base VLA on demonstration dataset via LoRA. - - Allows toggling different action representations (discrete vs. continuous), different learning objectives - (next-token prediction vs. L1 regression vs. diffusion), FiLM. Also allows for additional model inputs, - such as additional camera images and robot proprioceptive state. Assumes parallel action generation with - action chunking. - - Args: - cfg (FinetuneConfig): Training configuration. - - Returns: - None. - """ - assert cfg.use_lora, "Only LoRA fine-tuning is supported. Please set --use_lora=True!" - assert not (cfg.use_l1_regression and cfg.use_diffusion), ( - "Cannot do both L1 regression and diffusion. Please pick one of them!" - ) - - # Trim trailing forward slash ('/') in VLA path if it exists - cfg.vla_path = cfg.vla_path.rstrip("/") - print(f"Fine-tuning OpenVLA Model `{cfg.vla_path}` on `{cfg.dataset_name}`") - - # Get experiment run ID - run_id = get_run_id(cfg) - - # Create experiment run directory - run_dir = cfg.run_root_dir / run_id - os.makedirs(run_dir, exist_ok=True) - - # GPU setup - distributed_state = PartialState() - device_id = distributed_state.local_process_index - torch.cuda.set_device(device_id) - torch.cuda.empty_cache() - - # Initialize wandb logging - if distributed_state.is_main_process: - wandb.init(entity=cfg.wandb_entity, project=cfg.wandb_project, name=run_id, id=run_id) - - # Print detected constants - print( - "Detected constants:\n" - f"\tNUM_ACTIONS_CHUNK: {NUM_ACTIONS_CHUNK}\n" - f"\tACTION_DIM: {ACTION_DIM}\n" - f"\tPROPRIO_DIM: {PROPRIO_DIM}\n" - f"\tACTION_PROPRIO_NORMALIZATION_TYPE: {ACTION_PROPRIO_NORMALIZATION_TYPE}" - ) - - # Two options: - # (1) Base model is on Hugging Face Hub - # - Then download it and record the path to the download directory - # (2) Base model is stored locally - # - Then register model config in HF Auto Classes - # In both cases, we want to check whether any changes have been made to - # the `modeling_prismatic.py` file in this codebase; if so, we will copy - # the file to the downloaded or locally stored checkpoint directory so - # that the user's changes to the VLA class logic go into effect - if model_is_on_hf_hub(cfg.vla_path): - # Download model directly from Hugging Face Hub - vla_download_path = snapshot_download(repo_id=cfg.vla_path) - # Overwrite VLA path - cfg.vla_path = vla_download_path - else: - # Register OpenVLA model to HF Auto Classes (not needed if the model is on HF Hub) - AutoConfig.register("openvla", OpenVLAConfig) - AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor) - AutoProcessor.register(OpenVLAConfig, PrismaticProcessor) - AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction) - - # Update config.json and sync model files - if distributed_state.is_main_process: - update_auto_map(cfg.vla_path) - check_model_logic_mismatch(cfg.vla_path) - - # Wait for model files to be synced - dist.barrier() - - # Load processor and VLA - processor = AutoProcessor.from_pretrained(cfg.vla_path, trust_remote_code=True) - vla = AutoModelForVision2Seq.from_pretrained( - cfg.vla_path, - torch_dtype=torch.bfloat16, - low_cpu_mem_usage=True, - trust_remote_code=True, - ).to(device_id) - - # Set number of images in VLA input - vla.vision_backbone.set_num_images_in_input(cfg.num_images_in_input) - - # LoRA setup - if cfg.use_lora: - lora_config = LoraConfig( - r=cfg.lora_rank, - lora_alpha=min(cfg.lora_rank, 16), - lora_dropout=cfg.lora_dropout, - target_modules="all-linear", - init_lora_weights="gaussian", - ) - vla = get_peft_model(vla, lora_config) - vla.print_trainable_parameters() - - # FiLM setup - if cfg.use_film: - count_parameters(vla.vision_backbone, "vla.vision_backbone (original)") - # Wrap vision backbone with FiLM wrapper - # Important: For this, must specify `vla.model.vision_backbone` instead of just `vla.vision_backbone`, since the - # latter would cause the new wrapped backbone to be saved as a new attribute of `vla` instead of overwriting the - # original one (due to the LoRA wrapper) - vla.model.vision_backbone = FiLMedPrismaticVisionBackbone( - vision_backbone=vla.model.vision_backbone, - llm_dim=vla.llm_dim, - ) - count_parameters(vla.vision_backbone, "vla.vision_backbone (post-wrap)") - if cfg.resume: - state_dict = load_checkpoint("vision_backbone", cfg.vla_path, cfg.resume_step) - vla.model.vision_backbone.load_state_dict(state_dict) - vla.model.vision_backbone = vla.model.vision_backbone.to(device_id) - - # Wrap VLA with DDP - vla = wrap_ddp(vla, device_id, find_unused=False) - - # vla._set_static_graph() - - # If applicable, instantiate proprio projector - if cfg.use_proprio: - proprio_projector = init_module( - ProprioProjector, - "proprio_projector", - cfg, - device_id, - {"llm_dim": vla.module.llm_dim, "proprio_dim": PROPRIO_DIM}, - ) - else: - proprio_projector = None - - # If applicable, instantiate continuous action head for L1 regression - if cfg.use_l1_regression: - action_head = init_module( - L1RegressionActionHead, - "action_head", - cfg, - device_id, - {"input_dim": vla.module.llm_dim, "hidden_dim": vla.module.llm_dim, "action_dim": ACTION_DIM}, - to_bf16=True, - ) - else: - action_head = None - - # If applicable, instantiate diffusion action head and noisy action projector - if cfg.use_diffusion: - action_head = init_module( - DiffusionActionHead, - "action_head", - cfg, - device_id, - { - "input_dim": vla.module.llm_dim, - "hidden_dim": vla.module.llm_dim, - "action_dim": ACTION_DIM, - "num_diffusion_steps_train": cfg.num_diffusion_steps_train, - }, - to_bf16=True, - ) - noisy_action_projector = init_module( - NoisyActionProjector, "noisy_action_projector", cfg, device_id, {"llm_dim": vla.module.llm_dim} - ) - else: - noisy_action_projector = None - - # EMA - if cfg.use_ema: - ema_vla = EMAModel(vla, - action_head, - proprio_projector, - noisy_action_projector, - inv_gamma=cfg.inv_gamma - ) - - # Get number of vision patches - NUM_PATCHES = vla.module.vision_backbone.get_num_patches() * vla.module.vision_backbone.get_num_images_in_input() - # If we have proprio inputs, a single proprio embedding is appended to the end of the vision patch embeddings - if cfg.use_proprio: - NUM_PATCHES += 1 - # For diffusion, a single diffusion timestep embedding is appended to the end of the vision patch embeddings - if cfg.use_diffusion: - NUM_PATCHES += 1 - - diff_path = cfg.regularization_lora_vector_path # <- 改成你的 - - # Load diff parameters for regularization - diff_params_dict = {} - if diff_path and os.path.exists(diff_path): - print(f"Loading diff parameters from {diff_path}") - diff_params_dict = load_diff_params(diff_path, device="cpu") - print(f"Loaded {len(diff_params_dict)} parameters from diff_path") - else: - print(f"Warning: diff_path {diff_path} does not exist, skipping regularization loss") - - # Regularization weight (you can make this configurable via cfg if needed) - regularization_weight = cfg.regularization_weight # 可以根据需要调整这个权重 - - # Instantiate optimizer - trainable_params = [param for param in vla.parameters() if param.requires_grad] - if cfg.use_l1_regression or cfg.use_diffusion: - trainable_params += [param for param in action_head.parameters() if param.requires_grad] - if cfg.use_diffusion: - trainable_params += [param for param in noisy_action_projector.parameters() if param.requires_grad] - if cfg.use_proprio: - trainable_params += [param for param in proprio_projector.parameters() if param.requires_grad] - print(f"# total trainable params: {sum(p.numel() for p in trainable_params)}") - optimizer = AdamW(trainable_params, lr=cfg.learning_rate) - - # Record original learning rate - original_lr = optimizer.param_groups[0]["lr"] - - # Create learning rate scheduler - if cfg.scheduler == 'MultiStepLR': - scheduler = MultiStepLR( - optimizer, - milestones=[cfg.num_steps_before_decay], # Number of steps after which LR will change - gamma=0.1, # Multiplicative factor of learning rate decay - ) - elif cfg.scheduler == 'CosineAnnealingLR': - scheduler = CosineAnnealingLR( - optimizer, - T_max=cfg.max_steps, # Total number of steps for the cosine annealing - eta_min=cfg.learning_rate * 1e-3, - ) - elif cfg.scheduler == 'WarmupCosineLR': - scheduler = get_cosine_schedule_with_warmup( - optimizer, - num_warmup_steps=500, - num_training_steps=cfg.max_steps, - ) - else: - raise ValueError(f"Unsupported scheduler type: {cfg.scheduler}") - - # Create Action Tokenizer - action_tokenizer = ActionTokenizer(processor.tokenizer) - - # Load Fine-tuning Dataset =>> note that we use an RLDS-formatted dataset following Open X-Embodiment by default. - # =>> If you want to use a non-RLDS dataset (e.g., a standard PyTorch Dataset) see the following commented block. - # =>> Note that our training code does not loop over epochs because the RLDS loader does this implicitly; if using - # your own Dataset, make sure to add the appropriate logic to the training loop! - # - # --- - # from prismatic.vla.datasets import DummyDataset - # - # train_dataset = DummyDataset( - # action_tokenizer, - # processor.tokenizer, - # image_transform=processor.image_processor.apply_transform, - # prompt_builder_fn=PurePromptBuilder, - # ) - # --- - - # We assume that the model takes as input one third-person camera image and 1 or 2 optional wrist camera image(s) - use_wrist_image = cfg.num_images_in_input > 1 - - # Create training and optional validation datasets - batch_transform = RLDSBatchTransform( - action_tokenizer, - processor.tokenizer, - image_transform=processor.image_processor.apply_transform, - prompt_builder_fn=PurePromptBuilder, - use_wrist_image=use_wrist_image, - use_proprio=cfg.use_proprio, - ) - train_dataset = RLDSDataset( - cfg.data_root_dir, - cfg.dataset_name, - batch_transform, - resize_resolution=tuple(vla.module.config.image_sizes), - shuffle_buffer_size=cfg.shuffle_buffer_size, - image_aug=cfg.image_aug, - ) - if cfg.use_val_set: - val_dataset = RLDSDataset( - cfg.data_root_dir, - cfg.dataset_name, - batch_transform, - resize_resolution=tuple(vla.module.config.image_sizes), - shuffle_buffer_size=cfg.shuffle_buffer_size // 10, - image_aug=cfg.image_aug, - train=False, - ) - - # [Important] Save dataset statistics so that we can unnormalize actions during inference - if distributed_state.is_main_process: - save_dataset_statistics(train_dataset.dataset_statistics, run_dir) - - # Create collator and dataloader - collator = PaddedCollatorForActionPrediction( - processor.tokenizer.model_max_length, processor.tokenizer.pad_token_id, padding_side="right" - ) - dataloader = DataLoader( - train_dataset, - batch_size=cfg.batch_size, - sampler=None, - collate_fn=collator, - num_workers=0, # Important: Set to 0 if using RLDS, which uses its own parallelism - ) - if cfg.use_val_set: - val_batch_size = cfg.batch_size - val_dataloader = DataLoader( - val_dataset, - batch_size=val_batch_size, - sampler=None, - collate_fn=collator, - num_workers=0, # Important: Set to 0 if using RLDS, which uses its own parallelism - ) - - # Deque to store recent train metrics (used for computing smoothened metrics for gradient accumulation) - recent_metrics = { - "loss_value": deque(maxlen=cfg.grad_accumulation_steps), - "curr_action_accuracy": deque(maxlen=cfg.grad_accumulation_steps), - "curr_action_l1_loss": deque(maxlen=cfg.grad_accumulation_steps), - "next_actions_accuracy": deque(maxlen=cfg.grad_accumulation_steps), - "next_actions_l1_loss": deque(maxlen=cfg.grad_accumulation_steps), - "regularization_loss": deque(maxlen=cfg.grad_accumulation_steps), - } - - # Start training - with tqdm.tqdm(total=cfg.max_steps, leave=False) as progress: - vla.train() - optimizer.zero_grad() - for batch_idx, batch in enumerate(dataloader): - # Compute training metrics and loss - compute_diffusion_l1 = cfg.use_diffusion and batch_idx % cfg.diffusion_sample_freq == 0 - loss, metrics = run_forward_pass( - vla=vla, - action_head=action_head, - noisy_action_projector=noisy_action_projector if cfg.use_diffusion else None, - proprio_projector=proprio_projector if cfg.use_proprio else None, - batch=batch, - action_tokenizer=action_tokenizer, - device_id=device_id, - use_l1_regression=cfg.use_l1_regression, - use_diffusion=cfg.use_diffusion, - use_proprio=cfg.use_proprio, - use_film=cfg.use_film, - num_patches=NUM_PATCHES, - compute_diffusion_l1=compute_diffusion_l1, - num_diffusion_steps_train=cfg.num_diffusion_steps_train if cfg.use_diffusion else None, - ) - - # Add regularization loss if diff_params_dict is available - if diff_params_dict: - ########################### Regularization Loss ########################## - regularization_loss = compute_diff_regularization_loss( - vla, diff_params_dict, regularization_weight=regularization_weight - ) - # print(f"正则化loss: {regularization_loss}") - # print(f"主loss: {loss}") - # 这两行是用于梯度检查的 - # 保存主loss用于梯度检查 - # main_loss = loss.clone() - # reg_loss = regularization_loss.clone() - # print('loss:', loss) - # print('regularization_loss:', regularization_loss) - - # with vla.no_sync(): - # regularization_loss.backward() - - # model_module = vla.module if hasattr(vla, 'module') else vla - # reg_grads = {} - # for name, param in model_module.named_parameters(): - # if "lora_A" in name and param.requires_grad and param.grad is not None: - - # reg_grads[name] = param.grad.clone() - - - dummy_loss = 0.0 - for p in vla.parameters(): - if p.requires_grad: - dummy_loss = dummy_loss + p.sum() * 0.0 - - print('action loss:', loss) - print('regularization_loss:', regularization_loss) - print('dummy_loss:', dummy_loss) - - loss = loss + regularization_loss + dummy_loss - - - - loss.backward() - # main_grads = {} - # for name, param in model_module.named_parameters(): - # if "lora_A" in name and param.requires_grad and param.grad is not None: - - # main_grads[name] = param.grad.clone() - - # print('################################################') - # for name in main_grads.keys(): - # if name in reg_grads: - # main_grad_norm = main_grads[name].norm().item() - # reg_grad_norm = reg_grads[name].norm().item() - # combined_grad_norm = (main_grads[name] + reg_grads[name]).norm().item() - # print(f" {name}:") - # print(f" 主loss梯度norm: {main_grad_norm:.6f}") - # print(f" 正则化loss梯度norm: {reg_grad_norm:.6f}") - # print(f" 合并梯度norm: {combined_grad_norm:.6f}") - - - # print('################################################') - # # Log regularization loss - # metrics["regularization_loss"] = regularization_loss.item() - # ############################################################################# - - # # 这个if下面是用于梯度检查的 - # # 检查两个loss分别对应的梯度(在backward之前) - # if diff_params_dict and batch_idx % cfg.wandb_log_freq == 0: - # # 获取模型参数用于检查梯度 - # model_module = vla.module if hasattr(vla, 'module') else vla - - # # 先清零梯度 - # optimizer.zero_grad() - - # # 只对主loss进行backward - # main_loss_normalized = main_loss / cfg.grad_accumulation_steps - # main_loss_normalized.backward(retain_graph=True) - - # # 保存主loss的梯度 - # main_grads = {} - # for name, param in model_module.named_parameters(): - # if "lora_A" in name and param.requires_grad and param.grad is not None: - - # main_grads[name] = param.grad.clone() - - # # 清零梯度,只对正则化loss进行backward - # optimizer.zero_grad() - # reg_loss_normalized = reg_loss / cfg.grad_accumulation_steps - # reg_loss_normalized.backward(retain_graph=True) - - # # 保存正则化loss的梯度 - # reg_grads = {} - # for name, param in model_module.named_parameters(): - # if "lora_A" in name and param.requires_grad and param.grad is not None: - # reg_grads[name] = param.grad.clone() - - # # 打印梯度信息 - # print(f"\n[梯度检查] Step {batch_idx // cfg.grad_accumulation_steps}") - # sample_count = 0 - # for name in main_grads.keys(): - # if name in reg_grads: - # main_grad_norm = main_grads[name].norm().item() - # reg_grad_norm = reg_grads[name].norm().item() - # combined_grad_norm = (main_grads[name] + reg_grads[name]).norm().item() - # print(f" {name}:") - # print(f" 主loss梯度norm: {main_grad_norm:.6f}") - # print(f" 正则化loss梯度norm: {reg_grad_norm:.6f}") - # print(f" 合并梯度norm: {combined_grad_norm:.6f}") - # sample_count += 1 - # if sample_count >= 3: # 只检查前3个参数作为示例 - # break - # print() - - # # 清零梯度,准备正常的backward - # optimizer.zero_grad() - - # # Normalize loss to account for gradient accumulation - # normalized_loss = loss / cfg.grad_accumulation_steps - - # # Backward pass - # normalized_loss.backward() - - # Store recent train metrics - for metric_name, value in metrics.items(): - if metric_name in recent_metrics: - recent_metrics[metric_name].append(value) - - # Compute gradient step index - gradient_step_idx = batch_idx // cfg.grad_accumulation_steps - - # Compute smoothened train metrics - smoothened_metrics = compute_smoothened_metrics(recent_metrics) - - # Push Metrics to W&B (every wandb_log_freq gradient steps) - log_step = gradient_step_idx if not cfg.resume else cfg.resume_step + gradient_step_idx - if distributed_state.is_main_process and log_step % cfg.wandb_log_freq == 0: - log_metrics_to_wandb(smoothened_metrics, "VLA Train", log_step, wandb) - - # [If applicable] Linearly warm up learning rate from 10% to 100% of original - if cfg.lr_warmup_steps > 0: - lr_progress = min((gradient_step_idx + 1) / cfg.lr_warmup_steps, 1.0) # Cap at 1.0 - current_lr = original_lr * (0.1 + 0.9 * lr_progress) - for param_group in optimizer.param_groups: - param_group["lr"] = current_lr - - # Optimizer and LR scheduler step - if (batch_idx + 1) % cfg.grad_accumulation_steps == 0: - optimizer.step() - scheduler.step() - optimizer.zero_grad() - progress.update() - if cfg.use_ema: - ema_vla.step(vla, action_head, proprio_projector, noisy_action_projector) - - if distributed_state.is_main_process and gradient_step_idx % cfg.wandb_log_freq == 0: - # Log the learning rate - # Make sure to do this AFTER any learning rate modifications (e.g., warmup/decay) - wandb.log( - { - "VLA Train/Learning Rate": scheduler.get_last_lr()[0], - }, - step=log_step, - ) - - if cfg.use_ema: - # Log the EMA decay value - wandb.log( - { - "VLA Train/EMA Decay": ema_vla.decay, - }, - step=log_step, - ) - # Log the EMA eval loss - ema_vla.apply_shadow(vla, action_head, proprio_projector, noisy_action_projector) - with torch.no_grad(): - vla.eval() - action_head.eval() if action_head else None - _, ema_metrics = run_forward_pass( - vla=vla, - action_head=action_head, - noisy_action_projector=noisy_action_projector if cfg.use_diffusion else None, - proprio_projector=proprio_projector if cfg.use_proprio else None, - batch=batch, - action_tokenizer=action_tokenizer, - device_id=device_id, - use_l1_regression=cfg.use_l1_regression, - use_diffusion=cfg.use_diffusion, - use_proprio=cfg.use_proprio, - use_film=cfg.use_film, - num_patches=NUM_PATCHES, - compute_diffusion_l1=compute_diffusion_l1, - num_diffusion_steps_train=cfg.num_diffusion_steps_train if cfg.use_diffusion else None, - ) - ema_loss = ema_metrics['loss_value'] - vla.train() - action_head.train() if action_head else None - ema_vla.restore(vla, action_head, proprio_projector, noisy_action_projector) - wandb.log( - { - "VLA Train/EMA Loss": ema_loss, - }, - step=log_step, - ) - - # Save model checkpoint: either keep latest checkpoint only or all checkpoints - if gradient_step_idx > 0 and log_step % cfg.save_freq == 0: - save_training_checkpoint( - cfg=cfg, - run_dir=run_dir, - log_step=log_step, - vla=vla, - processor=processor, - proprio_projector=proprio_projector if cfg.use_proprio else None, - noisy_action_projector=noisy_action_projector if cfg.use_diffusion else None, - action_head=action_head if (cfg.use_l1_regression or cfg.use_diffusion) else None, - train_dataset=train_dataset, - distributed_state=distributed_state, - ) - - if cfg.use_ema: - # Also save EMA model checkpoint - ema_vla.apply_shadow(vla, action_head, proprio_projector, noisy_action_projector) - save_training_checkpoint( - cfg=cfg, - run_dir=run_dir / "ema_model", - log_step=log_step, - vla=vla, - processor=processor, - proprio_projector=proprio_projector if cfg.use_proprio else None, - noisy_action_projector=noisy_action_projector if cfg.use_diffusion else None, - action_head=action_head if (cfg.use_l1_regression or cfg.use_diffusion) else None, - train_dataset=train_dataset, - distributed_state=distributed_state, - ) - ema_vla.restore(vla, action_head, proprio_projector, noisy_action_projector) - - # Test model on validation set - if cfg.use_val_set and log_step > 0 and log_step % cfg.val_freq == 0: - run_validation( - vla=vla, - action_head=action_head, - noisy_action_projector=noisy_action_projector if cfg.use_diffusion else None, - proprio_projector=proprio_projector if cfg.use_proprio else None, - val_dataloader=val_dataloader, - action_tokenizer=action_tokenizer, - device_id=device_id, - cfg=cfg, - num_patches=NUM_PATCHES, - log_step=log_step, - distributed_state=distributed_state, - val_time_limit=cfg.val_time_limit, - ) - # Set model back to training mode after validation - vla.train() - - # Stop training when max_steps is reached - if log_step == cfg.max_steps: - print(f"Max step {cfg.max_steps} reached! Stopping training...") - break - - -if __name__ == "__main__": - finetune() diff --git a/capvector-oft/vla-scripts/merge_lora_weights_and_save.py b/capvector-oft/vla-scripts/merge_lora_weights_and_save.py deleted file mode 100644 index 78f96d4bb514a051edf8637e31cf425e3d52212b..0000000000000000000000000000000000000000 --- a/capvector-oft/vla-scripts/merge_lora_weights_and_save.py +++ /dev/null @@ -1,73 +0,0 @@ -""" -Loads a checkpoint that only has a LoRA adapter (no merged model) and merges the adapter -into the base OpenVLA model. Saves the final checkpoint in the same directory. - -Make sure to specify the correct base checkpoint when running this script. For example, -- if you fine-tuned the default OpenVLA-7B model without modifications, then `--base_checkpoint=="openvla/openvla-7b"` -- if you fine-tuned a different model or resumed fine-tuning from a different checkpoint, then specify that base checkpoint -- if you fine-tuned the default OpenVLA-7B model with modifications to `modeling_prismatic.py` (OpenVLA class definition), - then the base checkpoint path should point to the checkpoint containing the modifications - -Usage: - python vla-scripts/merge_lora_weights_and_save.py \ - --base_checkpoint openvla/openvla-7b \ - --lora_finetuned_checkpoint_dir /PATH/TO/CHECKPOINT/DIR/ -""" - -import os -import time -from dataclasses import dataclass -from pathlib import Path -from typing import Union - -import draccus -import torch -from peft import PeftModel -from transformers import AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor - -from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig -from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction -from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor - - -@dataclass -class ConvertConfig: - # fmt: off - - base_checkpoint: Union[str, Path] = "" # Base model checkpoint path/dir (either openvla/openvla-7b or whichever model you fine-tuned / resumed training from) - lora_finetuned_checkpoint_dir: Union[str, Path] = "" # Checkpoint directory containing the LoRA adapter - - # fmt: on - - -@draccus.wrap() -def main(cfg: ConvertConfig) -> None: - # Register OpenVLA model to HF Auto Classes (not needed if the model is on HF Hub) - AutoConfig.register("openvla", OpenVLAConfig) - AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor) - AutoProcessor.register(OpenVLAConfig, PrismaticProcessor) - AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction) - - # Load Model using HF AutoClasses - print(f"Loading base model: {cfg.base_checkpoint}") - vla = AutoModelForVision2Seq.from_pretrained( - cfg.base_checkpoint, - torch_dtype=torch.bfloat16, - low_cpu_mem_usage=True, - trust_remote_code=True, - ) - - # Load LoRA weights and merge into base model, then save final checkpoint - print("Merging LoRA weights into base model...") - start_time = time.time() - merged_vla = PeftModel.from_pretrained(vla, os.path.join(cfg.lora_finetuned_checkpoint_dir, "lora_adapter")).to( - "cuda" - ) - merged_vla = merged_vla.merge_and_unload() - merged_vla.save_pretrained(cfg.lora_finetuned_checkpoint_dir) - print(f"\nMerging complete! Time elapsed (sec): {time.time() - start_time}") - print(f"\nSaved merged model checkpoint at:\n{cfg.lora_finetuned_checkpoint_dir}") - - -if __name__ == "__main__": - main() diff --git a/capvector-pi05/.dockerignore b/capvector-pi05/.dockerignore deleted file mode 100644 index d773b6d053e4bc538f1aeea27ac0c63003a40271..0000000000000000000000000000000000000000 --- a/capvector-pi05/.dockerignore +++ /dev/null @@ -1,3 +0,0 @@ -.venv -checkpoints -data diff --git a/capvector-pi05/.gitignore b/capvector-pi05/.gitignore deleted file mode 100644 index fafb7302e27279ffd51d82bf31ef166f11ef4ee9..0000000000000000000000000000000000000000 --- a/capvector-pi05/.gitignore +++ /dev/null @@ -1,169 +0,0 @@ -# Data directories. -assets/ -checkpoints/ -data/ -wandb/ - -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ -cover/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -.pybuilder/ -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -# For a library or package, you might want to ignore these files since the code is -# intended to run in multiple environments; otherwise, check them in: -# .python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# poetry -# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control -#poetry.lock - -# pdm -# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. -#pdm.lock -# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it -# in version control. -# https://pdm.fming.dev/latest/usage/project/#working-with-version-control -.pdm.toml -.pdm-python -.pdm-build/ - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# pytype static type analyzer -.pytype/ - -# Cython debug symbols -cython_debug/ - -# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -.idea/ -.vscode/ diff --git a/capvector-pi05/.gitmodules b/capvector-pi05/.gitmodules deleted file mode 100644 index 611e80c319c46648142c4ec7e176ad55d72bde31..0000000000000000000000000000000000000000 --- a/capvector-pi05/.gitmodules +++ /dev/null @@ -1,6 +0,0 @@ -[submodule "third_party/aloha"] - path = third_party/aloha - url = https://github.com/Physical-Intelligence/aloha.git -[submodule "third_party/libero"] - path = third_party/libero - url = https://github.com/Lifelong-Robot-Learning/LIBERO.git diff --git a/capvector-pi05/.pre-commit-config.yaml b/capvector-pi05/.pre-commit-config.yaml deleted file mode 100644 index 28dbf2e1d9fe28d5703f9e991a4ba41e45a5152f..0000000000000000000000000000000000000000 --- a/capvector-pi05/.pre-commit-config.yaml +++ /dev/null @@ -1,16 +0,0 @@ -exclude: third_party/ - -repos: - - repo: https://github.com/astral-sh/uv-pre-commit - # uv version. - rev: 0.5.14 - hooks: - - id: uv-lock - - repo: https://github.com/astral-sh/ruff-pre-commit - # Ruff version. - rev: v0.8.6 - hooks: - # Run the linter. - - id: ruff - args: [--fix] - - id: ruff-format \ No newline at end of file diff --git a/capvector-pi05/.python-version b/capvector-pi05/.python-version deleted file mode 100644 index 902b2c90c86bce733594862f9a5893c7315b6441..0000000000000000000000000000000000000000 --- a/capvector-pi05/.python-version +++ /dev/null @@ -1 +0,0 @@ -3.11 \ No newline at end of file diff --git a/capvector-pi05/LICENSE b/capvector-pi05/LICENSE deleted file mode 100644 index 753842b6720f7980d411ecf2c78eb4ef220b9df8..0000000000000000000000000000000000000000 --- a/capvector-pi05/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. \ No newline at end of file diff --git a/capvector-pi05/README.md b/capvector-pi05/README.md deleted file mode 100644 index 8e7f0ae8e177e7df7808ccadc6054b78c03c23f1..0000000000000000000000000000000000000000 --- a/capvector-pi05/README.md +++ /dev/null @@ -1,128 +0,0 @@ -## 1. Environment Setup -We use [uv](https://docs.astral.sh/uv/) to manage Python dependencies. See the [uv installation instructions](https://docs.astral.sh/uv/getting-started/installation/) to set it up. Once uv is installed, run the following to set up the environment: - -```bash -GIT_LFS_SKIP_SMUDGE=1 uv sync -GIT_LFS_SKIP_SMUDGE=1 uv pip install -e . -cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/ -source .venv/bin/activate -``` - -NOTE: `GIT_LFS_SKIP_SMUDGE=1` is needed to pull LeRobot as a dependency. - - -## 2. Data Preparation -Here we take the real-world Aloha data as example, more detail simulation data could be refered in the [official openpi repo](https://github.com/Physical-Intelligence/openpi/). - -First, you need to collect the task-specific raw data with your own robot, and save it in the `.hdf5` format. - -Then, convert the data to LeRobot dataset format. -```bash -uv run examples/aloha_real/convert_aloha_data_to_lerobot.py --raw-dir /path/to/raw/data --repo-id / -# By default, The converted data is stored in ~/.cache/huggingface/lerobot/// -``` - - -## 3. Obtain the capability vectors and merge it to obtain $\theta_{meta}$ - -First, define your task-specific config in [config.py](src/openpi/training/config.py). And we provide an example of our real-world task [here](src/openpi/training/config.py#L776-L808). - -Then, convert a JAX model checkpoint to PyTorch format: -```bash -uv run examples/convert_jax_model_to_pytorch.py \ - --checkpoint_dir gs://openpi-assets/checkpoints/pi05_base \ - --config_name \ - --output_path checkpoints/pytorch_pi05_base -# This command will automatically download pi05_base checkpoint to ~/.cache/openpi/openpi-assets/checkpoints/pi05_base/ -# Otherwise you can download it manually and modify the --checkpoint_dir -``` - -> ⭐ If you don't use the regularization strategy, you could download the [capability-merged meta model](https://huggingface.co/haofuly/capvector_models_collection/capvector_pi05/merged_model) we provided, place it at `./checkpoints/vector_init/pi05SF-LIBEROspatial_minus_pi05-LIBEROspatial/`, and directly jump to the next [Training step](#4-training). - -Then, the capability vectors are obtained by simply conducting parameter arithmetic between two models finetuned with different strategies. Therefore, we need to prepare these two trained models, *e.g.*, [Pi0.5 on LIBERO-Spatial)](https://huggingface.co/haofuly/capvector_models_collection/capvector_pi05/pi05_baseline_30000step_spatial) and [Pi0.5-SF on LIBERO-Spatial)](https://huggingface.co/haofuly/capvector_models_collection/capvector_pi05/pi05_spatialforcing_30000step_spatial). The directory structure is as below: -``` -capvector-pi05 - ├── checkpoints - · ├── pi05-LIBEROspatial - │ ├── model.safetensors - │ └── ... - ├── pi05SF-LIBEROspatial - │ ├── model.safetensors - │ └── ... - ├── diff - ├── vector_init - · -``` - -Next, conduct parameter arithmetic between these two models: -```bash -CONFIG=pi05_capvector_aloha_place_block && \ -EXT=pi05SF-LIBEROspatial && \ -DOWN=pi05-LIBEROspatial && \ -uv run capvector/compute_param_diff.py \ - --config $CONFIG \ - --a.dir checkpoints/$EXT \ - --b.dir checkpoints/$DOWN \ - --out checkpoints/diff/${EXT}_minus_${DOWN}.pth \ - --strict-keys \ - --dtype fp32 -``` - -Finally, merge these diff parameters to obtain $\theta_{meta}: -```bash -DIFF=pi05SF-LIBEROspatial_minus_pi05-LIBEROspatial && \ -uv run capvector/apply_param_diff.py \ - --base-safetensors checkpoints/pytorch_pi05_base/model.safetensors \ - --diff-pth checkpoints/diff/${DIFF}.pth \ - --out-safetensors checkpoints/vector_init/${DIFF}/model.safetensors \ - --scale 1.0 \ - --no-strict-keys \ - --dtype fp32 \ - --device cpu -``` - - -## 4. Training -First, you need to compute the normalization statistics for the training data. -```bash -uv run scripts/compute_norm_stats.py --config-name -``` - -Finally, launch training using one of these modes: -```bash -# Single GPU training: -uv run scripts/train_regular_loss_pytorch.py --exp_name --save_interval -# Example: -uv run scripts/train_regular_loss_pytorch.py pi05_capvector_aloha_place_block --exp_name pytorch_test -uv run scripts/train_regular_loss_pytorch.py pi05_capvector_aloha_place_block --exp_name pytorch_test --overwrite # Overwrite existing checkpoints - -# Multi-GPU training (single node): -uv run torchrun --standalone --nnodes=1 --nproc_per_node= scripts/train_regular_loss_pytorch.py --exp_name - -# Multi-Node Training: -uv run torchrun \ - --nnodes= \ - --nproc_per_node= \ - --node_rank= \ - --master_addr= \ - --master_port= \ - scripts/train_regular_loss_pytorch.py --exp_name= --save_interval -``` - - -## 5. Inference -Real-world inference is executed in the server-client form. - -First, launch a model server (we use the checkpoint for iteration 20,000 for this example, modify as needed): -```bash -uv run scripts/serve_policy.py policy:checkpoint --policy.config= --policy.dir=checkpoints///20000 -``` - -This will spin up a server that listens on port 8000 and waits for observations to be sent to it. - -Then, We can then run an client robot script that queries the server. - -You need to write your client script according to your robot. A simple [client exmaple](examples/simple_client/main.py) is as below: -```bash -uv run examples/simple_client/main.py --env ALOHA -``` \ No newline at end of file diff --git a/capvector-pi05/capvector/apply_param_diff.py b/capvector-pi05/capvector/apply_param_diff.py deleted file mode 100644 index 2c5db696c642183d0df8b3c05f1282a25c880138..0000000000000000000000000000000000000000 --- a/capvector-pi05/capvector/apply_param_diff.py +++ /dev/null @@ -1,135 +0,0 @@ -import dataclasses -import logging -from pathlib import Path - -import torch -import tyro -from safetensors.torch import load_file, save_file - - -@dataclasses.dataclass -class Args: - # Base pretrained weights in safetensors - base_safetensors: str - - # Diff checkpoint in .pth (either {"state_dict": ...} or raw state_dict) - diff_pth: str - - # Output safetensors path - out_safetensors: str = "model_merged.safetensors" - - # final = base + scale * diff - scale: float = 1.0 - - # whether keys must match exactly - strict_keys: bool = True # use --strict-keys / --no-strict-keys - - # arithmetic dtype - dtype: str = "fp32" # fp32/fp16/bf16 - - # compute device - device: str = "cpu" # cpu/cuda - - -def cast(t: torch.Tensor, dtype: str) -> torch.Tensor: - if dtype == "fp32": - return t.float() - if dtype == "fp16": - return t.half() - if dtype == "bf16": - return t.bfloat16() - raise ValueError(f"Unknown dtype: {dtype}") - - -def load_diff_state_dict(path: str) -> dict[str, torch.Tensor]: - obj = torch.load(path, map_location="cpu") - if isinstance(obj, dict) and "state_dict" in obj and isinstance(obj["state_dict"], dict): - sd = obj["state_dict"] - elif isinstance(obj, dict): - sd = obj - else: - raise RuntimeError(f"Unexpected diff format: {type(obj)}") - - for k, v in sd.items(): - if not isinstance(v, torch.Tensor): - raise RuntimeError(f"Diff contains non-tensor at key={k}: {type(v)}") - return sd - - -def main(args: Args) -> None: - logging.info("Loading base safetensors: %s", args.base_safetensors) - base_sd = load_file(args.base_safetensors, device="cpu") # dict[str, Tensor] - - logging.info("Loading diff pth: %s", args.diff_pth) - diff_sd = load_diff_state_dict(args.diff_pth) - - keys_base = set(base_sd.keys()) - keys_diff = set(diff_sd.keys()) - - if args.strict_keys: - if keys_base != keys_diff: - only_base = sorted(list(keys_base - keys_diff))[:30] - only_diff = sorted(list(keys_diff - keys_base))[:30] - raise RuntimeError( - "Keys mismatch between base safetensors and diff.\n" - f"Only in base (up to 30): {only_base}\n" - f"Only in diff (up to 30): {only_diff}\n" - "Use --no-strict-keys to apply on intersection only." - ) - keys_apply = keys_base - else: - keys_apply = keys_base & keys_diff - logging.warning("Non-strict mode: applying on intersection keys: %d", len(keys_apply)) - - dev = torch.device(args.device) - - merged_sd: dict[str, torch.Tensor] = {} - applied_float = 0 - skipped_nonfloat = 0 - skipped_missing = 0 - - for k, base_t_cpu in base_sd.items(): - base_t = base_t_cpu # already on cpu - - if k not in keys_apply: - merged_sd[k] = base_t - skipped_missing += 1 - continue - - diff_t_cpu = diff_sd[k] - - if base_t.shape != diff_t_cpu.shape: - raise RuntimeError(f"Shape mismatch at key={k}: base {base_t.shape} vs diff {diff_t_cpu.shape}") - - # only add for floating-point tensors - if base_t.is_floating_point() and diff_t_cpu.is_floating_point(): - a = cast(base_t.to(dev), args.dtype) - d = cast(diff_t_cpu.to(dev), args.dtype) - out = a + args.scale * d - merged_sd[k] = out.to(base_t.dtype).detach().cpu() - applied_float += 1 - else: - merged_sd[k] = base_t - skipped_nonfloat += 1 - - out_path = Path(args.out_safetensors) - out_path.parent.mkdir(parents=True, exist_ok=True) - - # safetensors 需要所有 tensor 在 CPU - for k, v in merged_sd.items(): - if v.device.type != "cpu": - merged_sd[k] = v.cpu() - - logging.info( - "Done. applied_float=%d, skipped_nonfloat=%d, skipped_missing=%d", - applied_float, - skipped_nonfloat, - skipped_missing, - ) - logging.info("Saving merged safetensors to: %s", str(out_path)) - save_file(merged_sd, str(out_path)) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO, force=True) - main(tyro.cli(Args)) diff --git a/capvector-pi05/capvector/compute_param_diff.py b/capvector-pi05/capvector/compute_param_diff.py deleted file mode 100644 index 6d35337626d3113092581fbdc6b4261b86f6e3a4..0000000000000000000000000000000000000000 --- a/capvector-pi05/capvector/compute_param_diff.py +++ /dev/null @@ -1,142 +0,0 @@ -import dataclasses -import logging -from pathlib import Path -from typing import Any - -import torch -import tyro - -from openpi.training import config as _config - - -@dataclasses.dataclass -class CkptSpec: - dir: str - - -@dataclasses.dataclass -class Args: - config: str - a: CkptSpec - b: CkptSpec - out: str = "checkpoints/diff/a_minus_b.pth" - only_vlm: bool = False - strict_keys: bool = False - dtype: str = "fp32" - device: str = "cpu" - - -def _extract_state_dict(obj: Any) -> dict[str, torch.Tensor]: - """ - Try best to get a torch state_dict from a Policy or Module-like object. - """ - # Case 1: policy itself has state_dict() - if hasattr(obj, "state_dict") and callable(obj.state_dict): - sd = obj.state_dict() - if isinstance(sd, dict) and all(isinstance(v, torch.Tensor) for v in sd.values()): - return sd - - # Case 2: common attributes that hold torch.nn.Module - for attr in ["model", "_model", "module", "net", "_net", "policy", "_policy"]: - if hasattr(obj, attr): - m = getattr(obj, attr) - if hasattr(m, "state_dict") and callable(m.state_dict): - sd = m.state_dict() - if isinstance(sd, dict) and all(isinstance(v, torch.Tensor) for v in sd.values()): - return sd - - raise RuntimeError( - "Cannot extract state_dict. " - "Please inspect Policy object and update attribute list in _extract_state_dict()." - ) - - -def _cast_tensor(t: torch.Tensor, dtype: str) -> torch.Tensor: - if dtype == "fp32": - return t.float() - if dtype == "fp16": - return t.half() - if dtype == "bf16": - return t.bfloat16() - raise ValueError(f"Unknown dtype: {dtype}") - - -def load_model(config_name: str, spec: CkptSpec): - cfg = _config.get_config(config_name) - weight_path = Path(spec.dir) / "model.safetensors" - if not weight_path.exists(): - raise FileNotFoundError(f"Missing model.safetensors in checkpoint directory: {spec.dir}") - return cfg.model.load_pytorch(cfg, str(weight_path)) - - -def main(args: Args) -> None: - logging.info("Loading A model from %s with config %s", args.a.dir, args.config) - model_a = load_model(args.config, args.a) - logging.info("Loading B model from %s with config %s", args.b.dir, args.config) - model_b = load_model(args.config, args.b) - - sd_a = _extract_state_dict(model_a) - sd_b = _extract_state_dict(model_b) - - keys_a = set(sd_a.keys()) - keys_b = set(sd_b.keys()) - - if args.strict_keys: - if keys_a != keys_b: - only_a = sorted(list(keys_a - keys_b))[:20] - only_b = sorted(list(keys_b - keys_a))[:20] - raise RuntimeError( - f"State dict keys mismatch.\n" - f"Only in A (show up to 20): {only_a}\n" - f"Only in B (show up to 20): {only_b}\n" - f"Set --strict-keys False to subtract intersection only." - ) - keys = sorted(keys_a) - else: - keys = sorted(list(keys_a & keys_b)) - logging.warning("Non-strict mode: subtracting only intersection keys: %d", len(keys)) - - device = torch.device(args.device) - diff: dict[str, torch.Tensor] = {} - - if args.only_vlm: - ZERO_PREFIXES = [ - "paligemma_with_expert.gemma_expert.", - "action_in_proj.", - "action_out_proj.", - "action_time_mlp_in", - "action_time_mlp_oout", - ] - else: - ZERO_PREFIXES = [] - - for k in keys: - ta = sd_a[k].to(device) - tb = sd_b[k].to(device) - - if ta.shape != tb.shape: - raise RuntimeError(f"Shape mismatch at key={k}: {ta.shape} vs {tb.shape}") - - zero_this = any(k.startswith(p) for p in ZERO_PREFIXES) - - if zero_this: - out = torch.zeros_like(ta) - else: - if ta.is_floating_point(): - out = _cast_tensor(ta, args.dtype) - _cast_tensor(tb, args.dtype) - else: - out = ta - - diff[k] = out.detach().cpu() - - - - out_path = Path(args.out) - out_path.parent.mkdir(parents=True, exist_ok=True) - torch.save({"state_dict": diff, "a": dataclasses.asdict(args.a), "b": dataclasses.asdict(args.b)}, out_path) - logging.info("Saved diff checkpoint to: %s", str(out_path)) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO, force=True) - main(tyro.cli(Args)) diff --git a/capvector-pi05/docs/docker.md b/capvector-pi05/docs/docker.md deleted file mode 100644 index a66ecfa69345ddce52e93bb62628ae6006466ca9..0000000000000000000000000000000000000000 --- a/capvector-pi05/docs/docker.md +++ /dev/null @@ -1,25 +0,0 @@ -### Docker Setup - -All of the examples in this repo provide instructions for being run normally, and also using Docker. Although not required, the Docker option is recommended as this will simplify software installation, produce a more stable environment, and also allow you to avoid installing ROS and cluttering your machine, for examples which depend on ROS. - -- Basic Docker installation instructions are [here](https://docs.docker.com/engine/install/). -- Docker must be installed in [rootless mode](https://docs.docker.com/engine/security/rootless/). -- To use your GPU you must also install the [NVIDIA container toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html). -- The version of docker installed with `snap` is incompatible with the NVIDIA container toolkit, preventing it from accessing `libnvidia-ml.so` ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/154)). The snap version can be uninstalled with `sudo snap remove docker`. -- Docker Desktop is also incompatible with the NVIDIA runtime ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/229)). Docker Desktop can be uninstalled with `sudo apt remove docker-desktop`. - - -If starting from scratch and your host machine is Ubuntu 22.04, you can use accomplish all of the above with the convenience scripts `scripts/docker/install_docker_ubuntu22.sh` and `scripts/docker/install_nvidia_container_toolkit.sh`. - -Build the Docker image and start the container with the following command: -```bash -docker compose -f scripts/docker/compose.yml up --build -``` - -To build and run the Docker image for a specific example, use the following command: -```bash -docker compose -f examples//compose.yml up --build -``` -where `` is the name of the example you want to run. - -During the first run of any example, Docker will build the images. Go grab a coffee while this happens. Subsequent runs will be faster since the images are cached. \ No newline at end of file diff --git a/capvector-pi05/docs/norm_stats.md b/capvector-pi05/docs/norm_stats.md deleted file mode 100644 index c9e8811b4bc804d372f1f31ab8774b0c2ad81c1e..0000000000000000000000000000000000000000 --- a/capvector-pi05/docs/norm_stats.md +++ /dev/null @@ -1,69 +0,0 @@ -# Normalization statistics - -Following common practice, our models normalize the proprioceptive state inputs and action targets during policy training and inference. The statistics used for normalization are computed over the training data and stored alongside the model checkpoint. - -## Reloading normalization statistics - -When you fine-tune one of our models on a new dataset, you need to decide whether to (A) reuse existing normalization statistics or (B) compute new statistics over your new training data. Which option is better for you depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. Below, we list all the available pre-training normalization statistics for each model. - -**If your target robot matches one of these pre-training statistics, consider reloading the same normalization statistics.** By reloading the normalization statistics, the actions in your dataset will be more "familiar" to the model, which can lead to better performance. You can reload the normalization statistics by adding an `AssetsConfig` to your training config that points to the corresponding checkpoint directory and normalization statistics ID, like below for the `Trossen` (aka ALOHA) robot statistics of the `pi0_base` checkpoint: - -```python -TrainConfig( - ... - data=LeRobotAlohaDataConfig( - ... - assets=AssetsConfig( - assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets", - asset_id="trossen", - ), - ), -) -``` - -For an example of a full training config that reloads normalization statistics, see the `pi0_aloha_pen_uncap` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py). - -**Note:** To successfully reload normalization statistics, it's important that your robot + dataset are following the action space definitions used in pre-training. We provide a detailed description of our action space definitions below. - -**Note #2:** Whether reloading normalization statistics is beneficial depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. We recommend to always try both, reloading and training with a fresh set of statistics computed on your new dataset (see [main README](../README.md) for instructions on how to compute new statistics), and pick the one that works better for your task. - - -## Provided Pre-training Normalization Statistics - -Below is a list of all the pre-training normalization statistics we provide. We provide them for both, the `pi0_base` and `pi0_fast_base` models. For `pi0_base`, set the `assets_dir` to `gs://openpi-assets/checkpoints/pi0_base/assets` and for `pi0_fast_base`, set the `assets_dir` to `gs://openpi-assets/checkpoints/pi0_fast_base/assets`. -| Robot | Description | Asset ID | -|-------|-------------|----------| -| ALOHA | 6-DoF dual arm robot with parallel grippers | trossen | -| Mobile ALOHA | Mobile version of ALOHA mounted on a Slate base | trossen_mobile | -| Franka Emika (DROID) | 7-DoF arm with parallel gripper based on the DROID setup | droid | -| Franka Emika (non-DROID) | Franka FR3 arm with Robotiq 2F-85 gripper | franka | -| UR5e | 6-DoF UR5e arm with Robotiq 2F-85 gripper | ur5e | -| UR5e bi-manual | Bi-manual UR5e setup with Robotiq 2F-85 grippers | ur5e_dual | -| ARX | Bi-manual ARX-5 robot arm setup with parallel gripper | arx | -| ARX mobile | Mobile version of bi-manual ARX-5 robot arm setup mounted on a Slate base | arx_mobile | -| Fibocom mobile | Fibocom mobile robot with 2x ARX-5 arms | fibocom_mobile | - - -## Pi0 Model Action Space Definitions - -Out of the box, both the `pi0_base` and `pi0_fast_base` use the following action space definitions (left and right are defined looking from behind the robot towards the workspace): -``` - "dim_0:dim_5": "left arm joint angles", - "dim_6": "left arm gripper position", - "dim_7:dim_12": "right arm joint angles (for bi-manual only)", - "dim_13": "right arm gripper position (for bi-manual only)", - - # For mobile robots: - "dim_14:dim_15": "x-y base velocity (for mobile robots only)", -``` - -The proprioceptive state uses the same definitions as the action space, except for the base x-y position (the last two dimensions) for mobile robots, which we don't include in the proprioceptive state. - -For 7-DoF robots (e.g. Franka), we use the first 7 dimensions of the action space for the joint actions, and the 8th dimension for the gripper action. - -General info for Pi robots: -- Joint angles are expressed in radians, with position zero corresponding to the zero position reported by each robot's interface library, except for ALOHA, where the standard ALOHA code uses a slightly different convention (see the [ALOHA example code](../examples/aloha_real/README.md) for details). -- Gripper positions are in [0.0, 1.0], with 0.0 corresponding to fully open and 1.0 corresponding to fully closed. -- Control frequencies are either 20 Hz for UR5e and Franka, and 50 Hz for ARX and Trossen (ALOHA) arms. - -For DROID, we use the original DROID action configuration, with joint velocity actions in the first 7 dimensions and gripper actions in the 8th dimension + a control frequency of 15 Hz. diff --git a/capvector-pi05/docs/remote_inference.md b/capvector-pi05/docs/remote_inference.md deleted file mode 100644 index ffe45f6a2b60ac3b98950b922730ac6de89bdfe5..0000000000000000000000000000000000000000 --- a/capvector-pi05/docs/remote_inference.md +++ /dev/null @@ -1,71 +0,0 @@ - -# Running openpi models remotely - -We provide utilities for running openpi models remotely. This is useful for running inference on more powerful GPUs off-robot, and also helps keep the robot and policy environments separate (and e.g. avoid dependency hell with robot software). - -## Starting a remote policy server - -To start a remote policy server, you can simply run the following command: - -```bash -uv run scripts/serve_policy.py --env=[DROID | ALOHA | LIBERO] -``` - -The `env` argument specifies which $\pi_0$ checkpoint should be loaded. Under the hood, this script will execute a command like the following, which you can use to start a policy server, e.g. for checkpoints you trained yourself (here an example for the DROID environment): - -```bash -uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_fast_droid -``` - -This will start a policy server that will serve the policy specified by the `config` and `dir` arguments. The policy will be served on the specified port (default: 8000). - -## Querying the remote policy server from your robot code - -We provide a client utility with minimal dependencies that you can easily embed into any robot codebase. - -First, install the `openpi-client` package in your robot environment: - -```bash -cd $OPENPI_ROOT/packages/openpi-client -pip install -e . -``` - -Then, you can use the client to query the remote policy server from your robot code. Here's an example of how to do this: - -```python -from openpi_client import image_tools -from openpi_client import websocket_client_policy - -# Outside of episode loop, initialize the policy client. -# Point to the host and port of the policy server (localhost and 8000 are the defaults). -client = websocket_client_policy.WebsocketClientPolicy(host="localhost", port=8000) - -for step in range(num_steps): - # Inside the episode loop, construct the observation. - # Resize images on the client side to minimize bandwidth / latency. Always return images in uint8 format. - # We provide utilities for resizing images + uint8 conversion so you match the training routines. - # The typical resize_size for pre-trained pi0 models is 224. - # Note that the proprioceptive `state` can be passed unnormalized, normalization will be handled on the server side. - observation = { - "observation/image": image_tools.convert_to_uint8( - image_tools.resize_with_pad(img, 224, 224) - ), - "observation/wrist_image": image_tools.convert_to_uint8( - image_tools.resize_with_pad(wrist_img, 224, 224) - ), - "observation/state": state, - "prompt": task_instruction, - } - - # Call the policy server with the current observation. - # This returns an action chunk of shape (action_horizon, action_dim). - # Note that you typically only need to call the policy every N steps and execute steps - # from the predicted action chunk open-loop in the remaining steps. - action_chunk = client.infer(observation)["actions"] - - # Execute the actions in the environment. - ... - -``` - -Here, the `host` and `port` arguments specify the IP address and port of the remote policy server. You can also specify these as command-line arguments to your robot code, or hard-code them in your robot codebase. The `observation` is a dictionary of observations and the prompt, following the specification of the policy inputs for the policy you are serving. We have concrete examples of how to construct this dictionary for different environments in the [simple client example](../examples/simple_client/main.py). diff --git a/capvector-pi05/examples/aloha_real/Dockerfile b/capvector-pi05/examples/aloha_real/Dockerfile deleted file mode 100644 index 488655c638b15f7fcca27ac9ade4d91a4a8254f4..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/aloha_real/Dockerfile +++ /dev/null @@ -1,70 +0,0 @@ -# Dockerfile for the Aloha real environment. - -# Build the container: -# docker build . -t aloha_real -f examples/aloha_real/Dockerfile - -# Run the container: -# docker run --rm -it --network=host -v /dev:/dev -v .:/app --privileged aloha_real /bin/bash - -FROM ros:noetic-robot@sha256:7cf0b9f6546abeba308ea42cb7ad3453f3e520e1af57cdf179fe915c939674bc -SHELL ["/bin/bash", "-c"] - -ENV DEBIAN_FRONTEND=noninteractive -RUN apt-get update && \ - apt-get install -y --no-install-recommends \ - cmake \ - curl \ - libffi-dev \ - python3-rosdep \ - python3-rosinstall \ - python3-rosinstall-generator \ - whiptail \ - git \ - wget \ - openssh-client \ - ros-noetic-cv-bridge \ - ros-noetic-usb-cam \ - ros-noetic-realsense2-camera \ - keyboard-configuration - -WORKDIR /root -RUN curl 'https://raw.githubusercontent.com/Interbotix/interbotix_ros_manipulators/main/interbotix_ros_xsarms/install/amd64/xsarm_amd64_install.sh' > xsarm_amd64_install.sh -RUN chmod +x xsarm_amd64_install.sh -RUN export TZ='America/Los_Angeles' && ./xsarm_amd64_install.sh -d noetic -n - -COPY ./third_party/aloha /root/interbotix_ws/src/aloha -RUN cd /root/interbotix_ws && source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && catkin_make - -# Install python 3.10 because this ROS image comes with 3.8 -RUN mkdir /python && \ - cd /python && \ - wget https://www.python.org/ftp/python/3.10.14/Python-3.10.14.tgz && \ - tar -zxvf Python-3.10.14.tgz && \ - cd Python-3.10.14 && \ - ls -lhR && \ - ./configure --enable-optimizations && \ - make install && \ - echo 'alias python3="/usr/local/bin/python3.10"' >> ~/.bashrc && \ - echo 'alias python="/usr/local/bin/python3.10"' >> ~/.bashrc && \ - cd ~ && rm -rf /python && \ - rm -rf /var/lib/apt/lists/* - -COPY --from=ghcr.io/astral-sh/uv:0.5.6 /uv /bin/uv -ENV UV_HTTP_TIMEOUT=120 -ENV UV_LINK_MODE=copy -COPY ./examples/aloha_real/requirements.txt /tmp/requirements.txt -COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml -RUN uv pip sync --python 3.10 --system /tmp/requirements.txt /tmp/openpi-client/pyproject.toml - -ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src:/root/interbotix_ws/src/aloha/aloha_scripts:/root/interbotix_ws/src/aloha -WORKDIR /app - -# Create an entrypoint script to run the setup commands, followed by the command passed in. -RUN cat <<'EOF' > /usr/local/bin/entrypoint.sh -#!/bin/bash -source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && "$@" -EOF -RUN chmod +x /usr/local/bin/entrypoint.sh - -ENTRYPOINT ["/usr/local/bin/entrypoint.sh"] -CMD ["python3", "/app/examples/aloha_real/main.py"] diff --git a/capvector-pi05/examples/aloha_real/README.md b/capvector-pi05/examples/aloha_real/README.md deleted file mode 100644 index aadb913b21a03ad4ffde93a2a1c258e51034b82d..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/aloha_real/README.md +++ /dev/null @@ -1,126 +0,0 @@ -# Run Aloha (Real Robot) - -This example demonstrates how to run with a real robot using an [ALOHA setup](https://github.com/tonyzhaozh/aloha). See [here](../../docs/remote_inference.md) for instructions on how to load checkpoints and run inference. We list the relevant checkpoint paths for each provided fine-tuned model below. - -## Prerequisites - -This repo uses a fork of the ALOHA repo, with very minor modifications to use Realsense cameras. - -1. Follow the [hardware installation instructions](https://github.com/tonyzhaozh/aloha?tab=readme-ov-file#hardware-installation) in the ALOHA repo. -1. Modify the `third_party/aloha/aloha_scripts/realsense_publisher.py` file to use serial numbers for your cameras. - -## With Docker - -```bash -export SERVER_ARGS="--env ALOHA --default_prompt='take the toast out of the toaster'" -docker compose -f examples/aloha_real/compose.yml up --build -``` - -## Without Docker - -Terminal window 1: - -```bash -# Create virtual environment -uv venv --python 3.10 examples/aloha_real/.venv -source examples/aloha_real/.venv/bin/activate -uv pip sync examples/aloha_real/requirements.txt -uv pip install -e packages/openpi-client - -# Run the robot -python -m examples.aloha_real.main -``` - -Terminal window 2: - -```bash -roslaunch aloha ros_nodes.launch -``` - -Terminal window 3: - -```bash -uv run scripts/serve_policy.py --env ALOHA --default_prompt='take the toast out of the toaster' -``` - -## **ALOHA Checkpoint Guide** - - -The `pi0_base` model can be used in zero shot for a simple task on the ALOHA platform, and we additionally provide two example fine-tuned checkpoints, “fold the towel” and “open the tupperware and put the food on the plate,” which can perform more advanced tasks on the ALOHA. - -While we’ve found the policies to work in unseen conditions across multiple ALOHA stations, we provide some pointers here on how best to set up scenes to maximize the chance of policy success. We cover the prompts to use for the policies, objects we’ve seen it work well on, and well-represented initial state distributions. Running these policies in zero shot is still a very experimental feature, and there is no guarantee that they will work on your robot. The recommended way to use `pi0_base` is by finetuning with data from the target robot. - - ---- - -### **Toast Task** - -This task involves the robot taking two pieces of toast out of a toaster and placing them on a plate. - -- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_base` -- **Prompt**: "take the toast out of the toaster" -- **Objects needed**: Two pieces of toast, a plate, and a standard toaster. -- **Object Distribution**: - - Works on both real toast and rubber fake toast - - Compatible with standard 2-slice toasters - - Works with plates of varying colors - -### **Scene Setup Guidelines** -Screenshot 2025-01-31 at 10 06 02 PM - -- The toaster should be positioned in the top-left quadrant of the workspace. -- Both pieces of toast should start inside the toaster, with at least 1 cm of bread sticking out from the top. -- The plate should be placed roughly in the lower-center of the workspace. -- Works with both natural and synthetic lighting, but avoid making the scene too dark (e.g., don't place the setup inside an enclosed space or under a curtain). - - -### **Towel Task** - -This task involves folding a small towel (e.g., roughly the size of a hand towel) into eighths. - -- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_aloha_towel` -- **Prompt**: "fold the towel" -- **Object Distribution**: - - Works on towels of varying solid colors - - Performance is worse on heavily textured or striped towels - -### **Scene Setup Guidelines** -Screenshot 2025-01-31 at 10 01 15 PM - -- The towel should be flattened and roughly centered on the table. -- Choose a towel that does not blend in with the table surface. - - -### **Tupperware Task** - -This task involves opening a tupperware filled with food and pouring the contents onto a plate. - -- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_aloha_tupperware` -- **Prompt**: "open the tupperware and put the food on the plate" -- **Objects needed**: Tupperware, food (or food-like items), and a plate. -- **Object Distribution**: - - Works on various types of fake food (e.g., fake chicken nuggets, fries, and fried chicken). - - Compatible with tupperware of different lid colors and shapes, with best performance on square tupperware with a corner flap (see images below). - - The policy has seen plates of varying solid colors. - -### **Scene Setup Guidelines** -Screenshot 2025-01-31 at 10 02 27 PM - -- Best performance observed when both the tupperware and plate are roughly centered in the workspace. -- Positioning: - - Tupperware should be on the left. - - Plate should be on the right or bottom. - - The tupperware flap should point toward the plate. - -## Training on your own Aloha dataset - -1. Convert the dataset to the LeRobot dataset v2.0 format. - - We provide a script [convert_aloha_data_to_lerobot.py](./convert_aloha_data_to_lerobot.py) that converts the dataset to the LeRobot dataset v2.0 format. As an example we have converted the `aloha_pen_uncap_diverse_raw` dataset from the [BiPlay repo](https://huggingface.co/datasets/oier-mees/BiPlay/tree/main/aloha_pen_uncap_diverse_raw) and uploaded it to the HuggingFace Hub as [physical-intelligence/aloha_pen_uncap_diverse](https://huggingface.co/datasets/physical-intelligence/aloha_pen_uncap_diverse). - - -2. Define a training config that uses the custom dataset. - - We provide the [pi0_aloha_pen_uncap config](../../src/openpi/training/config.py) as an example. You should refer to the root [README](../../README.md) for how to run training with the new config. - -IMPORTANT: Our base checkpoint includes normalization stats from various common robot configurations. When fine-tuning a base checkpoint with a custom dataset from one of these configurations, we recommend using the corresponding normalization stats provided in the base checkpoint. In the example, this is done by specifying the trossen asset_id and a path to the pretrained checkpoint’s asset directory within the AssetsConfig. diff --git a/capvector-pi05/examples/aloha_real/compose.yml b/capvector-pi05/examples/aloha_real/compose.yml deleted file mode 100644 index bfb27cfcf9bcb8e4a3e6098abc0d0ffc394fc555..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/aloha_real/compose.yml +++ /dev/null @@ -1,66 +0,0 @@ -# Run with: -# docker compose -f examples/aloha_real/compose.yml up --build -services: - runtime: - image: aloha_real - depends_on: - - aloha_ros_nodes - - ros_master - - openpi_server - build: - context: ../.. - dockerfile: examples/aloha_real/Dockerfile - init: true - tty: true - network_mode: host - privileged: true - volumes: - - $PWD:/app - - ../../data:/data - - aloha_ros_nodes: - image: aloha_real - depends_on: - - ros_master - build: - context: ../.. - dockerfile: examples/aloha_real/Dockerfile - init: true - tty: true - network_mode: host - privileged: true - volumes: - - /dev:/dev - command: roslaunch --wait aloha ros_nodes.launch - - ros_master: - image: ros:noetic-robot - network_mode: host - privileged: true - command: - - roscore - - openpi_server: - image: openpi_server - build: - context: ../.. - dockerfile: scripts/docker/serve_policy.Dockerfile - init: true - tty: true - network_mode: host - volumes: - - $PWD:/app - - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets - environment: - - SERVER_ARGS - - OPENPI_DATA_HOME=/openpi_assets - - IS_DOCKER=true - - # Comment out this block if not running on a machine with GPUs. - deploy: - resources: - reservations: - devices: - - driver: nvidia - count: 1 - capabilities: [gpu] diff --git a/capvector-pi05/examples/aloha_real/constants.py b/capvector-pi05/examples/aloha_real/constants.py deleted file mode 100644 index 59abfb84390757b53553572d06470ae12fcf8d4d..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/aloha_real/constants.py +++ /dev/null @@ -1,71 +0,0 @@ -# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act). -# ruff: noqa - -### Task parameters - -### ALOHA fixed constants -DT = 0.001 -JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"] -START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239] - -# Left finger position limits (qpos[7]), right_finger = -1 * left_finger -MASTER_GRIPPER_POSITION_OPEN = 0.02417 -MASTER_GRIPPER_POSITION_CLOSE = 0.01244 -PUPPET_GRIPPER_POSITION_OPEN = 0.05800 -PUPPET_GRIPPER_POSITION_CLOSE = 0.01844 - -# Gripper joint limits (qpos[6]) -MASTER_GRIPPER_JOINT_OPEN = 0.3083 -MASTER_GRIPPER_JOINT_CLOSE = -0.6842 -PUPPET_GRIPPER_JOINT_OPEN = 1.4910 -PUPPET_GRIPPER_JOINT_CLOSE = -0.6213 - -############################ Helper functions ############################ - -MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / ( - MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE -) -PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / ( - PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE -) -MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = ( - lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE -) -PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = ( - lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE -) -MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x)) - -MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / ( - MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE -) -PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / ( - PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE -) -MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = ( - lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE -) -PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = ( - lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE -) -MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x)) - -MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) -PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) - -MASTER_POS2JOINT = ( - lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) - + MASTER_GRIPPER_JOINT_CLOSE -) -MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN( - (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) -) -PUPPET_POS2JOINT = ( - lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) - + PUPPET_GRIPPER_JOINT_CLOSE -) -PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN( - (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) -) - -MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2 diff --git a/capvector-pi05/examples/aloha_real/convert_aloha_data_to_lerobot.py b/capvector-pi05/examples/aloha_real/convert_aloha_data_to_lerobot.py deleted file mode 100644 index 8918663120f29b478675ef69c3c4e0e71ccca471..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/aloha_real/convert_aloha_data_to_lerobot.py +++ /dev/null @@ -1,263 +0,0 @@ -""" -Script to convert Aloha hdf5 data to the LeRobot dataset v2.0 format. - -Example usage: uv run examples/aloha_real/convert_aloha_data_to_lerobot.py --raw-dir /path/to/raw/data --repo-id / -""" - -import dataclasses -from pathlib import Path -import shutil -from typing import Literal - -import h5py -from lerobot.common.constants import HF_LEROBOT_HOME -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -import numpy as np -import torch -import tqdm -import tyro - - -@dataclasses.dataclass(frozen=True) -class DatasetConfig: - use_videos: bool = True - tolerance_s: float = 0.0001 - image_writer_processes: int = 10 - image_writer_threads: int = 5 - video_backend: str | None = None - - -DEFAULT_DATASET_CONFIG = DatasetConfig() - - -def create_empty_dataset( - repo_id: str, - robot_type: str, - cameras: list[str], - mode: Literal["video", "image"] = "video", - *, - has_velocity: bool = False, - has_effort: bool = False, - dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG, -) -> LeRobotDataset: - motors = [ - "right_waist", - "right_shoulder", - "right_elbow", - "right_forearm_roll", - "right_wrist_angle", - "right_wrist_rotate", - "right_gripper", - "left_waist", - "left_shoulder", - "left_elbow", - "left_forearm_roll", - "left_wrist_angle", - "left_wrist_rotate", - "left_gripper", - ] - - features = { - "observation.state": { - "dtype": "float32", - "shape": (len(motors),), - "names": [ - motors, - ], - }, - "action": { - "dtype": "float32", - "shape": (len(motors),), - "names": [ - motors, - ], - }, - } - - if has_velocity: - features["observation.velocity"] = { - "dtype": "float32", - "shape": (len(motors),), - "names": [ - motors, - ], - } - - if has_effort: - features["observation.effort"] = { - "dtype": "float32", - "shape": (len(motors),), - "names": [ - motors, - ], - } - - for cam in cameras: - features[f"observation.images.{cam}"] = { - "dtype": mode, - "shape": (3, 480, 640), - "names": [ - "channels", - "height", - "width", - ], - } - - if Path(HF_LEROBOT_HOME / repo_id).exists(): - shutil.rmtree(HF_LEROBOT_HOME / repo_id) - - return LeRobotDataset.create( - repo_id=repo_id, - fps=50, - robot_type=robot_type, - features=features, - use_videos=dataset_config.use_videos, - tolerance_s=dataset_config.tolerance_s, - image_writer_processes=dataset_config.image_writer_processes, - image_writer_threads=dataset_config.image_writer_threads, - video_backend=dataset_config.video_backend, - ) - - -def get_cameras(hdf5_files: list[Path]) -> list[str]: - with h5py.File(hdf5_files[0], "r") as ep: - # ignore depth channel, not currently handled - return [key for key in ep["/observations/images"].keys() if "depth" not in key] # noqa: SIM118 - - -def has_velocity(hdf5_files: list[Path]) -> bool: - with h5py.File(hdf5_files[0], "r") as ep: - return "/observations/qvel" in ep - - -def has_effort(hdf5_files: list[Path]) -> bool: - with h5py.File(hdf5_files[0], "r") as ep: - return "/observations/effort" in ep - - -def load_raw_images_per_camera(ep: h5py.File, cameras: list[str]) -> dict[str, np.ndarray]: - imgs_per_cam = {} - for camera in cameras: - uncompressed = ep[f"/observations/images/{camera}"].ndim == 4 - - if uncompressed: - # load all images in RAM - imgs_array = ep[f"/observations/images/{camera}"][:] - else: - import cv2 - - # load one compressed image after the other in RAM and uncompress - imgs_array = [] - for data in ep[f"/observations/images/{camera}"]: - imgs_array.append(cv2.cvtColor(cv2.imdecode(data, 1), cv2.COLOR_BGR2RGB)) - imgs_array = np.array(imgs_array) - - imgs_per_cam[camera] = imgs_array - return imgs_per_cam - - -def load_raw_episode_data( - ep_path: Path, - cameras: list[str], -) -> tuple[dict[str, np.ndarray], torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: - with h5py.File(ep_path, "r") as ep: - state = torch.from_numpy(ep["/observations/qpos"][:]) - action = torch.from_numpy(ep["/action"][:]) - - velocity = None - if "/observations/qvel" in ep: - velocity = torch.from_numpy(ep["/observations/qvel"][:]) - - effort = None - if "/observations/effort" in ep: - effort = torch.from_numpy(ep["/observations/effort"][:]) - - imgs_per_cam = load_raw_images_per_camera(ep, cameras) - - return imgs_per_cam, state, action, velocity, effort - - -def populate_dataset( - dataset: LeRobotDataset, - hdf5_files: list[Path], - cameras: list[str], - task: str, - episodes: list[int] | None = None, -) -> LeRobotDataset: - if episodes is None: - episodes = range(len(hdf5_files)) - - for ep_idx in tqdm.tqdm(episodes): - ep_path = hdf5_files[ep_idx] - - imgs_per_cam, state, action, velocity, effort = load_raw_episode_data(ep_path, cameras) - num_frames = state.shape[0] - - for i in range(num_frames): - frame = { - "observation.state": state[i], - "action": action[i], - "task": task, - } - - for camera, img_array in imgs_per_cam.items(): - frame[f"observation.images.{camera}"] = img_array[i] - - if velocity is not None: - frame["observation.velocity"] = velocity[i] - if effort is not None: - frame["observation.effort"] = effort[i] - - dataset.add_frame(frame) - - dataset.save_episode() - - return dataset - - -def port_aloha( - raw_dir: Path, - repo_id: str, - task: str = "DEBUG", - *, - episodes: list[int] | None = None, - push_to_hub: bool = False, - is_mobile: bool = False, - mode: Literal["video", "image"] = "image", - dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG, -): - if (HF_LEROBOT_HOME / repo_id).exists(): - shutil.rmtree(HF_LEROBOT_HOME / repo_id) - - if not raw_dir.exists(): - raise ValueError(f"Raw directory {raw_dir} does not exist. Please provide a valid path to the raw data.") - - hdf5_files = sorted(raw_dir.glob("episode_*.hdf5")) - - # Get camera names from the first episode - cameras = get_cameras(hdf5_files) - print(f"Detected cameras: {cameras}") - - dataset = create_empty_dataset( - repo_id, - robot_type="mobile_aloha" if is_mobile else "aloha", - cameras=cameras, - mode=mode, - has_effort=has_effort(hdf5_files), - has_velocity=has_velocity(hdf5_files), - dataset_config=dataset_config, - ) - dataset = populate_dataset( - dataset, - hdf5_files, - cameras=cameras, - task=task, - episodes=episodes, - ) - - if push_to_hub: - dataset.push_to_hub() - - -if __name__ == "__main__": - tyro.cli(port_aloha) diff --git a/capvector-pi05/examples/aloha_real/env.py b/capvector-pi05/examples/aloha_real/env.py deleted file mode 100644 index e583c71bcbacf18c927fb4de262ddd2df66e6d47..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/aloha_real/env.py +++ /dev/null @@ -1,57 +0,0 @@ -from typing import List, Optional # noqa: UP035 - -import einops -from openpi_client import image_tools -from openpi_client.runtime import environment as _environment -from typing_extensions import override - -from examples.aloha_real import real_env as _real_env - - -class AlohaRealEnvironment(_environment.Environment): - """An environment for an Aloha robot on real hardware.""" - - def __init__( - self, - reset_position: Optional[List[float]] = None, # noqa: UP006,UP007 - render_height: int = 224, - render_width: int = 224, - ) -> None: - self._env = _real_env.make_real_env(init_node=True, reset_position=reset_position) - self._render_height = render_height - self._render_width = render_width - - self._ts = None - - @override - def reset(self) -> None: - self._ts = self._env.reset() - - @override - def is_episode_complete(self) -> bool: - return False - - @override - def get_observation(self) -> dict: - if self._ts is None: - raise RuntimeError("Timestep is not set. Call reset() first.") - - obs = self._ts.observation - for k in list(obs["images"].keys()): - if "_depth" in k: - del obs["images"][k] - - for cam_name in obs["images"]: - img = image_tools.convert_to_uint8( - image_tools.resize_with_pad(obs["images"][cam_name], self._render_height, self._render_width) - ) - obs["images"][cam_name] = einops.rearrange(img, "h w c -> c h w") - - return { - "state": obs["qpos"], - "images": obs["images"], - } - - @override - def apply_action(self, action: dict) -> None: - self._ts = self._env.step(action["actions"]) diff --git a/capvector-pi05/examples/aloha_real/main.py b/capvector-pi05/examples/aloha_real/main.py deleted file mode 100644 index 1ceab3d65f9597125d00114305b821eaf16732be..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/aloha_real/main.py +++ /dev/null @@ -1,51 +0,0 @@ -import dataclasses -import logging - -from openpi_client import action_chunk_broker -from openpi_client import websocket_client_policy as _websocket_client_policy -from openpi_client.runtime import runtime as _runtime -from openpi_client.runtime.agents import policy_agent as _policy_agent -import tyro - -from examples.aloha_real import env as _env - - -@dataclasses.dataclass -class Args: - host: str = "0.0.0.0" - port: int = 8000 - - action_horizon: int = 25 - - num_episodes: int = 1 - max_episode_steps: int = 1000 - - -def main(args: Args) -> None: - ws_client_policy = _websocket_client_policy.WebsocketClientPolicy( - host=args.host, - port=args.port, - ) - logging.info(f"Server metadata: {ws_client_policy.get_server_metadata()}") - - metadata = ws_client_policy.get_server_metadata() - runtime = _runtime.Runtime( - environment=_env.AlohaRealEnvironment(reset_position=metadata.get("reset_pose")), - agent=_policy_agent.PolicyAgent( - policy=action_chunk_broker.ActionChunkBroker( - policy=ws_client_policy, - action_horizon=args.action_horizon, - ) - ), - subscribers=[], - max_hz=50, - num_episodes=args.num_episodes, - max_episode_steps=args.max_episode_steps, - ) - - runtime.run() - - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO, force=True) - tyro.cli(main) diff --git a/capvector-pi05/examples/aloha_real/real_env.py b/capvector-pi05/examples/aloha_real/real_env.py deleted file mode 100644 index 04e1c9c3e0a8f4af81609aaae46cb105865c952d..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/aloha_real/real_env.py +++ /dev/null @@ -1,176 +0,0 @@ -# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act). -# ruff: noqa -import collections -import time -from typing import Optional, List -import dm_env -from interbotix_xs_modules.arm import InterbotixManipulatorXS -from interbotix_xs_msgs.msg import JointSingleCommand -import numpy as np - -from examples.aloha_real import constants -from examples.aloha_real import robot_utils - -# This is the reset position that is used by the standard Aloha runtime. -DEFAULT_RESET_POSITION = [0, -0.96, 1.16, 0, -0.3, 0] - - -class RealEnv: - """ - Environment for real robot bi-manual manipulation - Action space: [left_arm_qpos (6), # absolute joint position - left_gripper_positions (1), # normalized gripper position (0: close, 1: open) - right_arm_qpos (6), # absolute joint position - right_gripper_positions (1),] # normalized gripper position (0: close, 1: open) - - Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position - left_gripper_position (1), # normalized gripper position (0: close, 1: open) - right_arm_qpos (6), # absolute joint position - right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open) - "qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad) - left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing) - right_arm_qvel (6), # absolute joint velocity (rad) - right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing) - "images": {"cam_high": (480x640x3), # h, w, c, dtype='uint8' - "cam_low": (480x640x3), # h, w, c, dtype='uint8' - "cam_left_wrist": (480x640x3), # h, w, c, dtype='uint8' - "cam_right_wrist": (480x640x3)} # h, w, c, dtype='uint8' - """ - - def __init__(self, init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True): - # reset_position = START_ARM_POSE[:6] - self._reset_position = reset_position[:6] if reset_position else DEFAULT_RESET_POSITION - - self.puppet_bot_left = InterbotixManipulatorXS( - robot_model="vx300s", - group_name="arm", - gripper_name="gripper", - robot_name="puppet_left", - init_node=init_node, - ) - self.puppet_bot_right = InterbotixManipulatorXS( - robot_model="vx300s", group_name="arm", gripper_name="gripper", robot_name="puppet_right", init_node=False - ) - if setup_robots: - self.setup_robots() - - self.recorder_left = robot_utils.Recorder("left", init_node=False) - self.recorder_right = robot_utils.Recorder("right", init_node=False) - self.image_recorder = robot_utils.ImageRecorder(init_node=False) - self.gripper_command = JointSingleCommand(name="gripper") - - def setup_robots(self): - robot_utils.setup_puppet_bot(self.puppet_bot_left) - robot_utils.setup_puppet_bot(self.puppet_bot_right) - - def get_qpos(self): - left_qpos_raw = self.recorder_left.qpos - right_qpos_raw = self.recorder_right.qpos - left_arm_qpos = left_qpos_raw[:6] - right_arm_qpos = right_qpos_raw[:6] - left_gripper_qpos = [ - constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[7]) - ] # this is position not joint - right_gripper_qpos = [ - constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[7]) - ] # this is position not joint - return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos]) - - def get_qvel(self): - left_qvel_raw = self.recorder_left.qvel - right_qvel_raw = self.recorder_right.qvel - left_arm_qvel = left_qvel_raw[:6] - right_arm_qvel = right_qvel_raw[:6] - left_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[7])] - right_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[7])] - return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel]) - - def get_effort(self): - left_effort_raw = self.recorder_left.effort - right_effort_raw = self.recorder_right.effort - left_robot_effort = left_effort_raw[:7] - right_robot_effort = right_effort_raw[:7] - return np.concatenate([left_robot_effort, right_robot_effort]) - - def get_images(self): - return self.image_recorder.get_images() - - def set_gripper_pose(self, left_gripper_desired_pos_normalized, right_gripper_desired_pos_normalized): - left_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(left_gripper_desired_pos_normalized) - self.gripper_command.cmd = left_gripper_desired_joint - self.puppet_bot_left.gripper.core.pub_single.publish(self.gripper_command) - - right_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN( - right_gripper_desired_pos_normalized - ) - self.gripper_command.cmd = right_gripper_desired_joint - self.puppet_bot_right.gripper.core.pub_single.publish(self.gripper_command) - - def _reset_joints(self): - robot_utils.move_arms( - [self.puppet_bot_left, self.puppet_bot_right], [self._reset_position, self._reset_position], move_time=1 - ) - - def _reset_gripper(self): - """Set to position mode and do position resets: first close then open. Then change back to PWM mode - - NOTE: This diverges from the original Aloha code which first opens then closes the gripper. Pi internal aloha data - was collected with the gripper starting in the open position. Leaving the grippers fully closed was also found to - increase the frequency of motor faults. - """ - robot_utils.move_grippers( - [self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_CLOSE] * 2, move_time=1 - ) - robot_utils.move_grippers( - [self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5 - ) - - def get_observation(self): - obs = collections.OrderedDict() - obs["qpos"] = self.get_qpos() - obs["qvel"] = self.get_qvel() - obs["effort"] = self.get_effort() - obs["images"] = self.get_images() - return obs - - def get_reward(self): - return 0 - - def reset(self, *, fake=False): - if not fake: - # Reboot puppet robot gripper motors - self.puppet_bot_left.dxl.robot_reboot_motors("single", "gripper", True) - self.puppet_bot_right.dxl.robot_reboot_motors("single", "gripper", True) - self._reset_joints() - self._reset_gripper() - return dm_env.TimeStep( - step_type=dm_env.StepType.FIRST, reward=self.get_reward(), discount=None, observation=self.get_observation() - ) - - def step(self, action): - state_len = int(len(action) / 2) - left_action = action[:state_len] - right_action = action[state_len:] - self.puppet_bot_left.arm.set_joint_positions(left_action[:6], blocking=False) - self.puppet_bot_right.arm.set_joint_positions(right_action[:6], blocking=False) - self.set_gripper_pose(left_action[-1], right_action[-1]) - time.sleep(constants.DT) - return dm_env.TimeStep( - step_type=dm_env.StepType.MID, reward=self.get_reward(), discount=None, observation=self.get_observation() - ) - - -def get_action(master_bot_left, master_bot_right): - action = np.zeros(14) # 6 joint + 1 gripper, for two arms - # Arm actions - action[:6] = master_bot_left.dxl.joint_states.position[:6] - action[7 : 7 + 6] = master_bot_right.dxl.joint_states.position[:6] - # Gripper actions - action[6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_left.dxl.joint_states.position[6]) - action[7 + 6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_right.dxl.joint_states.position[6]) - - return action - - -def make_real_env(init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True) -> RealEnv: - return RealEnv(init_node, reset_position=reset_position, setup_robots=setup_robots) diff --git a/capvector-pi05/examples/aloha_real/requirements.in b/capvector-pi05/examples/aloha_real/requirements.in deleted file mode 100644 index 1763156a3a698aa6d78d6f35ef93b3e5f0b7e177..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/aloha_real/requirements.in +++ /dev/null @@ -1,18 +0,0 @@ -Pillow -dm_control -einops -h5py -matplotlib -modern_robotics -msgpack -numpy>=1.22.4,<2.0.0 -opencv-python -packaging -pexpect -pyquaternion -pyrealsense2 -pyyaml -requests -rospkg -tyro -websockets diff --git a/capvector-pi05/examples/aloha_real/requirements.txt b/capvector-pi05/examples/aloha_real/requirements.txt deleted file mode 100644 index b2c3d38cdbb268f62e40e7ba8820ef0b883a2d11..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/aloha_real/requirements.txt +++ /dev/null @@ -1,156 +0,0 @@ -# This file was autogenerated by uv via the following command: -# uv pip compile examples/aloha_real/requirements.in -o examples/aloha_real/requirements.txt --python-version 3.10 -absl-py==2.1.0 - # via - # dm-control - # dm-env - # labmaze - # mujoco -catkin-pkg==1.0.0 - # via rospkg -certifi==2024.8.30 - # via requests -charset-normalizer==3.4.0 - # via requests -contourpy==1.1.1 - # via matplotlib -cycler==0.12.1 - # via matplotlib -distro==1.9.0 - # via rospkg -dm-control==1.0.23 - # via -r examples/aloha_real/requirements.in -dm-env==1.6 - # via dm-control -dm-tree==0.1.8 - # via - # dm-control - # dm-env -docstring-parser==0.16 - # via tyro -docutils==0.20.1 - # via catkin-pkg -einops==0.8.0 - # via -r examples/aloha_real/requirements.in -etils==1.3.0 - # via mujoco -fonttools==4.55.2 - # via matplotlib -glfw==2.8.0 - # via - # dm-control - # mujoco -h5py==3.11.0 - # via -r examples/aloha_real/requirements.in -idna==3.10 - # via requests -importlib-resources==6.4.5 - # via etils -kiwisolver==1.4.7 - # via matplotlib -labmaze==1.0.6 - # via dm-control -lxml==5.3.0 - # via dm-control -markdown-it-py==3.0.0 - # via rich -matplotlib==3.7.5 - # via -r examples/aloha_real/requirements.in -mdurl==0.1.2 - # via markdown-it-py -modern-robotics==1.1.1 - # via -r examples/aloha_real/requirements.in -msgpack==1.1.0 - # via -r examples/aloha_real/requirements.in -mujoco==3.2.3 - # via dm-control -numpy==1.24.4 - # via - # -r examples/aloha_real/requirements.in - # contourpy - # dm-control - # dm-env - # h5py - # labmaze - # matplotlib - # modern-robotics - # mujoco - # opencv-python - # pyquaternion - # scipy -opencv-python==4.10.0.84 - # via -r examples/aloha_real/requirements.in -packaging==24.2 - # via - # -r examples/aloha_real/requirements.in - # matplotlib -pexpect==4.9.0 - # via -r examples/aloha_real/requirements.in -pillow==10.4.0 - # via - # -r examples/aloha_real/requirements.in - # matplotlib -protobuf==5.29.1 - # via dm-control -ptyprocess==0.7.0 - # via pexpect -pygments==2.18.0 - # via rich -pyopengl==3.1.7 - # via - # dm-control - # mujoco -pyparsing==3.1.4 - # via - # catkin-pkg - # dm-control - # matplotlib -pyquaternion==0.9.9 - # via -r examples/aloha_real/requirements.in -pyrealsense2==2.55.1.6486 - # via -r examples/aloha_real/requirements.in -python-dateutil==2.9.0.post0 - # via - # catkin-pkg - # matplotlib -pyyaml==6.0.2 - # via - # -r examples/aloha_real/requirements.in - # rospkg -requests==2.32.3 - # via - # -r examples/aloha_real/requirements.in - # dm-control -rich==13.9.4 - # via tyro -rospkg==1.5.1 - # via -r examples/aloha_real/requirements.in -scipy==1.10.1 - # via dm-control -setuptools==75.3.0 - # via - # catkin-pkg - # dm-control - # labmaze -shtab==1.7.1 - # via tyro -six==1.17.0 - # via python-dateutil -tqdm==4.67.1 - # via dm-control -typeguard==4.4.0 - # via tyro -typing-extensions==4.12.2 - # via - # etils - # rich - # typeguard - # tyro -tyro==0.9.2 - # via -r examples/aloha_real/requirements.in -urllib3==2.2.3 - # via requests -websockets==14.1 - # via -r examples/aloha_real/requirements.in -zipp==3.20.2 - # via etils diff --git a/capvector-pi05/examples/aloha_real/robot_utils.py b/capvector-pi05/examples/aloha_real/robot_utils.py deleted file mode 100644 index 62e6d37b57b4e253876bd60d9e66bd620800d2cc..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/aloha_real/robot_utils.py +++ /dev/null @@ -1,275 +0,0 @@ -# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act). -# ruff: noqa -from collections import deque -import datetime -import json -import time - -from aloha.msg import RGBGrayscaleImage -from cv_bridge import CvBridge -from interbotix_xs_msgs.msg import JointGroupCommand -from interbotix_xs_msgs.msg import JointSingleCommand -import numpy as np -import rospy -from sensor_msgs.msg import JointState - -from examples.aloha_real import constants - - -class ImageRecorder: - def __init__(self, init_node=True, is_debug=False): - self.is_debug = is_debug - self.bridge = CvBridge() - self.camera_names = ["cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist"] - - if init_node: - rospy.init_node("image_recorder", anonymous=True) - for cam_name in self.camera_names: - setattr(self, f"{cam_name}_rgb_image", None) - setattr(self, f"{cam_name}_depth_image", None) - setattr(self, f"{cam_name}_timestamp", 0.0) - if cam_name == "cam_high": - callback_func = self.image_cb_cam_high - elif cam_name == "cam_low": - callback_func = self.image_cb_cam_low - elif cam_name == "cam_left_wrist": - callback_func = self.image_cb_cam_left_wrist - elif cam_name == "cam_right_wrist": - callback_func = self.image_cb_cam_right_wrist - else: - raise NotImplementedError - rospy.Subscriber(f"/{cam_name}", RGBGrayscaleImage, callback_func) - if self.is_debug: - setattr(self, f"{cam_name}_timestamps", deque(maxlen=50)) - - self.cam_last_timestamps = {cam_name: 0.0 for cam_name in self.camera_names} - time.sleep(0.5) - - def image_cb(self, cam_name, data): - setattr( - self, - f"{cam_name}_rgb_image", - self.bridge.imgmsg_to_cv2(data.images[0], desired_encoding="bgr8"), - ) - # setattr( - # self, - # f"{cam_name}_depth_image", - # self.bridge.imgmsg_to_cv2(data.images[1], desired_encoding="mono16"), - # ) - setattr( - self, - f"{cam_name}_timestamp", - data.header.stamp.secs + data.header.stamp.nsecs * 1e-9, - ) - # setattr(self, f'{cam_name}_secs', data.images[0].header.stamp.secs) - # setattr(self, f'{cam_name}_nsecs', data.images[0].header.stamp.nsecs) - # cv2.imwrite('/home/lucyshi/Desktop/sample.jpg', cv_image) - if self.is_debug: - getattr(self, f"{cam_name}_timestamps").append( - data.images[0].header.stamp.secs + data.images[0].header.stamp.nsecs * 1e-9 - ) - - def image_cb_cam_high(self, data): - cam_name = "cam_high" - return self.image_cb(cam_name, data) - - def image_cb_cam_low(self, data): - cam_name = "cam_low" - return self.image_cb(cam_name, data) - - def image_cb_cam_left_wrist(self, data): - cam_name = "cam_left_wrist" - return self.image_cb(cam_name, data) - - def image_cb_cam_right_wrist(self, data): - cam_name = "cam_right_wrist" - return self.image_cb(cam_name, data) - - def get_images(self): - image_dict = {} - for cam_name in self.camera_names: - while getattr(self, f"{cam_name}_timestamp") <= self.cam_last_timestamps[cam_name]: - time.sleep(0.00001) - rgb_image = getattr(self, f"{cam_name}_rgb_image") - depth_image = getattr(self, f"{cam_name}_depth_image") - self.cam_last_timestamps[cam_name] = getattr(self, f"{cam_name}_timestamp") - image_dict[cam_name] = rgb_image - image_dict[f"{cam_name}_depth"] = depth_image - return image_dict - - def print_diagnostics(self): - def dt_helper(l): - l = np.array(l) - diff = l[1:] - l[:-1] - return np.mean(diff) - - for cam_name in self.camera_names: - image_freq = 1 / dt_helper(getattr(self, f"{cam_name}_timestamps")) - print(f"{cam_name} {image_freq=:.2f}") - print() - - -class Recorder: - def __init__(self, side, init_node=True, is_debug=False): - self.secs = None - self.nsecs = None - self.qpos = None - self.effort = None - self.arm_command = None - self.gripper_command = None - self.is_debug = is_debug - - if init_node: - rospy.init_node("recorder", anonymous=True) - rospy.Subscriber(f"/puppet_{side}/joint_states", JointState, self.puppet_state_cb) - rospy.Subscriber( - f"/puppet_{side}/commands/joint_group", - JointGroupCommand, - self.puppet_arm_commands_cb, - ) - rospy.Subscriber( - f"/puppet_{side}/commands/joint_single", - JointSingleCommand, - self.puppet_gripper_commands_cb, - ) - if self.is_debug: - self.joint_timestamps = deque(maxlen=50) - self.arm_command_timestamps = deque(maxlen=50) - self.gripper_command_timestamps = deque(maxlen=50) - time.sleep(0.1) - - def puppet_state_cb(self, data): - self.qpos = data.position - self.qvel = data.velocity - self.effort = data.effort - self.data = data - if self.is_debug: - self.joint_timestamps.append(time.time()) - - def puppet_arm_commands_cb(self, data): - self.arm_command = data.cmd - if self.is_debug: - self.arm_command_timestamps.append(time.time()) - - def puppet_gripper_commands_cb(self, data): - self.gripper_command = data.cmd - if self.is_debug: - self.gripper_command_timestamps.append(time.time()) - - def print_diagnostics(self): - def dt_helper(l): - l = np.array(l) - diff = l[1:] - l[:-1] - return np.mean(diff) - - joint_freq = 1 / dt_helper(self.joint_timestamps) - arm_command_freq = 1 / dt_helper(self.arm_command_timestamps) - gripper_command_freq = 1 / dt_helper(self.gripper_command_timestamps) - - print(f"{joint_freq=:.2f}\n{arm_command_freq=:.2f}\n{gripper_command_freq=:.2f}\n") - - -def get_arm_joint_positions(bot): - return bot.arm.core.joint_states.position[:6] - - -def get_arm_gripper_positions(bot): - return bot.gripper.core.joint_states.position[6] - - -def move_arms(bot_list, target_pose_list, move_time=1): - num_steps = int(move_time / constants.DT) - curr_pose_list = [get_arm_joint_positions(bot) for bot in bot_list] - traj_list = [ - np.linspace(curr_pose, target_pose, num_steps) - for curr_pose, target_pose in zip(curr_pose_list, target_pose_list) - ] - for t in range(num_steps): - for bot_id, bot in enumerate(bot_list): - bot.arm.set_joint_positions(traj_list[bot_id][t], blocking=False) - time.sleep(constants.DT) - - -def move_grippers(bot_list, target_pose_list, move_time): - print(f"Moving grippers to {target_pose_list=}") - gripper_command = JointSingleCommand(name="gripper") - num_steps = int(move_time / constants.DT) - curr_pose_list = [get_arm_gripper_positions(bot) for bot in bot_list] - traj_list = [ - np.linspace(curr_pose, target_pose, num_steps) - for curr_pose, target_pose in zip(curr_pose_list, target_pose_list) - ] - - with open(f"/data/gripper_traj_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl", "a") as f: - for t in range(num_steps): - d = {} - for bot_id, bot in enumerate(bot_list): - gripper_command.cmd = traj_list[bot_id][t] - bot.gripper.core.pub_single.publish(gripper_command) - d[bot_id] = {"obs": get_arm_gripper_positions(bot), "act": traj_list[bot_id][t]} - f.write(json.dumps(d) + "\n") - time.sleep(constants.DT) - - -def setup_puppet_bot(bot): - bot.dxl.robot_reboot_motors("single", "gripper", True) - bot.dxl.robot_set_operating_modes("group", "arm", "position") - bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position") - torque_on(bot) - - -def setup_master_bot(bot): - bot.dxl.robot_set_operating_modes("group", "arm", "pwm") - bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position") - torque_off(bot) - - -def set_standard_pid_gains(bot): - bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 800) - bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0) - - -def set_low_pid_gains(bot): - bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 100) - bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0) - - -def torque_off(bot): - bot.dxl.robot_torque_enable("group", "arm", False) - bot.dxl.robot_torque_enable("single", "gripper", False) - - -def torque_on(bot): - bot.dxl.robot_torque_enable("group", "arm", True) - bot.dxl.robot_torque_enable("single", "gripper", True) - - -# for DAgger -def sync_puppet_to_master(master_bot_left, master_bot_right, puppet_bot_left, puppet_bot_right): - print("\nSyncing!") - - # activate master arms - torque_on(master_bot_left) - torque_on(master_bot_right) - - # get puppet arm positions - puppet_left_qpos = get_arm_joint_positions(puppet_bot_left) - puppet_right_qpos = get_arm_joint_positions(puppet_bot_right) - - # get puppet gripper positions - puppet_left_gripper = get_arm_gripper_positions(puppet_bot_left) - puppet_right_gripper = get_arm_gripper_positions(puppet_bot_right) - - # move master arms to puppet positions - move_arms( - [master_bot_left, master_bot_right], - [puppet_left_qpos, puppet_right_qpos], - move_time=1, - ) - - # move master grippers to puppet positions - move_grippers( - [master_bot_left, master_bot_right], - [puppet_left_gripper, puppet_right_gripper], - move_time=1, - ) diff --git a/capvector-pi05/examples/aloha_real/video_display.py b/capvector-pi05/examples/aloha_real/video_display.py deleted file mode 100644 index 97d9564d109bdc50196f3bc605859f07386d304a..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/aloha_real/video_display.py +++ /dev/null @@ -1,36 +0,0 @@ -import matplotlib.pyplot as plt -import numpy as np -from openpi_client.runtime import subscriber as _subscriber -from typing_extensions import override - - -class VideoDisplay(_subscriber.Subscriber): - """Displays video frames.""" - - def __init__(self) -> None: - self._ax: plt.Axes | None = None - self._plt_img: plt.Image | None = None - - @override - def on_episode_start(self) -> None: - plt.ion() - self._ax = plt.subplot() - self._plt_img = None - - @override - def on_step(self, observation: dict, action: dict) -> None: - assert self._ax is not None - - im = observation["image"][0] # [C, H, W] - im = np.transpose(im, (1, 2, 0)) # [H, W, C] - - if self._plt_img is None: - self._plt_img = self._ax.imshow(im) - else: - self._plt_img.set_data(im) - plt.pause(0.001) - - @override - def on_episode_end(self) -> None: - plt.ioff() - plt.close() diff --git a/capvector-pi05/examples/aloha_sim/Dockerfile b/capvector-pi05/examples/aloha_sim/Dockerfile deleted file mode 100644 index d38fa0e853a15f45ac5aa7ce992fd39751131354..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/aloha_sim/Dockerfile +++ /dev/null @@ -1,41 +0,0 @@ -# Dockerfile for the Aloha simulation environment. - -# Build the container: -# docker build . -t aloha_sim -f examples/aloha_sim/Dockerfile - -# Run the container: -# docker run --rm -it --network=host -v .:/app aloha_sim /bin/bash - -FROM python:3.11-slim@sha256:370c586a6ffc8c619e6d652f81c094b34b14b8f2fb9251f092de23f16e299b78 -COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/ - -RUN apt-get update && \ - apt-get install -y \ - libosmesa6-dev \ - libgl1-mesa-glx \ - libglew-dev \ - libglfw3-dev \ - libgles2-mesa-dev -ENV MUJOCO_GL=egl - -WORKDIR /app - -# Copy from the cache instead of linking since it's a mounted volume -ENV UV_LINK_MODE=copy - -# Write the virtual environment outside of the project directory so it doesn't -# leak out of the container when we mount the application code. -ENV UV_PROJECT_ENVIRONMENT=/.venv - -# Copy the requirements files so we can install dependencies. -# The rest of the project is mounted as a volume, so we don't need to rebuild on changes. -# This strategy is best for development-style usage. -COPY ./examples/aloha_sim/requirements.txt /tmp/requirements.txt -COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml - -# Install python dependencies. -RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT -RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml -ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src - -CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/aloha_sim/main.py"] \ No newline at end of file diff --git a/capvector-pi05/examples/aloha_sim/README.md b/capvector-pi05/examples/aloha_sim/README.md deleted file mode 100644 index 2430e2bea8cbf27bda65bebae02adfff0702cbb7..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/aloha_sim/README.md +++ /dev/null @@ -1,36 +0,0 @@ -# Run Aloha Sim - -## With Docker - -```bash -export SERVER_ARGS="--env ALOHA_SIM" -docker compose -f examples/aloha_sim/compose.yml up --build -``` - -## Without Docker - -Terminal window 1: - -```bash -# Create virtual environment -uv venv --python 3.10 examples/aloha_sim/.venv -source examples/aloha_sim/.venv/bin/activate -uv pip sync examples/aloha_sim/requirements.txt -uv pip install -e packages/openpi-client - -# Run the simulation -MUJOCO_GL=egl python examples/aloha_sim/main.py -``` - -Note: If you are seeing EGL errors, you may need to install the following dependencies: - -```bash -sudo apt-get install -y libegl1-mesa-dev libgles2-mesa-dev -``` - -Terminal window 2: - -```bash -# Run the server -uv run scripts/serve_policy.py --env ALOHA_SIM -``` diff --git a/capvector-pi05/examples/aloha_sim/compose.yml b/capvector-pi05/examples/aloha_sim/compose.yml deleted file mode 100644 index 5e13b66682fe22d9747d1343b48b5f541aa8f24a..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/aloha_sim/compose.yml +++ /dev/null @@ -1,42 +0,0 @@ -# Run with: -# docker compose -f examples/aloha_sim/compose.yml up --build -services: - runtime: - image: aloha_sim - depends_on: - - openpi_server - build: - context: ../.. - dockerfile: examples/aloha_sim/Dockerfile - init: true - tty: true - network_mode: host - privileged: true - volumes: - - $PWD:/app - - ../../data:/data - - openpi_server: - image: openpi_server - build: - context: ../.. - dockerfile: scripts/docker/serve_policy.Dockerfile - init: true - tty: true - network_mode: host - volumes: - - $PWD:/app - - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets - environment: - - SERVER_ARGS - - OPENPI_DATA_HOME=/openpi_assets - - IS_DOCKER=true - - # Comment out this block if not running on a machine with GPUs. - deploy: - resources: - reservations: - devices: - - driver: nvidia - count: 1 - capabilities: [gpu] diff --git a/capvector-pi05/examples/aloha_sim/env.py b/capvector-pi05/examples/aloha_sim/env.py deleted file mode 100644 index 3c0edb5905d16a44f1907ee20f44936a7bc23f6b..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/aloha_sim/env.py +++ /dev/null @@ -1,56 +0,0 @@ -import gym_aloha # noqa: F401 -import gymnasium -import numpy as np -from openpi_client import image_tools -from openpi_client.runtime import environment as _environment -from typing_extensions import override - - -class AlohaSimEnvironment(_environment.Environment): - """An environment for an Aloha robot in simulation.""" - - def __init__(self, task: str, obs_type: str = "pixels_agent_pos", seed: int = 0) -> None: - np.random.seed(seed) - self._rng = np.random.default_rng(seed) - - self._gym = gymnasium.make(task, obs_type=obs_type) - - self._last_obs = None - self._done = True - self._episode_reward = 0.0 - - @override - def reset(self) -> None: - gym_obs, _ = self._gym.reset(seed=int(self._rng.integers(2**32 - 1))) - self._last_obs = self._convert_observation(gym_obs) # type: ignore - self._done = False - self._episode_reward = 0.0 - - @override - def is_episode_complete(self) -> bool: - return self._done - - @override - def get_observation(self) -> dict: - if self._last_obs is None: - raise RuntimeError("Observation is not set. Call reset() first.") - - return self._last_obs # type: ignore - - @override - def apply_action(self, action: dict) -> None: - gym_obs, reward, terminated, truncated, info = self._gym.step(action["actions"]) - self._last_obs = self._convert_observation(gym_obs) # type: ignore - self._done = terminated or truncated - self._episode_reward = max(self._episode_reward, reward) - - def _convert_observation(self, gym_obs: dict) -> dict: - img = gym_obs["pixels"]["top"] - img = image_tools.convert_to_uint8(image_tools.resize_with_pad(img, 224, 224)) - # Convert axis order from [H, W, C] --> [C, H, W] - img = np.transpose(img, (2, 0, 1)) - - return { - "state": gym_obs["agent_pos"], - "images": {"cam_high": img}, - } diff --git a/capvector-pi05/examples/aloha_sim/main.py b/capvector-pi05/examples/aloha_sim/main.py deleted file mode 100644 index 9d5a56becfe9247010c893d87c5dc4d8114f3b77..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/aloha_sim/main.py +++ /dev/null @@ -1,55 +0,0 @@ -import dataclasses -import logging -import pathlib - -import env as _env -from openpi_client import action_chunk_broker -from openpi_client import websocket_client_policy as _websocket_client_policy -from openpi_client.runtime import runtime as _runtime -from openpi_client.runtime.agents import policy_agent as _policy_agent -import saver as _saver -import tyro - - -@dataclasses.dataclass -class Args: - out_dir: pathlib.Path = pathlib.Path("data/aloha_sim/videos") - - task: str = "gym_aloha/AlohaTransferCube-v0" - seed: int = 0 - - action_horizon: int = 10 - - host: str = "0.0.0.0" - port: int = 8000 - - display: bool = False - - -def main(args: Args) -> None: - runtime = _runtime.Runtime( - environment=_env.AlohaSimEnvironment( - task=args.task, - seed=args.seed, - ), - agent=_policy_agent.PolicyAgent( - policy=action_chunk_broker.ActionChunkBroker( - policy=_websocket_client_policy.WebsocketClientPolicy( - host=args.host, - port=args.port, - ), - action_horizon=args.action_horizon, - ) - ), - subscribers=[ - _saver.VideoSaver(args.out_dir), - ], - max_hz=50, - ) - - runtime.run() - - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO, force=True) - tyro.cli(main) diff --git a/capvector-pi05/examples/aloha_sim/requirements.in b/capvector-pi05/examples/aloha_sim/requirements.in deleted file mode 100644 index ab7257ca737cf51ab2673924121f633937e230e9..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/aloha_sim/requirements.in +++ /dev/null @@ -1,8 +0,0 @@ -gym-aloha -imageio -matplotlib -msgpack -numpy>=1.22.4,<2.0.0 -typing-extensions -tyro -websockets \ No newline at end of file diff --git a/capvector-pi05/examples/aloha_sim/requirements.txt b/capvector-pi05/examples/aloha_sim/requirements.txt deleted file mode 100644 index 99ff9491dc16f3e6cde1f4c3652b1a181ea5b1a1..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/aloha_sim/requirements.txt +++ /dev/null @@ -1,132 +0,0 @@ -# This file was autogenerated by uv via the following command: -# uv pip compile examples/aloha_sim/requirements.in -o examples/aloha_sim/requirements.txt --python-version 3.10 -absl-py==2.1.0 - # via - # dm-control - # dm-env - # labmaze - # mujoco -certifi==2024.8.30 - # via requests -charset-normalizer==3.4.0 - # via requests -cloudpickle==3.1.0 - # via gymnasium -contourpy==1.3.1 - # via matplotlib -cycler==0.12.1 - # via matplotlib -dm-control==1.0.14 - # via gym-aloha -dm-env==1.6 - # via dm-control -dm-tree==0.1.8 - # via - # dm-control - # dm-env -docstring-parser==0.16 - # via tyro -farama-notifications==0.0.4 - # via gymnasium -fonttools==4.55.2 - # via matplotlib -glfw==2.8.0 - # via - # dm-control - # mujoco -gym-aloha==0.1.1 - # via -r examples/aloha_sim/requirements.in -gymnasium==1.0.0 - # via gym-aloha -idna==3.10 - # via requests -imageio==2.36.1 - # via - # -r examples/aloha_sim/requirements.in - # gym-aloha -imageio-ffmpeg==0.5.1 - # via imageio -kiwisolver==1.4.7 - # via matplotlib -labmaze==1.0.6 - # via dm-control -lxml==5.3.0 - # via dm-control -markdown-it-py==3.0.0 - # via rich -matplotlib==3.9.3 - # via -r examples/aloha_sim/requirements.in -mdurl==0.1.2 - # via markdown-it-py -msgpack==1.1.0 - # via -r examples/aloha_sim/requirements.in -mujoco==2.3.7 - # via - # dm-control - # gym-aloha -numpy==1.26.4 - # via - # -r examples/aloha_sim/requirements.in - # contourpy - # dm-control - # dm-env - # gymnasium - # imageio - # labmaze - # matplotlib - # mujoco - # scipy -packaging==24.2 - # via matplotlib -pillow==11.0.0 - # via - # imageio - # matplotlib -protobuf==5.29.1 - # via dm-control -psutil==6.1.0 - # via imageio -pygments==2.18.0 - # via rich -pyopengl==3.1.7 - # via - # dm-control - # mujoco -pyparsing==3.2.0 - # via - # dm-control - # matplotlib -python-dateutil==2.9.0.post0 - # via matplotlib -requests==2.32.3 - # via dm-control -rich==13.9.4 - # via tyro -scipy==1.14.1 - # via dm-control -setuptools==75.6.0 - # via - # dm-control - # imageio-ffmpeg - # labmaze -shtab==1.7.1 - # via tyro -six==1.17.0 - # via python-dateutil -tqdm==4.67.1 - # via dm-control -typeguard==4.4.1 - # via tyro -typing-extensions==4.12.2 - # via - # -r examples/aloha_sim/requirements.in - # gymnasium - # rich - # typeguard - # tyro -tyro==0.9.2 - # via -r examples/aloha_sim/requirements.in -urllib3==2.2.3 - # via requests -websockets==14.1 - # via -r examples/aloha_sim/requirements.in diff --git a/capvector-pi05/examples/aloha_sim/saver.py b/capvector-pi05/examples/aloha_sim/saver.py deleted file mode 100644 index 5928268d86c6f480e9ea27a74116617535c1bc0a..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/aloha_sim/saver.py +++ /dev/null @@ -1,40 +0,0 @@ -import logging -import pathlib - -import imageio -import numpy as np -from openpi_client.runtime import subscriber as _subscriber -from typing_extensions import override - - -class VideoSaver(_subscriber.Subscriber): - """Saves episode data.""" - - def __init__(self, out_dir: pathlib.Path, subsample: int = 1) -> None: - out_dir.mkdir(parents=True, exist_ok=True) - self._out_dir = out_dir - self._images: list[np.ndarray] = [] - self._subsample = subsample - - @override - def on_episode_start(self) -> None: - self._images = [] - - @override - def on_step(self, observation: dict, action: dict) -> None: - im = observation["images"]["cam_high"] # [C, H, W] - im = np.transpose(im, (1, 2, 0)) # [H, W, C] - self._images.append(im) - - @override - def on_episode_end(self) -> None: - existing = list(self._out_dir.glob("out_[0-9]*.mp4")) - next_idx = max([int(p.stem.split("_")[1]) for p in existing], default=-1) + 1 - out_path = self._out_dir / f"out_{next_idx}.mp4" - - logging.info(f"Saving video to {out_path}") - imageio.mimwrite( - out_path, - [np.asarray(x) for x in self._images[:: self._subsample]], - fps=50 // max(1, self._subsample), - ) diff --git a/capvector-pi05/examples/convert_jax_model_to_pytorch.py b/capvector-pi05/examples/convert_jax_model_to_pytorch.py deleted file mode 100644 index 7c2752fe9df5f0cb492832fd506d0fa854b7bd44..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/convert_jax_model_to_pytorch.py +++ /dev/null @@ -1,587 +0,0 @@ -#!/usr/bin/env python3 -""" -Load a JAX model and print all parameter keys, with optional conversion to PyTorch. - -This script loads a JAX model checkpoint using orbax and can either: -1. Print out all the parameter keys in a hierarchical structure for inspection -2. Convert the JAX model to PyTorch format using our PI0Pytorch model - -Usage: - # Just inspect keys: - python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only - python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only - - # Convert to PyTorch: - python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output - python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output - -Example: - # pi0_droid - python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid_pytorch - - # pi0_aloha_sim - python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch - - # pi05_droid - python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid_pytorch -""" - -import json -import os -import pathlib -import shutil -from typing import Literal - -from flax.nnx import traversals -import numpy as np -import orbax.checkpoint as ocp -import safetensors -import torch -import tyro - -import openpi.models.gemma -import openpi.models.model -import openpi.models.pi0_config -import openpi.models_pytorch.pi0_pytorch -from openpi.training import utils -import openpi.training.config as _config - - -def slice_paligemma_state_dict(state_dict, config): - """Convert PaliGemma JAX parameters to PyTorch format.""" - suffix = "/value" if "img/embedding/kernel/value" in state_dict else "" - - # patch embeddings - jax_key = f"img/embedding/kernel{suffix}" - pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.weight" - state_dict[pytorch_key] = state_dict.pop(jax_key).transpose(3, 2, 0, 1) - - jax_key = f"img/embedding/bias{suffix}" - pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.bias" - state_dict[pytorch_key] = state_dict.pop(jax_key) - - # positional embeddings - jax_key = f"img/pos_embedding{suffix}" - pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.position_embedding.weight" - state_dict[pytorch_key] = state_dict.pop(jax_key).reshape(-1, config.vision_config.hidden_size) - - # extract vision layers to be sliced at index 0. There are 27 layers in the base model. - encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}") - encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}") - encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}") - encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}") - - encoderblock_mlp_dense0_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}") - encoderblock_mlp_dense0_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}") - encoderblock_mlp_dense1_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}") - encoderblock_mlp_dense1_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}") - - encoderblock_attention_0_key_kernel = state_dict.pop( - f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}" - ) - encoderblock_attention_0_key_bias = state_dict.pop( - f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}" - ) - encoderblock_attention_0_value_kernel = state_dict.pop( - f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}" - ) - encoderblock_attention_0_value_bias = state_dict.pop( - f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}" - ) - encoderblock_attention_0_query_kernel = state_dict.pop( - f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}" - ) - encoderblock_attention_0_query_bias = state_dict.pop( - f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}" - ) - encoderblock_attention_0_out_kernel = state_dict.pop( - f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}" - ) - encoderblock_attention_0_out_bias = state_dict.pop( - f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}" - ) - - for i in range(config.vision_config.num_hidden_layers): - state_dict[ - f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight" - ] = encoderblock_layernorm0_scale[i].transpose() - state_dict[ - f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias" - ] = encoderblock_layernorm0_bias[i] - state_dict[ - f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight" - ] = encoderblock_layernorm1_scale[i].transpose() - state_dict[ - f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias" - ] = encoderblock_layernorm1_bias[i] - state_dict[ - f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight" - ] = encoderblock_mlp_dense0_kernel[i].transpose() - state_dict[ - f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias" - ] = encoderblock_mlp_dense0_bias[i] - state_dict[ - f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight" - ] = encoderblock_mlp_dense1_kernel[i].transpose() - state_dict[ - f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias" - ] = encoderblock_mlp_dense1_bias[i] - state_dict[ - f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight" - ] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() - state_dict[ - f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias" - ] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) - state_dict[ - f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight" - ] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() - state_dict[ - f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias" - ] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) - state_dict[ - f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight" - ] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() - state_dict[ - f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias" - ] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) - state_dict[ - f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight" - ] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() - state_dict[ - f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias" - ] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) - - jax_key = f"img/Transformer/encoder_norm/scale{suffix}" - pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.weight" - state_dict[pytorch_key] = state_dict.pop(jax_key).transpose() - - jax_key = f"img/Transformer/encoder_norm/bias{suffix}" - pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.bias" - state_dict[pytorch_key] = state_dict.pop(jax_key) - - # multimodal projector - jax_key = f"img/head/kernel{suffix}" - pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.weight" - state_dict[pytorch_key] = state_dict.pop(jax_key).transpose() - - jax_key = f"img/head/bias{suffix}" - pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.bias" - state_dict[pytorch_key] = state_dict.pop(jax_key) - - # text decoder (gemma) - jax_key = f"llm/embedder/input_embedding{suffix}" - pytorch_key = "paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight" - state_dict[pytorch_key] = state_dict.pop(jax_key) - - # pop the einsum attention + mlp representations - llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}") - llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}") - llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}") - - llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}") - llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}") - - llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}") - llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}") - - for i in range(config.text_config.num_hidden_layers): - q_proj_weight_reshaped = ( - llm_attention_q_einsum[i] - .transpose(0, 2, 1) - .reshape( - config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size - ) - ) - state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.q_proj.weight"] = ( - q_proj_weight_reshaped - ) - - k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose() - state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.k_proj.weight"] = ( - k_proj_weight_reshaped - ) - v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose() - state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.v_proj.weight"] = ( - v_proj_weight_reshaped - ) - - o_proj_weight_reshaped = ( - llm_attention_attn_vec_einsum[i] - .transpose(2, 0, 1) - .reshape( - config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size - ) - ) - state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.o_proj.weight"] = ( - o_proj_weight_reshaped - ) - - gate_proj_weight = llm_mlp_gating_einsum[i, 0] - state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.gate_proj.weight"] = ( - gate_proj_weight.transpose() - ) - up_proj_weight = llm_mlp_gating_einsum[i, 1] - state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.up_proj.weight"] = ( - up_proj_weight.transpose() - ) - state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.down_proj.weight"] = ( - llm_mlp_linear[i].transpose() - ) - state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.input_layernorm.weight"] = ( - llm_input_layernorm[i] - ) - state_dict[ - f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.post_attention_layernorm.weight" - ] = llm_post_attention_layernorm[i] - - jax_key = f"llm/final_norm/scale{suffix}" - pytorch_key = "paligemma_with_expert.paligemma.model.language_model.norm.weight" - state_dict[pytorch_key] = state_dict.pop(jax_key) - - expert_dict = {} - final_state_dict = {} - - # Expert-related keys to extract (including pi05 Dense layer parameters) - expert_keys = [ - f"llm/final_norm_1/scale{suffix}", - f"llm/final_norm_1/Dense_0/bias{suffix}", - f"llm/final_norm_1/Dense_0/kernel{suffix}", - f"llm/layers/attn/attn_vec_einsum_1/w{suffix}", - f"llm/layers/attn/kv_einsum_1/w{suffix}", - f"llm/layers/attn/q_einsum_1/w{suffix}", - f"llm/layers/mlp_1/gating_einsum{suffix}", - f"llm/layers/mlp_1/linear{suffix}", - f"llm/layers/pre_attention_norm_1/scale{suffix}", - f"llm/layers/pre_attention_norm_1/Dense_0/bias{suffix}", - f"llm/layers/pre_attention_norm_1/Dense_0/kernel{suffix}", - f"llm/layers/pre_ffw_norm_1/scale{suffix}", - f"llm/layers/pre_ffw_norm_1/Dense_0/bias{suffix}", - f"llm/layers/pre_ffw_norm_1/Dense_0/kernel{suffix}", - ] - - for key, value in state_dict.items(): - if key not in expert_keys: - final_state_dict[key] = torch.from_numpy(value) - else: - expert_dict[key] = value - - return final_state_dict, expert_dict - - -def slice_gemma_state_dict(state_dict, config, *, num_expert, checkpoint_dir, pi05): - """Convert Gemma JAX parameters to PyTorch format.""" - # Add missing attributes to config if they don't exist - if not hasattr(config, "vocab_size"): - config.vocab_size = 257152 # PALIGEMMA_VOCAB_SIZE - if not hasattr(config, "hidden_size"): - config.hidden_size = config.width - if not hasattr(config, "num_hidden_layers"): - config.num_hidden_layers = config.depth - if not hasattr(config, "num_attention_heads"): - config.num_attention_heads = config.num_heads - - suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else "" - - llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}") - llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}") - llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}") - - llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}") - llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}") - - # Check if we have Dense layers (for pi05/adaptive normalization) or scale layers (for regular pi0) - if "pi05" in checkpoint_dir: - # Pi05 with adaptive normalization - llm_input_layernorm_bias = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/bias{suffix}") - llm_post_attention_layernorm_bias = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/bias{suffix}") - llm_input_layernorm_kernel = state_dict.pop( - f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/kernel{suffix}" - ) - llm_post_attention_layernorm_kernel = state_dict.pop( - f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/kernel{suffix}" - ) - else: - # Regular pi0 with standard RMSNorm - llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}") - llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}") - - for i in range(config.num_hidden_layers): - q_proj_weight_reshaped = ( - llm_attention_q_einsum[i] - .transpose(0, 2, 1) - .reshape(config.num_attention_heads * config.head_dim, config.hidden_size) - ) - state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = ( - q_proj_weight_reshaped - ) - - k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose() - state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = ( - k_proj_weight_reshaped - ) - v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose() - state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = ( - v_proj_weight_reshaped - ) - - o_proj_weight_reshaped = ( - llm_attention_attn_vec_einsum[i] - .reshape(config.num_attention_heads * config.head_dim, config.hidden_size) - .transpose(1, 0) - ) - state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = ( - o_proj_weight_reshaped - ) - - gate_proj_weight = llm_mlp_gating_einsum[i, 0] - state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = ( - gate_proj_weight.transpose() - ) - up_proj_weight = llm_mlp_gating_einsum[i, 1] - state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = ( - up_proj_weight.transpose() - ) - state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[ - i - ].transpose() - - if "pi05" in checkpoint_dir: - # Pi05 with adaptive normalization - use Dense layer parameters directly - state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.bias"] = ( - llm_input_layernorm_bias[i] - ) - state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.bias"] = ( - llm_post_attention_layernorm_bias[i] - ) - state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.weight"] = ( - llm_input_layernorm_kernel[i].transpose() - ) - state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.weight"] = ( - llm_post_attention_layernorm_kernel[i].transpose() - ) - else: - # Regular pi0 with standard RMSNorm - state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.weight"] = ( - llm_input_layernorm[i] - ) - state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = ( - llm_post_attention_layernorm[i] - ) - - # Handle final norm layer - if "pi05" in checkpoint_dir: - # Pi05 with adaptive normalization - use Dense layer parameters directly - final_norm_bias = state_dict.pop(f"llm/final_norm_{num_expert}/Dense_0/bias{suffix}") - final_norm_kernel = state_dict.pop(f"llm/final_norm_{num_expert}/Dense_0/kernel{suffix}") - state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.bias"] = final_norm_bias - state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.weight"] = final_norm_kernel.transpose() - else: - # Regular pi0 with standard RMSNorm - state_dict["paligemma_with_expert.gemma_expert.model.norm.weight"] = state_dict.pop( - f"llm/final_norm_{num_expert}/scale{suffix}" - ) - - # state_dict["paligemma_with_expert.gemma_expert.lm_head.weight"] = embedding_vector # weights are tied. - - final_state_dict = {} - for key, value in state_dict.items(): - if not isinstance(value, torch.Tensor): - final_state_dict[key] = torch.from_numpy(value) - else: - final_state_dict[key] = value - - return final_state_dict - - -def slice_initial_orbax_checkpoint(checkpoint_dir: str, restore_precision: str | None = None): - """Load and process params by restoring via JAX model loader first. - This respects dtype conversions that occur during model restore. - """ - # Use repository restore utility to load a pure dict of params (value suffix removed) - params = openpi.models.model.restore_params( - f"{checkpoint_dir}/params/", restore_type=np.ndarray, dtype=restore_precision - ) - - return {"paligemma_params": traversals.flatten_mapping(params["PaliGemma"], sep="/"), "projection_params": params} - - -def load_jax_model_and_print_keys(checkpoint_dir: str): - """ - Load JAX model from checkpoint and print all parameter keys. - - Args: - checkpoint_dir: Path to the checkpoint directory - """ - checkpoint_dir = os.path.abspath(checkpoint_dir) if not checkpoint_dir.startswith("gs://") else checkpoint_dir - # Initialize checkpointer - checkpointer = ocp.PyTreeCheckpointer() - metadata = checkpointer.metadata(f"{checkpoint_dir}/params") - print(utils.array_tree_to_info(metadata)) - - -def convert_pi0_checkpoint( - checkpoint_dir: str, precision: str, output_path: str, model_config: openpi.models.pi0_config.Pi0Config -): - """ - Convert PI0 JAX checkpoint to PyTorch format. - - Args: - checkpoint_dir: Path to the JAX checkpoint - precision: Model precision (float32, bfloat16, float16) - output_path: Path to save the converted PyTorch model - model_config: Model config - """ - print(f"Converting PI0 checkpoint from {checkpoint_dir} to {output_path}") - print(f"Model config: {model_config}") - - # Break down orbax ckpts by restoring via JAX to respect dtype - initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir, restore_precision="float32") - - # Process projection params - if model_config.pi05: - keys = [ - "action_in_proj", - "action_out_proj", - "time_mlp_in", - "time_mlp_out", - ] - else: - keys = [ - "state_proj", - "action_in_proj", - "action_out_proj", - "action_time_mlp_in", - "action_time_mlp_out", - ] - - projection_params = {} - for key in keys: - kernel_params = initial_params["projection_params"][key]["kernel"] - bias_params = initial_params["projection_params"][key]["bias"] - if isinstance(kernel_params, dict): - weight = kernel_params["value"] - bias = bias_params["value"] - else: - weight = kernel_params - bias = bias_params - - pytorch_weight_key = f"{key}.weight" - pytorch_bias_key = f"{key}.bias" - - projection_params[pytorch_weight_key] = torch.from_numpy(np.array(weight)).T - projection_params[pytorch_bias_key] = torch.from_numpy(np.array(bias)) - - # Create configs based on checkpoint path - # All models use the same PaliGemma config structure - class PaliGemmaConfig: - def __init__(self): - self.vision_config = type( - "obj", - (object,), - { - "hidden_size": 1152, - "num_hidden_layers": 27, - "num_attention_heads": 16, - "intermediate_size": 4304, - "patch_size": 14, - "projection_dim": 2048, - }, - )() - self.text_config = type( - "obj", - (object,), - { - "hidden_size": 2048, - "num_hidden_layers": 18, - "num_attention_heads": 8, - "head_dim": 256, - "intermediate_size": 16384, - }, - )() - - paligemma_config = PaliGemmaConfig() - action_expert_config = openpi.models.gemma.get_config("gemma_300m") - - # Process PaliGemma weights - paligemma_params, expert_params = slice_paligemma_state_dict(initial_params["paligemma_params"], paligemma_config) - - # Process Gemma weights from expert_params - gemma_params = slice_gemma_state_dict( - expert_params, action_expert_config, num_expert=1, checkpoint_dir=checkpoint_dir, pi05=model_config.pi05 - ) - - # Instantiate model - pi0_model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_config) - - # Combine all parameters (no prefix needed for our model structure) - all_params = {**paligemma_params, **gemma_params, **projection_params} - - # Load state dict - pi0_model.load_state_dict(all_params, strict=False) - - if precision == "float32": - pi0_model = pi0_model.to(torch.float32) - elif precision == "bfloat16": - pi0_model = pi0_model.to(torch.bfloat16) - else: - raise ValueError(f"Invalid precision: {precision}") - - # Save the converted model using safetensors - os.makedirs(output_path, exist_ok=True) - - # Save model weights as SafeTensors using save_model to handle tied weights - safetensors.torch.save_model(pi0_model, os.path.join(output_path, "model.safetensors")) - - # Copy assets folder if it exists - assets_source = pathlib.Path(checkpoint_dir).parent / "assets" - if assets_source.exists(): - assets_dest = pathlib.Path(output_path) / "assets" - if assets_dest.exists(): - shutil.rmtree(assets_dest) - shutil.copytree(assets_source, assets_dest) - - # Save config as JSON for reference - config_dict = { - "action_dim": model_config.action_dim, - "action_horizon": model_config.action_horizon, - "paligemma_variant": model_config.paligemma_variant, - "action_expert_variant": model_config.action_expert_variant, - "precision": precision, - } - with open(os.path.join(output_path, "config.json"), "w") as f: - json.dump(config_dict, f, indent=2) - - print("Model conversion completed successfully!") - print(f"Model saved to {output_path}") - - -def main( - checkpoint_dir: str, - config_name: str, - output_path: str | None = None, - precision: Literal["float32", "bfloat16", "float16"] = "bfloat16", - *, - inspect_only: bool = False, -): - """Load JAX model and optionally convert to PyTorch. - - Args: - checkpoint_dir: Path to the JAX checkpoint directory - output_path: Path to save converted PyTorch model (required for conversion) - precision: Precision for model conversion - inspect_only: Only inspect parameter keys, don't convert - """ - model_config = _config.get_config(config_name).model - if not isinstance(model_config, openpi.models.pi0_config.Pi0Config): - raise ValueError(f"Config {config_name} is not a Pi0Config") - if inspect_only: - load_jax_model_and_print_keys(checkpoint_dir) - else: - if not output_path: - print("Error: --output_path is required for conversion. Use --inspect_only to only view keys.") - return - convert_pi0_checkpoint(checkpoint_dir, precision, output_path, model_config) - - -if __name__ == "__main__": - tyro.cli(main) diff --git a/capvector-pi05/examples/droid/README.md b/capvector-pi05/examples/droid/README.md deleted file mode 100644 index abef943e97cca8d95d4c0fea3f6d601affbc0582..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/droid/README.md +++ /dev/null @@ -1,84 +0,0 @@ -# DROID Policies in openpi - -We offer instructions for: -- [Running inference for our best $pi_{0.5}$-DROID policy](./README.md#running-droid-inference) -- [Running inference for other pre-trained DROID policies ($\pi_0$, $\pi_0$-FAST, ...)](./README.md#running-roboarena-baseline-policies) -- [Pre-training *generalist* policies on the *full* DROID dataset](./README_train.md#training-on-droid) -- [Fine-tuning expert $\pi_{0.5}$ on your custom DROID dataset](./README_train.md#fine-tuning-on-custom-droid-datasets) - -## Running DROID Inference - -This example shows how to run the fine-tuned $\pi_{0.5}$-DROID model on the [DROID robot platform](https://github.com/droid-dataset/droid). Based on the [public RoboArena benchmark](https://robo-arena.github.io/leaderboard), this is currently our strongest generalist DROID policy. - - -### Step 1: Start a policy server - -Since the DROID control laptop does not have a powerful GPU, we will start a remote policy server on a different machine with a more powerful GPU and then query it from the DROID control laptop during inference. - -1. On a machine with a powerful GPU (~NVIDIA 4090), clone and install the `openpi` repository following the instructions in the [README](https://github.com/Physical-Intelligence/openpi). -2. Start the OpenPI server via the following command: - -```bash -uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi05_droid --policy.dir=gs://openpi-assets/checkpoints/pi05_droid -``` - -You can also run the equivalent command below: - -```bash -uv run scripts/serve_policy.py --env=DROID -``` - -### Step 2: Run the DROID robot - -1. Make sure you have the most recent version of the DROID package installed on both the DROID control laptop and the NUC. -2. On the control laptop, activate your DROID conda environment. -3. Clone the openpi repo and install the openpi client, which we will use to connect to the policy server (this has very few dependencies and should be very fast to install): with the DROID conda environment activated, run `cd $OPENPI_ROOT/packages/openpi-client && pip install -e .`. -4. Install `tyro`, which we will use for command line parsing: `pip install tyro`. -5. Copy the `main.py` file from this directory to the `$DROID_ROOT/scripts` directory. -6. Replace the camera IDs in the `main.py` file with the IDs of your cameras (you can find the camera IDs by running `ZED_Explorer` in the command line, which will open a tool that shows you all connected cameras and their IDs -- you can also use it to make sure that the cameras are well-positioned to see the scene you want the robot to interact with). -7. Run the `main.py` file. Make sure to point the IP and host address to the policy server. (To make sure the server machine is reachable from the DROID laptop, you can run `ping ` from the DROID laptop.) Also make sure to specify the external camera to use for the policy (we only input one external camera), choose from ["left", "right"]. - -```bash -python3 scripts/main.py --remote_host= --remote_port= --external_camera="left" -``` - -The script will ask you to enter a free-form language instruction for the robot to follow. Make sure to point the cameras at the scene you want the robot to interact with. You _do not_ need to carefully control camera angle, object positions, etc. The policy is fairly robust in our experience. Happy prompting! - -## Troubleshooting - -| Issue | Solution | -|-------|----------| -| Cannot reach policy server | Make sure the server is running and the IP and port are correct. You can check that the server machine is reachable by running `ping ` from the DROID laptop. | -| Cannot find cameras | Make sure the camera IDs are correct and that the cameras are connected to the DROID laptop. Sometimes replugging the cameras can help. You can check all connected cameras by running `ZED_Explore` in the command line. | -| Policy inference is slow / inconsistent | Try using a wired internet connection for the DROID laptop to reduce latency (0.5 - 1 sec latency per chunk is normal). | -| Policy does not perform the task well | In our experiments, the policy could perform simple table top manipulation tasks (pick-and-place) across a wide range of environments, camera positions, and lighting conditions. If the policy does not perform the task well, you can try modifying the scene or object placement to make the task easier. Also make sure that the camera view you are passing to the policy can see all relevant objects in the scene (the policy is only conditioned on a single external camera + wrist camera, make sure you are feeding the desired camera to the policy). Use `ZED_Explore` to check that the camera view you are passing to the policy can see all relevant objects in the scene. Finally, the policy is far from perfect and will fail on more complex manipulation tasks, but it usually makes a decent effort. :) | - - -## Running Other Policies - -We provide configs for running the baseline DROID policies from the [RoboArena](https://robo-arena.github.io/) paper. Simply run the commands below to start inference servers for the respective policies. Then follow the instructions above to run evaluation on the DROID robot. - -``` -# Train from pi0-FAST, using FAST tokenizer -uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_fast_droid - -# Train from pi0, using flow matching -uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_droid - -# Trained from PaliGemma, using RT-2 / OpenVLA style binning tokenizer. -uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_binning_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_binning_droid - -# Trained from PaliGemma, using FAST tokenizer (using universal FAST+ tokenizer). -uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_fast_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_fast_droid - -# Trained from PaliGemma, using FAST tokenizer (tokenizer trained on DROID dataset). -uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_fast_specialist_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_fast_specialist_droid - -# Trained from PaliGemma, using FSQ tokenizer. -uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_vq_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_vq_droid - -# pi0-style diffusion / flow VLA, trained on DROID from PaliGemma. -uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_diffusion_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_diffusion_droid -``` - -You can find the inference configs in [roboarena_config.py](../../src/openpi/training/misc/roboarena_config.py). diff --git a/capvector-pi05/examples/droid/README_train.md b/capvector-pi05/examples/droid/README_train.md deleted file mode 100644 index f3fbe44ec56d6a46ae0919f8ffebbe6f92b411cd..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/droid/README_train.md +++ /dev/null @@ -1,106 +0,0 @@ -# Training on DROID - -Here we describe how to fine-tune the pi0.5 model on the *full* DROID dataset. This is an approximate open-source reproduction of the pi05-DROID training pipeline. -(small differences in data loading and the used action space) -- For a tutorial on how to fine-tune your model with a smaller, custom dataset collected on the DROID platform, see below. - -In contrast to the rest of openpi, which uses LeRobot for data loading, we need to use RLDS as the data format for full DROID training (since at the moment LeRobot isn't scalable enough -for larger datasets like DROID -- they are working on improving it though). Below, we provide instructions for updating your openpi environment for RLDS data loading and where to download the DROID dataset. - -## Install - -We need a few additional dependencies for RLDS data loading. Run: -```bash -uv sync --group rlds -``` - -## Download DROID dataset - -You can download the DROID dataset with the following command (after installing the `gsutil` google cloud CLI): -``` -gsutil -m cp -r gs://gresearch/robotics/droid/1.0.1 /droid/1.0.1 -``` - -Note that downloading version 1.0.1 is important (not v1.0.0): it contains the complete set of language annotations (~75k episodes) while v1.0.0 only has annotations for 30k episodes. If for some reason you would like to use another version, modify the line `version="1.0.1"` in the `DroidRldsDataset` object [here](src/openpi/training/droid_rlds_dataset.py). - -You will need 1.8TB of disk storage to download the DROID RLDS dataset. - -## Run - -First, change the `rlds_data_dir` path in your `TrainConfig` to the directory that you downloaded the `droid` dataset into (see [src/openpi/training/config.py](src/openpi/training/config.py)). - -Then, compute normalization statistics (this will take ~10 minutes): -```bash -uv run --group rlds scripts/compute_norm_stats.py --config-name pi05_full_droid_finetune --max-frames 10_000_000 -``` - -Run training: -```bash -XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run --group rlds scripts/train.py pi05_full_droid_finetune --exp-name=my_experiment --overwrite -``` - -**Note**: The original pi0.5-DROID model was trained with joint velocity actions. -Joint velocity actions are not compatible with simulated evaluation environments (much harder to simulate). -Thus, we do not recommend training with joint velocity actions and instead use joint position actions here. - - -## Compute Requirements - -Our DROID training config requires approximately 2 days on 8x H100 GPUs for convergence (100k iterations, bs256, approx. 1 epoch). -If you start from PaliGemma instead of pi0 initialization, plan with ~5 days on 8x H100s (240k iterations, i.e. 3 epochs). - -We have experimented with LoRA for cheaper finetuning, but haven't found the policies to perform well so far. - - -## Data Filtering - -Like any diverse real-robot dataset, the DROID dataset isn't perfectly "clean" and we have found data filtering to significantly improve policy performance. Concretely, the DROID dataset contains many *idle* timesteps in which the robot does not move (in part due to the VR teleoperation interface that was used during data collection, we will not go into too much detail here). Appropriate filtering of these idle transitions can improve policy performance. - -By default, our openpi training recipe implements the same idle filter used to train all pi-DROID models. We implement it by pre-computing which dataset indices to sample during training. You can check [compute_droid_nonidle_ranges.py](examples/droid/compute_droid_nonidle_ranges.py) for how we compute these indices. Roughly speaking, we filter any time steps for which the next chunk of actions would be largely idle. During training, our code automatically pulls our pre-computed list of indices from cloud storage and applies them. If you want to modify the idle filter / create your custom sampling logic, you can modify our script to generate a new index list and provide it via the `filter_dict_path=""` argument in [src/openpi/training/config.py](src/openpi/training/config.py). - -**Note**: our list of filtering indices is only valid for the `droid/1.0.1` dataset mentioned in the download section above, and will not provide valid filtering for any other version of the DROID dataset, so make sure you download the dataset above! If you have a custom DROID version, you can rerun the [compute_droid_nonidle_ranges.py](examples/droid/compute_droid_nonidle_ranges.py) script to generate a new list of sampling indices. - -## RoboArena - -Consider submitting your DROID policies to the [RoboArena benchmark](https://robo-arena.github.io/), which allows you to evaluate your policies on diverse tasks & scenes, **in the real world**! :) - -If you have questions about RoboArena, please email [karl.pertsch@gmail.com](mailto:karl.pertsch@gmail.com). - - -# Fine-Tuning on Custom DROID Datasets - -Here we describe how to fine-tune a model on a custom (smaller) dataset collected on the DROID platform. Like for other datasets, we will first convert the custom DROID dataset to LeRobot and then fine-tune a model (pi05-droid) on it. - -Note: We use LeRobot here, since we assume the custom DROID fine-tuning dataset to be relatively small (<10s of hours). For larger datasets (like the full DROID dataset) we recommend using RLDS for it's better efficiency (see the example above). - - -## Step 1: Converting your custom DROID dataset to LeRobot - -We will use a small subset of the real DROID dataset for this example. This is a subset of just 30 demonstrations -- we assume that you will use your own dataset instead, but here is the command to download our subset (1.6GB): -``` -gsutil -m cp -r gs://gresearch/robotics/droid_raw/1.0.1/IRIS/success/2023-12-04 -``` - -We will also download the language annotations for the DROID dataset so we can pair our demonstrations with language instructions. Again, for your own data you can manually enter your language instructions and don't need to download our annotations. To download the DROID language annotations (12MB), run: -``` -gsutil -m cp -r gs://gresearch/robotics/droid_raw/1.0.1/aggregated-annotations-030724.json -``` - -For your own dataset, make sure that each episode's directory contains a folder called `recordings/MP4` -- if not, you need to first run the MP4 video extraction (from SVO files) using the script [here](https://github.com/droid-dataset/droid/blob/main/scripts/convert/svo_to_mp4.py). - -Now, we will use the `convert_droid_to_lerobot.py` script to create a LeRobot version of this dataset (takes <5min for the 30 demonstrations): -``` -uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir -``` - -## Step 2: Run fine-tuning with your custom dataset - -Now we can run fine-tuning with our converted custom dataset. We provide an example config for fine-tuning `pi05_droid` on the custom dataset we created. -You can modify the config easily to work with other base models, or use your custom DROID dataset in `config.py` (seach for `pi05_droid_finetune`). - -To launch training: -``` -uv run scripts/train.py pi05_droid_finetune --exp-name=my_experiment --overwrite -``` - -Once trained, you can follow the instructions in [`examples/droid/README.md`](examples/droid/README.md) to serve the policy and run it on the robot. - diff --git a/capvector-pi05/examples/droid/compute_droid_nonidle_ranges.py b/capvector-pi05/examples/droid/compute_droid_nonidle_ranges.py deleted file mode 100644 index a79ab422144a6b4cffd073c9267db0aa7fec3c26..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/droid/compute_droid_nonidle_ranges.py +++ /dev/null @@ -1,103 +0,0 @@ -""" -Iterates through the DROID dataset and creates a json mapping from episode unique IDs to ranges of time steps -that should be sampled during training (all others are filtered out). - -Filtering logic: -We look for ranges of consecutive steps that contain at most min_idle_len consecutive idle frames -(default to 7 -- as most DROID action-chunking policies run the first 8 actions generated in each chunk, filtering -this way means the policy will not get stuck outputting stationary actions). Additionally, we also only keep non-idle -ranges of length at least min_non_idle_len (default to 16 frames = ~1 second), while also removing the last -filter_last_n_in_ranges frames from the end of each range (as those all correspond to action chunks with many idle actions). - -This leaves us with trajectory segments consisting of contiguous, significant movement. Training on this filtered set -yields policies that output fewer stationary actions (i.e., get "stuck" in states less). -""" - -import json -import os -from pathlib import Path - -import numpy as np -import tensorflow as tf -import tensorflow_datasets as tfds -from tqdm import tqdm - -os.environ["CUDA_VISIBLE_DEVICES"] = "" # Set to the GPU you want to use, or leave empty for CPU - -builder = tfds.builder_from_directory( - # path to the `droid` directory (not its parent) - builder_dir="", -) -ds = builder.as_dataset(split="train", shuffle_files=False) -tf.data.experimental.ignore_errors(ds) - -keep_ranges_path = "" - -min_idle_len = 7 # If more than this number of consecutive idle frames, filter all of them out -min_non_idle_len = 16 # If fewer than this number of consecutive non-idle frames, filter all of them out -filter_last_n_in_ranges = 10 # When using a filter dict, remove this many frames from the end of each range - -keep_ranges_map = {} -if Path(keep_ranges_path).exists(): - with Path(keep_ranges_path).open("r") as f: - keep_ranges_map = json.load(f) - print(f"Resuming from {len(keep_ranges_map)} episodes already processed") - -for ep_idx, ep in enumerate(tqdm(ds)): - recording_folderpath = ep["episode_metadata"]["recording_folderpath"].numpy().decode() - file_path = ep["episode_metadata"]["file_path"].numpy().decode() - - key = f"{recording_folderpath}--{file_path}" - if key in keep_ranges_map: - continue - - joint_velocities = [step["action_dict"]["joint_velocity"].numpy() for step in ep["steps"]] - joint_velocities = np.array(joint_velocities) - - is_idle_array = np.hstack( - [np.array([False]), np.all(np.abs(joint_velocities[1:] - joint_velocities[:-1]) < 1e-3, axis=1)] - ) - - # Find what steps go from idle to non-idle and vice-versa - is_idle_padded = np.concatenate( - [[False], is_idle_array, [False]] - ) # Start and end with False, so idle at first step is a start of motion - - is_idle_diff = np.diff(is_idle_padded.astype(int)) - is_idle_true_starts = np.where(is_idle_diff == 1)[0] # +1 transitions --> going from idle to non-idle - is_idle_true_ends = np.where(is_idle_diff == -1)[0] # -1 transitions --> going from non-idle to idle - - # Find which steps correspond to idle segments of length at least min_idle_len - true_segment_masks = (is_idle_true_ends - is_idle_true_starts) >= min_idle_len - is_idle_true_starts = is_idle_true_starts[true_segment_masks] - is_idle_true_ends = is_idle_true_ends[true_segment_masks] - - keep_mask = np.ones(len(joint_velocities), dtype=bool) - for start, end in zip(is_idle_true_starts, is_idle_true_ends, strict=True): - keep_mask[start:end] = False - - # Get all non-idle ranges of at least 16 - # Same logic as above, but for keep_mask, allowing us to filter out contiguous ranges of length < min_non_idle_len - keep_padded = np.concatenate([[False], keep_mask, [False]]) - - keep_diff = np.diff(keep_padded.astype(int)) - keep_true_starts = np.where(keep_diff == 1)[0] # +1 transitions --> going from filter out to keep - keep_true_ends = np.where(keep_diff == -1)[0] # -1 transitions --> going from keep to filter out - - # Find which steps correspond to non-idle segments of length at least min_non_idle_len - true_segment_masks = (keep_true_ends - keep_true_starts) >= min_non_idle_len - keep_true_starts = keep_true_starts[true_segment_masks] - keep_true_ends = keep_true_ends[true_segment_masks] - - # Add mapping from episode unique ID key to list of non-idle ranges to keep - keep_ranges_map[key] = [] - for start, end in zip(keep_true_starts, keep_true_ends, strict=True): - keep_ranges_map[key].append((int(start), int(end) - filter_last_n_in_ranges)) - - if ep_idx % 1000 == 0: - with Path(keep_ranges_path).open("w") as f: - json.dump(keep_ranges_map, f) - -print("Done!") -with Path(keep_ranges_path).open("w") as f: - json.dump(keep_ranges_map, f) diff --git a/capvector-pi05/examples/droid/convert_droid_data_to_lerobot.py b/capvector-pi05/examples/droid/convert_droid_data_to_lerobot.py deleted file mode 100644 index ab8c5ecdc3c1ddfe44d0f2baa2e1e2e7851e1727..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/droid/convert_droid_data_to_lerobot.py +++ /dev/null @@ -1,477 +0,0 @@ -""" -Minimal example script for converting a dataset collected on the DROID platform to LeRobot format. - -Usage: -uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir /path/to/your/data - -If you want to push your dataset to the Hugging Face Hub, you can use the following command: -uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub - -The resulting dataset will get saved to the $LEROBOT_HOME directory. -""" - -from collections import defaultdict -import copy -import glob -import json -from pathlib import Path -import shutil - -import cv2 -import h5py -from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -import numpy as np -from PIL import Image -from tqdm import tqdm -import tyro - -REPO_NAME = "your_hf_username/my_droid_dataset" # Name of the output dataset, also used for the Hugging Face Hub - - -def resize_image(image, size): - image = Image.fromarray(image) - return np.array(image.resize(size, resample=Image.BICUBIC)) - - -def main(data_dir: str, *, push_to_hub: bool = False): - # Clean up any existing dataset in the output directory - output_path = HF_LEROBOT_HOME / REPO_NAME - if output_path.exists(): - shutil.rmtree(output_path) - data_dir = Path(data_dir) - - # Create LeRobot dataset, define features to store - # We will follow the DROID data naming conventions here. - # LeRobot assumes that dtype of image data is `image` - dataset = LeRobotDataset.create( - repo_id=REPO_NAME, - robot_type="panda", - fps=15, # DROID data is typically recorded at 15fps - features={ - # We call this "left" since we will only use the left stereo camera (following DROID RLDS convention) - "exterior_image_1_left": { - "dtype": "image", - "shape": (180, 320, 3), # This is the resolution used in the DROID RLDS dataset - "names": ["height", "width", "channel"], - }, - "exterior_image_2_left": { - "dtype": "image", - "shape": (180, 320, 3), - "names": ["height", "width", "channel"], - }, - "wrist_image_left": { - "dtype": "image", - "shape": (180, 320, 3), - "names": ["height", "width", "channel"], - }, - "joint_position": { - "dtype": "float32", - "shape": (7,), - "names": ["joint_position"], - }, - "gripper_position": { - "dtype": "float32", - "shape": (1,), - "names": ["gripper_position"], - }, - "actions": { - "dtype": "float32", - "shape": (8,), # We will use joint *velocity* actions here (7D) + gripper position (1D) - "names": ["actions"], - }, - }, - image_writer_threads=10, - image_writer_processes=5, - ) - - # Load language annotations - # Note: we load the DROID language annotations for this example, but you can manually define them for your own data - with (data_dir / "aggregated-annotations-030724.json").open() as f: - language_annotations = json.load(f) - - # Loop over raw DROID fine-tuning datasets and write episodes to the LeRobot dataset - # We assume the following directory structure: - # RAW_DROID_PATH/ - # - <...>/ - # - recordings/ - # - MP4/ - # - .mp4 # single-view video of left stereo pair camera - # - trajectory.hdf5 - # - <...>/ - episode_paths = list(data_dir.glob("**/trajectory.h5")) - print(f"Found {len(episode_paths)} episodes for conversion") - - # We will loop over each dataset_name and write episodes to the LeRobot dataset - for episode_path in tqdm(episode_paths, desc="Converting episodes"): - # Load raw data - recording_folderpath = episode_path.parent / "recordings" / "MP4" - trajectory = load_trajectory(str(episode_path), recording_folderpath=str(recording_folderpath)) - - # To load the language instruction, we need to parse out the episode_id from the metadata file - # Again, you can modify this step for your own data, to load your own language instructions - metadata_filepath = next(iter(episode_path.parent.glob("metadata_*.json"))) - episode_id = metadata_filepath.name.split(".")[0].split("_")[-1] - language_instruction = language_annotations.get(episode_id, {"language_instruction1": "Do something"})[ - "language_instruction1" - ] - print(f"Converting episode with language instruction: {language_instruction}") - - # Write to LeRobot dataset - for step in trajectory: - camera_type_dict = step["observation"]["camera_type"] - wrist_ids = [k for k, v in camera_type_dict.items() if v == 0] - exterior_ids = [k for k, v in camera_type_dict.items() if v != 0] - dataset.add_frame( - { - # Note: need to flip BGR --> RGB for loaded images - "exterior_image_1_left": resize_image( - step["observation"]["image"][exterior_ids[0]][..., ::-1], (320, 180) - ), - "exterior_image_2_left": resize_image( - step["observation"]["image"][exterior_ids[1]][..., ::-1], (320, 180) - ), - "wrist_image_left": resize_image(step["observation"]["image"][wrist_ids[0]][..., ::-1], (320, 180)), - "joint_position": np.asarray( - step["observation"]["robot_state"]["joint_positions"], dtype=np.float32 - ), - "gripper_position": np.asarray( - step["observation"]["robot_state"]["gripper_position"][None], dtype=np.float32 - ), - # Important: we use joint velocity actions here since pi05-droid was pre-trained on joint velocity actions - "actions": np.concatenate( - [step["action"]["joint_velocity"], step["action"]["gripper_position"][None]], dtype=np.float32 - ), - "task": language_instruction, - } - ) - dataset.save_episode() - - # Optionally push to the Hugging Face Hub - if push_to_hub: - dataset.push_to_hub( - tags=["libero", "panda", "rlds"], - private=False, - push_videos=True, - license="apache-2.0", - ) - - -########################################################################################################## -################ The rest of this file are functions to parse the raw DROID data ######################### -################ You don't need to worry about understanding this part ######################### -################ It was copied from here: https://github.com/JonathanYang0127/r2d2_rlds_dataset_builder/blob/parallel_convert/r2_d2/r2_d2.py -########################################################################################################## - - -camera_type_dict = { - "hand_camera_id": 0, - "varied_camera_1_id": 1, - "varied_camera_2_id": 1, -} - -camera_type_to_string_dict = { - 0: "hand_camera", - 1: "varied_camera", - 2: "fixed_camera", -} - - -def get_camera_type(cam_id): - if cam_id not in camera_type_dict: - return None - type_int = camera_type_dict[cam_id] - return camera_type_to_string_dict[type_int] - - -class MP4Reader: - def __init__(self, filepath, serial_number): - # Save Parameters # - self.serial_number = serial_number - self._index = 0 - - # Open Video Reader # - self._mp4_reader = cv2.VideoCapture(filepath) - if not self._mp4_reader.isOpened(): - raise RuntimeError("Corrupted MP4 File") - - def set_reading_parameters( - self, - image=True, # noqa: FBT002 - concatenate_images=False, # noqa: FBT002 - resolution=(0, 0), - resize_func=None, - ): - # Save Parameters # - self.image = image - self.concatenate_images = concatenate_images - self.resolution = resolution - self.resize_func = cv2.resize - self.skip_reading = not image - if self.skip_reading: - return - - def get_frame_resolution(self): - width = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_WIDTH) - height = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_HEIGHT) - return (width, height) - - def get_frame_count(self): - if self.skip_reading: - return 0 - return int(self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_COUNT)) - - def set_frame_index(self, index): - if self.skip_reading: - return - - if index < self._index: - self._mp4_reader.set(cv2.CAP_PROP_POS_FRAMES, index - 1) - self._index = index - - while self._index < index: - self.read_camera(ignore_data=True) - - def _process_frame(self, frame): - frame = copy.deepcopy(frame) - if self.resolution == (0, 0): - return frame - return self.resize_func(frame, self.resolution) - - def read_camera(self, ignore_data=False, correct_timestamp=None): # noqa: FBT002 - # Skip if Read Unnecesary # - if self.skip_reading: - return {} - - # Read Camera # - success, frame = self._mp4_reader.read() - - self._index += 1 - if not success: - return None - if ignore_data: - return None - - # Return Data # - data_dict = {} - - if self.concatenate_images or "stereo" not in self.serial_number: - data_dict["image"] = {self.serial_number: self._process_frame(frame)} - else: - single_width = frame.shape[1] // 2 - data_dict["image"] = { - self.serial_number + "_left": self._process_frame(frame[:, :single_width, :]), - self.serial_number + "_right": self._process_frame(frame[:, single_width:, :]), - } - - return data_dict - - def disable_camera(self): - if hasattr(self, "_mp4_reader"): - self._mp4_reader.release() - - -class RecordedMultiCameraWrapper: - def __init__(self, recording_folderpath, camera_kwargs={}): # noqa: B006 - # Save Camera Info # - self.camera_kwargs = camera_kwargs - - # Open Camera Readers # - mp4_filepaths = glob.glob(recording_folderpath + "/*.mp4") - all_filepaths = mp4_filepaths - - self.camera_dict = {} - for f in all_filepaths: - serial_number = f.split("/")[-1][:-4] - cam_type = get_camera_type(serial_number) - camera_kwargs.get(cam_type, {}) - - if f.endswith(".mp4"): - Reader = MP4Reader # noqa: N806 - else: - raise ValueError - - self.camera_dict[serial_number] = Reader(f, serial_number) - - def read_cameras(self, index=None, camera_type_dict={}, timestamp_dict={}): # noqa: B006 - full_obs_dict = defaultdict(dict) - - # Read Cameras In Randomized Order # - all_cam_ids = list(self.camera_dict.keys()) - # random.shuffle(all_cam_ids) - - for cam_id in all_cam_ids: - if "stereo" in cam_id: - continue - try: - cam_type = camera_type_dict[cam_id] - except KeyError: - print(f"{self.camera_dict} -- {camera_type_dict}") - raise ValueError(f"Camera type {cam_id} not found in camera_type_dict") # noqa: B904 - curr_cam_kwargs = self.camera_kwargs.get(cam_type, {}) - self.camera_dict[cam_id].set_reading_parameters(**curr_cam_kwargs) - - timestamp = timestamp_dict.get(cam_id + "_frame_received", None) - if index is not None: - self.camera_dict[cam_id].set_frame_index(index) - - data_dict = self.camera_dict[cam_id].read_camera(correct_timestamp=timestamp) - - # Process Returned Data # - if data_dict is None: - return None - for key in data_dict: - full_obs_dict[key].update(data_dict[key]) - - return full_obs_dict - - -def get_hdf5_length(hdf5_file, keys_to_ignore=[]): # noqa: B006 - length = None - - for key in hdf5_file: - if key in keys_to_ignore: - continue - - curr_data = hdf5_file[key] - if isinstance(curr_data, h5py.Group): - curr_length = get_hdf5_length(curr_data, keys_to_ignore=keys_to_ignore) - elif isinstance(curr_data, h5py.Dataset): - curr_length = len(curr_data) - else: - raise ValueError - - if length is None: - length = curr_length - assert curr_length == length - - return length - - -def load_hdf5_to_dict(hdf5_file, index, keys_to_ignore=[]): # noqa: B006 - data_dict = {} - - for key in hdf5_file: - if key in keys_to_ignore: - continue - - curr_data = hdf5_file[key] - if isinstance(curr_data, h5py.Group): - data_dict[key] = load_hdf5_to_dict(curr_data, index, keys_to_ignore=keys_to_ignore) - elif isinstance(curr_data, h5py.Dataset): - data_dict[key] = curr_data[index] - else: - raise ValueError - - return data_dict - - -class TrajectoryReader: - def __init__(self, filepath, read_images=True): # noqa: FBT002 - self._hdf5_file = h5py.File(filepath, "r") - is_video_folder = "observations/videos" in self._hdf5_file - self._read_images = read_images and is_video_folder - self._length = get_hdf5_length(self._hdf5_file) - self._video_readers = {} - self._index = 0 - - def length(self): - return self._length - - def read_timestep(self, index=None, keys_to_ignore=[]): # noqa: B006 - # Make Sure We Read Within Range # - if index is None: - index = self._index - else: - assert not self._read_images - self._index = index - assert index < self._length - - # Load Low Dimensional Data # - keys_to_ignore = [*keys_to_ignore.copy(), "videos"] - timestep = load_hdf5_to_dict(self._hdf5_file, self._index, keys_to_ignore=keys_to_ignore) - - # Increment Read Index # - self._index += 1 - - # Return Timestep # - return timestep - - def close(self): - self._hdf5_file.close() - - -def load_trajectory( - filepath=None, - read_cameras=True, # noqa: FBT002 - recording_folderpath=None, - camera_kwargs={}, # noqa: B006 - remove_skipped_steps=False, # noqa: FBT002 - num_samples_per_traj=None, - num_samples_per_traj_coeff=1.5, -): - read_recording_folderpath = read_cameras and (recording_folderpath is not None) - - traj_reader = TrajectoryReader(filepath) - if read_recording_folderpath: - camera_reader = RecordedMultiCameraWrapper(recording_folderpath, camera_kwargs) - - horizon = traj_reader.length() - timestep_list = [] - - # Choose Timesteps To Save # - if num_samples_per_traj: - num_to_save = num_samples_per_traj - if remove_skipped_steps: - num_to_save = int(num_to_save * num_samples_per_traj_coeff) - max_size = min(num_to_save, horizon) - indices_to_save = np.sort(np.random.choice(horizon, size=max_size, replace=False)) - else: - indices_to_save = np.arange(horizon) - - # Iterate Over Trajectory # - for i in indices_to_save: - # Get HDF5 Data # - timestep = traj_reader.read_timestep(index=i) - - # If Applicable, Get Recorded Data # - if read_recording_folderpath: - timestamp_dict = timestep["observation"]["timestamp"]["cameras"] - camera_type_dict = { - k: camera_type_to_string_dict[v] for k, v in timestep["observation"]["camera_type"].items() - } - camera_obs = camera_reader.read_cameras( - index=i, camera_type_dict=camera_type_dict, timestamp_dict=timestamp_dict - ) - camera_failed = camera_obs is None - - # Add Data To Timestep If Successful # - if camera_failed: - break - timestep["observation"].update(camera_obs) - - # Filter Steps # - step_skipped = not timestep["observation"]["controller_info"].get("movement_enabled", True) - delete_skipped_step = step_skipped and remove_skipped_steps - - # Save Filtered Timesteps # - if delete_skipped_step: - del timestep - else: - timestep_list.append(timestep) - - # Remove Extra Transitions # - timestep_list = np.array(timestep_list) - if (num_samples_per_traj is not None) and (len(timestep_list) > num_samples_per_traj): - ind_to_keep = np.random.choice(len(timestep_list), size=num_samples_per_traj, replace=False) - timestep_list = timestep_list[ind_to_keep] - - # Close Readers # - traj_reader.close() - - # Return Data # - return timestep_list - - -if __name__ == "__main__": - tyro.cli(main) diff --git a/capvector-pi05/examples/droid/main.py b/capvector-pi05/examples/droid/main.py deleted file mode 100644 index bacb34da935e70910a1ca86d00f60859dd0912de..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/droid/main.py +++ /dev/null @@ -1,246 +0,0 @@ -# ruff: noqa - -import contextlib -import dataclasses -import datetime -import faulthandler -import os -import signal -import time -from moviepy.editor import ImageSequenceClip -import numpy as np -from openpi_client import image_tools -from openpi_client import websocket_client_policy -import pandas as pd -from PIL import Image -from droid.robot_env import RobotEnv -import tqdm -import tyro - -faulthandler.enable() - -# DROID data collection frequency -- we slow down execution to match this frequency -DROID_CONTROL_FREQUENCY = 15 - - -@dataclasses.dataclass -class Args: - # Hardware parameters - left_camera_id: str = "" # e.g., "24259877" - right_camera_id: str = "" # e.g., "24514023" - wrist_camera_id: str = "" # e.g., "13062452" - - # Policy parameters - external_camera: str | None = ( - None # which external camera should be fed to the policy, choose from ["left", "right"] - ) - - # Rollout parameters - max_timesteps: int = 600 - # How many actions to execute from a predicted action chunk before querying policy server again - # 8 is usually a good default (equals 0.5 seconds of action execution). - open_loop_horizon: int = 8 - - # Remote server parameters - remote_host: str = "0.0.0.0" # point this to the IP address of the policy server, e.g., "192.168.1.100" - remote_port: int = ( - 8000 # point this to the port of the policy server, default server port for openpi servers is 8000 - ) - - -# We are using Ctrl+C to optionally terminate rollouts early -- however, if we press Ctrl+C while the policy server is -# waiting for a new action chunk, it will raise an exception and the server connection dies. -# This context manager temporarily prevents Ctrl+C and delays it after the server call is complete. -@contextlib.contextmanager -def prevent_keyboard_interrupt(): - """Temporarily prevent keyboard interrupts by delaying them until after the protected code.""" - interrupted = False - original_handler = signal.getsignal(signal.SIGINT) - - def handler(signum, frame): - nonlocal interrupted - interrupted = True - - signal.signal(signal.SIGINT, handler) - try: - yield - finally: - signal.signal(signal.SIGINT, original_handler) - if interrupted: - raise KeyboardInterrupt - - -def main(args: Args): - # Make sure external camera is specified by user -- we only use one external camera for the policy - assert ( - args.external_camera is not None and args.external_camera in ["left", "right"] - ), f"Please specify an external camera to use for the policy, choose from ['left', 'right'], but got {args.external_camera}" - - # Initialize the Panda environment. Using joint velocity action space and gripper position action space is very important. - env = RobotEnv(action_space="joint_velocity", gripper_action_space="position") - print("Created the droid env!") - - # Connect to the policy server - policy_client = websocket_client_policy.WebsocketClientPolicy(args.remote_host, args.remote_port) - - df = pd.DataFrame(columns=["success", "duration", "video_filename"]) - - while True: - instruction = input("Enter instruction: ") - - # Rollout parameters - actions_from_chunk_completed = 0 - pred_action_chunk = None - - # Prepare to save video of rollout - timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H:%M:%S") - video = [] - bar = tqdm.tqdm(range(args.max_timesteps)) - print("Running rollout... press Ctrl+C to stop early.") - for t_step in bar: - start_time = time.time() - try: - # Get the current observation - curr_obs = _extract_observation( - args, - env.get_observation(), - # Save the first observation to disk - save_to_disk=t_step == 0, - ) - - video.append(curr_obs[f"{args.external_camera}_image"]) - - # Send websocket request to policy server if it's time to predict a new chunk - if actions_from_chunk_completed == 0 or actions_from_chunk_completed >= args.open_loop_horizon: - actions_from_chunk_completed = 0 - - # We resize images on the robot laptop to minimize the amount of data sent to the policy server - # and improve latency. - request_data = { - "observation/exterior_image_1_left": image_tools.resize_with_pad( - curr_obs[f"{args.external_camera}_image"], 224, 224 - ), - "observation/wrist_image_left": image_tools.resize_with_pad(curr_obs["wrist_image"], 224, 224), - "observation/joint_position": curr_obs["joint_position"], - "observation/gripper_position": curr_obs["gripper_position"], - "prompt": instruction, - } - - # Wrap the server call in a context manager to prevent Ctrl+C from interrupting it - # Ctrl+C will be handled after the server call is complete - with prevent_keyboard_interrupt(): - # this returns action chunk [10, 8] of 10 joint velocity actions (7) + gripper position (1) - pred_action_chunk = policy_client.infer(request_data)["actions"] - assert pred_action_chunk.shape == (10, 8) - - # Select current action to execute from chunk - action = pred_action_chunk[actions_from_chunk_completed] - actions_from_chunk_completed += 1 - - # Binarize gripper action - if action[-1].item() > 0.5: - # action[-1] = 1.0 - action = np.concatenate([action[:-1], np.ones((1,))]) - else: - # action[-1] = 0.0 - action = np.concatenate([action[:-1], np.zeros((1,))]) - - # clip all dimensions of action to [-1, 1] - action = np.clip(action, -1, 1) - - env.step(action) - - # Sleep to match DROID data collection frequency - elapsed_time = time.time() - start_time - if elapsed_time < 1 / DROID_CONTROL_FREQUENCY: - time.sleep(1 / DROID_CONTROL_FREQUENCY - elapsed_time) - except KeyboardInterrupt: - break - - video = np.stack(video) - save_filename = "video_" + timestamp - ImageSequenceClip(list(video), fps=10).write_videofile(save_filename + ".mp4", codec="libx264") - - success: str | float | None = None - while not isinstance(success, float): - success = input( - "Did the rollout succeed? (enter y for 100%, n for 0%), or a numeric value 0-100 based on the evaluation spec" - ) - if success == "y": - success = 1.0 - elif success == "n": - success = 0.0 - - success = float(success) / 100 - if not (0 <= success <= 1): - print(f"Success must be a number in [0, 100] but got: {success * 100}") - - df = df.append( - { - "success": success, - "duration": t_step, - "video_filename": save_filename, - }, - ignore_index=True, - ) - - if input("Do one more eval? (enter y or n) ").lower() != "y": - break - env.reset() - - os.makedirs("results", exist_ok=True) - timestamp = datetime.datetime.now().strftime("%I:%M%p_%B_%d_%Y") - csv_filename = os.path.join("results", f"eval_{timestamp}.csv") - df.to_csv(csv_filename) - print(f"Results saved to {csv_filename}") - - -def _extract_observation(args: Args, obs_dict, *, save_to_disk=False): - image_observations = obs_dict["image"] - left_image, right_image, wrist_image = None, None, None - for key in image_observations: - # Note the "left" below refers to the left camera in the stereo pair. - # The model is only trained on left stereo cams, so we only feed those. - if args.left_camera_id in key and "left" in key: - left_image = image_observations[key] - elif args.right_camera_id in key and "left" in key: - right_image = image_observations[key] - elif args.wrist_camera_id in key and "left" in key: - wrist_image = image_observations[key] - - # Drop the alpha dimension - left_image = left_image[..., :3] - right_image = right_image[..., :3] - wrist_image = wrist_image[..., :3] - - # Convert to RGB - left_image = left_image[..., ::-1] - right_image = right_image[..., ::-1] - wrist_image = wrist_image[..., ::-1] - - # In addition to image observations, also capture the proprioceptive state - robot_state = obs_dict["robot_state"] - cartesian_position = np.array(robot_state["cartesian_position"]) - joint_position = np.array(robot_state["joint_positions"]) - gripper_position = np.array([robot_state["gripper_position"]]) - - # Save the images to disk so that they can be viewed live while the robot is running - # Create one combined image to make live viewing easy - if save_to_disk: - combined_image = np.concatenate([left_image, wrist_image, right_image], axis=1) - combined_image = Image.fromarray(combined_image) - combined_image.save("robot_camera_views.png") - - return { - "left_image": left_image, - "right_image": right_image, - "wrist_image": wrist_image, - "cartesian_position": cartesian_position, - "joint_position": joint_position, - "gripper_position": gripper_position, - } - - -if __name__ == "__main__": - args: Args = tyro.cli(Args) - main(args) diff --git a/capvector-pi05/examples/inference.ipynb b/capvector-pi05/examples/inference.ipynb deleted file mode 100644 index 0f2627f148de7448474270abd20bdfb4047aa082..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/inference.ipynb +++ /dev/null @@ -1,137 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import dataclasses\n", - "\n", - "import jax\n", - "\n", - "from openpi.models import model as _model\n", - "from openpi.policies import droid_policy\n", - "from openpi.policies import policy_config as _policy_config\n", - "from openpi.shared import download\n", - "from openpi.training import config as _config\n", - "from openpi.training import data_loader as _data_loader" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Policy inference\n", - "\n", - "The following example shows how to create a policy from a checkpoint and run inference on a dummy example." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "config = _config.get_config(\"pi0_fast_droid\")\n", - "checkpoint_dir = download.maybe_download(\"gs://openpi-assets/checkpoints/pi0_fast_droid\")\n", - "\n", - "# Create a trained policy.\n", - "policy = _policy_config.create_trained_policy(config, checkpoint_dir)\n", - "\n", - "# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.\n", - "example = droid_policy.make_droid_example()\n", - "result = policy.infer(example)\n", - "\n", - "# Delete the policy to free up memory.\n", - "del policy\n", - "\n", - "print(\"Actions shape:\", result[\"actions\"].shape)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Working with a live model\n", - "\n", - "\n", - "The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "config = _config.get_config(\"pi0_aloha_sim\")\n", - "\n", - "checkpoint_dir = download.maybe_download(\"gs://openpi-assets/checkpoints/pi0_aloha_sim\")\n", - "key = jax.random.key(0)\n", - "\n", - "# Create a model from the checkpoint.\n", - "model = config.model.load(_model.restore_params(checkpoint_dir / \"params\"))\n", - "\n", - "# We can create fake observations and actions to test the model.\n", - "obs, act = config.model.fake_obs(), config.model.fake_act()\n", - "\n", - "# Sample actions from the model.\n", - "loss = model.compute_loss(key, obs, act)\n", - "print(\"Loss shape:\", loss.shape)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now, we are going to create a data loader and use a real batch of training data to compute the loss." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Reduce the batch size to reduce memory usage.\n", - "config = dataclasses.replace(config, batch_size=2)\n", - "\n", - "# Load a single batch of data. This is the same data that will be used during training.\n", - "# NOTE: In order to make this example self-contained, we are skipping the normalization step\n", - "# since it requires the normalization statistics to be generated using `compute_norm_stats`.\n", - "loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)\n", - "obs, act = next(iter(loader))\n", - "\n", - "# Sample actions from the model.\n", - "loss = model.compute_loss(key, obs, act)\n", - "\n", - "# Delete the model to free up memory.\n", - "del model\n", - "\n", - "print(\"Loss shape:\", loss.shape)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/capvector-pi05/examples/libero/Dockerfile b/capvector-pi05/examples/libero/Dockerfile deleted file mode 100644 index 3e1ed413f0441f3ab556974910fc6a7e138dd1e6..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/libero/Dockerfile +++ /dev/null @@ -1,59 +0,0 @@ -# Dockerfile for the LIBERO benchmark. - -# Build the container: -# docker build . -t libero -f examples/libero/Dockerfile - -# Run the container: -# docker run --rm -it --network=host -v .:/app -v /tmp/.X11-unix:/tmp/.X11-unix:ro -e DISPLAY=$DISPLAY --gpus all libero /bin/bash - -FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0 -COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/ - -RUN apt-get update && \ - apt-get install -y \ - make \ - g++ \ - clang \ - libosmesa6-dev \ - libgl1-mesa-glx \ - libglew-dev \ - libglfw3-dev \ - libgles2-mesa-dev \ - libglib2.0-0 \ - libsm6 \ - libxrender1 \ - libxext6 - -WORKDIR /app - -# Copy from the cache instead of linking since it's a mounted volume -ENV UV_LINK_MODE=copy - -# Write the virtual environment outside of the project directory so it doesn't -# leak out of the container when we mount the application code. -ENV UV_PROJECT_ENVIRONMENT=/.venv - -# Copy the requirements files so we can install dependencies. -# The rest of the project is mounted as a volume, so we don't need to rebuild on changes. -# This strategy is best for development-style usage. -COPY ./examples/libero/requirements.txt /tmp/requirements.txt -COPY ./third_party/libero/requirements.txt /tmp/requirements-libero.txt -COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml - -# Install python dependencies. -RUN uv venv --python 3.8 $UV_PROJECT_ENVIRONMENT -RUN uv pip sync /tmp/requirements.txt /tmp/requirements-libero.txt /tmp/openpi-client/pyproject.toml --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match -ENV PYTHONPATH=/app:/app/packages/openpi-client/src:/app/third_party/libero - -# Create a default config file to avoid an input prompt from LIBERO's init script. -# https://github.com/Lifelong-Robot-Learning/LIBERO/blob/master/libero/libero/__init__.py -ENV LIBERO_CONFIG_PATH=/tmp/libero -RUN mkdir -p /tmp/libero && cat <<'EOF' > /tmp/libero/config.yaml -benchmark_root: /app/third_party/libero/libero/libero -bddl_files: /app/third_party/libero/libero/libero/bddl_files -init_states: /app/third_party/libero/libero/libero/init_files -datasets: /app/third_party/libero/libero/datasets -assets: /app/third_party/libero/libero/libero/assets -EOF - -CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/libero/main.py $CLIENT_ARGS"] diff --git a/capvector-pi05/examples/libero/README.md b/capvector-pi05/examples/libero/README.md deleted file mode 100644 index 2e16f67368d09ae80e16ab67515907fcb034c1cc..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/libero/README.md +++ /dev/null @@ -1,71 +0,0 @@ -# LIBERO Benchmark - -This example runs the LIBERO benchmark: https://github.com/Lifelong-Robot-Learning/LIBERO - -Note: When updating requirements.txt in this directory, there is an additional flag `--extra-index-url https://download.pytorch.org/whl/cu113` that must be added to the `uv pip compile` command. - -This example requires git submodules to be initialized. Don't forget to run: - -```bash -git submodule update --init --recursive -``` - -## With Docker (recommended) - -```bash -# Grant access to the X11 server: -sudo xhost +local:docker - -# To run with the default checkpoint and task suite: -SERVER_ARGS="--env LIBERO" docker compose -f examples/libero/compose.yml up --build - -# To run with glx for Mujoco instead (use this if you have egl errors): -MUJOCO_GL=glx SERVER_ARGS="--env LIBERO" docker compose -f examples/libero/compose.yml up --build -``` - -You can customize the loaded checkpoint by providing additional `SERVER_ARGS` (see `scripts/serve_policy.py`), and the LIBERO task suite by providing additional `CLIENT_ARGS` (see `examples/libero/main.py`). -For example: - -```bash -# To load a custom checkpoint (located in the top-level openpi/ directory): -export SERVER_ARGS="--env LIBERO policy:checkpoint --policy.config pi05_libero --policy.dir ./my_custom_checkpoint" - -# To run the libero_10 task suite: -export CLIENT_ARGS="--args.task-suite-name libero_10" -``` - -## Without Docker (not recommended) - -Terminal window 1: - -```bash -# Create virtual environment -uv venv --python 3.8 examples/libero/.venv -source examples/libero/.venv/bin/activate -uv pip sync examples/libero/requirements.txt third_party/libero/requirements.txt --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match -uv pip install -e packages/openpi-client -uv pip install -e third_party/libero -export PYTHONPATH=$PYTHONPATH:$PWD/third_party/libero - -# Run the simulation -python examples/libero/main.py - -# To run with glx for Mujoco instead (use this if you have egl errors): -MUJOCO_GL=glx python examples/libero/main.py -``` - -Terminal window 2: - -```bash -# Run the server -uv run scripts/serve_policy.py --env LIBERO -``` - -## Results - -If you want to reproduce the following numbers, you can evaluate the checkpoint at `gs://openpi-assets/checkpoints/pi05_libero/`. This -checkpoint was trained in openpi with the `pi05_libero` config. - -| Model | Libero Spatial | Libero Object | Libero Goal | Libero 10 | Average | -|-------|---------------|---------------|-------------|-----------|---------| -| π0.5 @ 30k (finetuned) | 98.8 | 98.2 | 98.0 | 92.4 | 96.85 diff --git a/capvector-pi05/examples/libero/compose.yml b/capvector-pi05/examples/libero/compose.yml deleted file mode 100644 index d1ef58e8c47874829baf2dba759d49fc67b6a8e3..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/libero/compose.yml +++ /dev/null @@ -1,54 +0,0 @@ -# Run with: -# docker compose -f examples/libero/compose.yml up --build -services: - runtime: - image: libero - depends_on: - - openpi_server - build: - context: ../.. - dockerfile: examples/libero/Dockerfile - init: true - tty: true - network_mode: host - privileged: true - volumes: - - $PWD:/app - - ../../data:/data - - /tmp/.X11-unix:/tmp/.X11-unix:ro - environment: - - CLIENT_ARGS - - DISPLAY=$DISPLAY - - MUJOCO_GL=${MUJOCO_GL:-egl} - deploy: - resources: - reservations: - devices: - - driver: nvidia - count: 1 - capabilities: [gpu] - - openpi_server: - image: openpi_server - build: - context: ../.. - dockerfile: scripts/docker/serve_policy.Dockerfile - init: true - tty: true - network_mode: host - volumes: - - $PWD:/app - - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets - environment: - - SERVER_ARGS - - OPENPI_DATA_HOME=/openpi_assets - - IS_DOCKER=true - - # Comment out this block if not running on a machine with GPUs. - deploy: - resources: - reservations: - devices: - - driver: nvidia - count: 1 - capabilities: [gpu] diff --git a/capvector-pi05/examples/libero/convert_libero_data_to_lerobot.py b/capvector-pi05/examples/libero/convert_libero_data_to_lerobot.py deleted file mode 100644 index 54a390fb7826c767b30de83c0a2cb03cc15c4e55..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/libero/convert_libero_data_to_lerobot.py +++ /dev/null @@ -1,104 +0,0 @@ -""" -Minimal example script for converting a dataset to LeRobot format. - -We use the Libero dataset (stored in RLDS) for this example, but it can be easily -modified for any other data you have saved in a custom format. - -Usage: -uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data - -If you want to push your dataset to the Hugging Face Hub, you can use the following command: -uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub - -Note: to run the script, you need to install tensorflow_datasets: -`uv pip install tensorflow tensorflow_datasets` - -You can download the raw Libero datasets from https://huggingface.co/datasets/openvla/modified_libero_rlds -The resulting dataset will get saved to the $HF_LEROBOT_HOME directory. -Running this conversion script will take approximately 30 minutes. -""" - -import shutil - -from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -import tensorflow_datasets as tfds -import tyro - -REPO_NAME = "your_hf_username/libero" # Name of the output dataset, also used for the Hugging Face Hub -RAW_DATASET_NAMES = [ - "libero_10_no_noops", - "libero_goal_no_noops", - "libero_object_no_noops", - "libero_spatial_no_noops", -] # For simplicity we will combine multiple Libero datasets into one training dataset - - -def main(data_dir: str, *, push_to_hub: bool = False): - # Clean up any existing dataset in the output directory - output_path = HF_LEROBOT_HOME / REPO_NAME - if output_path.exists(): - shutil.rmtree(output_path) - - # Create LeRobot dataset, define features to store - # OpenPi assumes that proprio is stored in `state` and actions in `action` - # LeRobot assumes that dtype of image data is `image` - dataset = LeRobotDataset.create( - repo_id=REPO_NAME, - robot_type="panda", - fps=10, - features={ - "image": { - "dtype": "image", - "shape": (256, 256, 3), - "names": ["height", "width", "channel"], - }, - "wrist_image": { - "dtype": "image", - "shape": (256, 256, 3), - "names": ["height", "width", "channel"], - }, - "state": { - "dtype": "float32", - "shape": (8,), - "names": ["state"], - }, - "actions": { - "dtype": "float32", - "shape": (7,), - "names": ["actions"], - }, - }, - image_writer_threads=10, - image_writer_processes=5, - ) - - # Loop over raw Libero datasets and write episodes to the LeRobot dataset - # You can modify this for your own data format - for raw_dataset_name in RAW_DATASET_NAMES: - raw_dataset = tfds.load(raw_dataset_name, data_dir=data_dir, split="train") - for episode in raw_dataset: - for step in episode["steps"].as_numpy_iterator(): - dataset.add_frame( - { - "image": step["observation"]["image"], - "wrist_image": step["observation"]["wrist_image"], - "state": step["observation"]["state"], - "actions": step["action"], - "task": step["language_instruction"].decode(), - } - ) - dataset.save_episode() - - # Optionally push to the Hugging Face Hub - if push_to_hub: - dataset.push_to_hub( - tags=["libero", "panda", "rlds"], - private=False, - push_videos=True, - license="apache-2.0", - ) - - -if __name__ == "__main__": - tyro.cli(main) diff --git a/capvector-pi05/examples/libero/main.py b/capvector-pi05/examples/libero/main.py deleted file mode 100644 index 2a1ab94db9aac8288708455d3404e03f08e5383f..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/libero/main.py +++ /dev/null @@ -1,219 +0,0 @@ -import collections -import dataclasses -import logging -import math -import pathlib - -import imageio -from libero.libero import benchmark -from libero.libero import get_libero_path -from libero.libero.envs import OffScreenRenderEnv -import numpy as np -from openpi_client import image_tools -from openpi_client import websocket_client_policy as _websocket_client_policy -import tqdm -import tyro - -LIBERO_DUMMY_ACTION = [0.0] * 6 + [-1.0] -LIBERO_ENV_RESOLUTION = 256 # resolution used to render training data - - -@dataclasses.dataclass -class Args: - ################################################################################################################# - # Model server parameters - ################################################################################################################# - host: str = "0.0.0.0" - port: int = 8000 - resize_size: int = 224 - replan_steps: int = 5 - - ################################################################################################################# - # LIBERO environment-specific parameters - ################################################################################################################# - task_suite_name: str = ( - "libero_spatial" # Task suite. Options: libero_spatial, libero_object, libero_goal, libero_10, libero_90 - ) - num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize i n sim - num_trials_per_task: int = 50 # Number of rollouts per task - - ################################################################################################################# - # Utils - ################################################################################################################# - video_out_path: str = "data/libero/videos" # Path to save videos - - seed: int = 7 # Random Seed (for reproducibility) - - -def eval_libero(args: Args) -> None: - # Set random seed - np.random.seed(args.seed) - - # Initialize LIBERO task suite - benchmark_dict = benchmark.get_benchmark_dict() - task_suite = benchmark_dict[args.task_suite_name]() - num_tasks_in_suite = task_suite.n_tasks - logging.info(f"Task suite: {args.task_suite_name}") - - pathlib.Path(args.video_out_path).mkdir(parents=True, exist_ok=True) - - if args.task_suite_name == "libero_spatial": - max_steps = 220 # longest training demo has 193 steps - elif args.task_suite_name == "libero_object": - max_steps = 280 # longest training demo has 254 steps - elif args.task_suite_name == "libero_goal": - max_steps = 300 # longest training demo has 270 steps - elif args.task_suite_name == "libero_10": - max_steps = 520 # longest training demo has 505 steps - elif args.task_suite_name == "libero_90": - max_steps = 400 # longest training demo has 373 steps - else: - raise ValueError(f"Unknown task suite: {args.task_suite_name}") - - client = _websocket_client_policy.WebsocketClientPolicy(args.host, args.port) - - # Start evaluation - total_episodes, total_successes = 0, 0 - for task_id in tqdm.tqdm(range(num_tasks_in_suite)): - # Get task - task = task_suite.get_task(task_id) - - # Get default LIBERO initial states - initial_states = task_suite.get_task_init_states(task_id) - - # Initialize LIBERO environment and task description - env, task_description = _get_libero_env(task, LIBERO_ENV_RESOLUTION, args.seed) - - # Start episodes - task_episodes, task_successes = 0, 0 - for episode_idx in tqdm.tqdm(range(args.num_trials_per_task)): - logging.info(f"\nTask: {task_description}") - - # Reset environment - env.reset() - action_plan = collections.deque() - - # Set initial states - obs = env.set_init_state(initial_states[episode_idx]) - - # Setup - t = 0 - replay_images = [] - - logging.info(f"Starting episode {task_episodes+1}...") - while t < max_steps + args.num_steps_wait: - try: - # IMPORTANT: Do nothing for the first few timesteps because the simulator drops objects - # and we need to wait for them to fall - if t < args.num_steps_wait: - obs, reward, done, info = env.step(LIBERO_DUMMY_ACTION) - t += 1 - continue - - # Get preprocessed image - # IMPORTANT: rotate 180 degrees to match train preprocessing - img = np.ascontiguousarray(obs["agentview_image"][::-1, ::-1]) - wrist_img = np.ascontiguousarray(obs["robot0_eye_in_hand_image"][::-1, ::-1]) - img = image_tools.convert_to_uint8( - image_tools.resize_with_pad(img, args.resize_size, args.resize_size) - ) - wrist_img = image_tools.convert_to_uint8( - image_tools.resize_with_pad(wrist_img, args.resize_size, args.resize_size) - ) - - # Save preprocessed image for replay video - replay_images.append(img) - - if not action_plan: - # Finished executing previous action chunk -- compute new chunk - # Prepare observations dict - element = { - "observation/image": img, - "observation/wrist_image": wrist_img, - "observation/state": np.concatenate( - ( - obs["robot0_eef_pos"], - _quat2axisangle(obs["robot0_eef_quat"]), - obs["robot0_gripper_qpos"], - ) - ), - "prompt": str(task_description), - } - - # Query model to get action - action_chunk = client.infer(element)["actions"] - assert ( - len(action_chunk) >= args.replan_steps - ), f"We want to replan every {args.replan_steps} steps, but policy only predicts {len(action_chunk)} steps." - action_plan.extend(action_chunk[: args.replan_steps]) - - action = action_plan.popleft() - - # Execute action in environment - obs, reward, done, info = env.step(action.tolist()) - if done: - task_successes += 1 - total_successes += 1 - break - t += 1 - - except Exception as e: - logging.error(f"Caught exception: {e}") - break - - task_episodes += 1 - total_episodes += 1 - - # Save a replay video of the episode - suffix = "success" if done else "failure" - task_segment = task_description.replace(" ", "_") - imageio.mimwrite( - pathlib.Path(args.video_out_path) / f"rollout_{task_segment}_{suffix}.mp4", - [np.asarray(x) for x in replay_images], - fps=10, - ) - - # Log current results - logging.info(f"Success: {done}") - logging.info(f"# episodes completed so far: {total_episodes}") - logging.info(f"# successes: {total_successes} ({total_successes / total_episodes * 100:.1f}%)") - - # Log final results - logging.info(f"Current task success rate: {float(task_successes) / float(task_episodes)}") - logging.info(f"Current total success rate: {float(total_successes) / float(total_episodes)}") - - logging.info(f"Total success rate: {float(total_successes) / float(total_episodes)}") - logging.info(f"Total episodes: {total_episodes}") - - -def _get_libero_env(task, resolution, seed): - """Initializes and returns the LIBERO environment, along with the task description.""" - task_description = task.language - task_bddl_file = pathlib.Path(get_libero_path("bddl_files")) / task.problem_folder / task.bddl_file - env_args = {"bddl_file_name": task_bddl_file, "camera_heights": resolution, "camera_widths": resolution} - env = OffScreenRenderEnv(**env_args) - env.seed(seed) # IMPORTANT: seed seems to affect object positions even when using fixed initial state - return env, task_description - - -def _quat2axisangle(quat): - """ - Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55 - """ - # clip quaternion - if quat[3] > 1.0: - quat[3] = 1.0 - elif quat[3] < -1.0: - quat[3] = -1.0 - - den = np.sqrt(1.0 - quat[3] * quat[3]) - if math.isclose(den, 0.0): - # This is (close to) a zero degree rotation, immediately return - return np.zeros(3) - - return (quat[:3] * 2.0 * math.acos(quat[3])) / den - - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) - tyro.cli(eval_libero) diff --git a/capvector-pi05/examples/libero/requirements.in b/capvector-pi05/examples/libero/requirements.in deleted file mode 100644 index d9fd2275d739216c453e67fbe3c060ccec56cca4..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/libero/requirements.in +++ /dev/null @@ -1,11 +0,0 @@ -imageio[ffmpeg] -numpy==1.22.4 -tqdm -tyro -PyYaml -opencv-python==4.6.0.66 -torch==1.11.0+cu113 -torchvision==0.12.0+cu113 -torchaudio==0.11.0+cu113 -robosuite==1.4.1 -matplotlib==3.5.3 diff --git a/capvector-pi05/examples/libero/requirements.txt b/capvector-pi05/examples/libero/requirements.txt deleted file mode 100644 index 9123401789ca3740077de1a56842db145a697781..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/libero/requirements.txt +++ /dev/null @@ -1,136 +0,0 @@ -# This file was autogenerated by uv via the following command: -# uv pip compile examples/libero/requirements.in -o examples/libero/requirements.txt --python-version 3.8 --index-strategy=unsafe-best-match -absl-py==2.1.0 - # via mujoco -certifi==2024.12.14 - # via requests -charset-normalizer==3.4.0 - # via requests -cycler==0.12.1 - # via matplotlib -docstring-parser==0.16 - # via tyro -etils==1.3.0 - # via mujoco -eval-type-backport==0.2.0 - # via tyro -evdev==1.7.1 - # via pynput -fonttools==4.55.3 - # via matplotlib -glfw==1.12.0 - # via mujoco -idna==3.10 - # via requests -imageio==2.35.1 - # via -r examples/libero/requirements.in -imageio-ffmpeg==0.5.1 - # via imageio -importlib-metadata==8.5.0 - # via typeguard -importlib-resources==6.4.5 - # via etils -kiwisolver==1.4.7 - # via matplotlib -llvmlite==0.36.0 - # via numba -markdown-it-py==3.0.0 - # via rich -matplotlib==3.5.3 - # via -r examples/libero/requirements.in -mdurl==0.1.2 - # via markdown-it-py -mujoco==3.2.3 - # via robosuite -numba==0.53.1 - # via robosuite -numpy==1.22.4 - # via - # -r examples/libero/requirements.in - # imageio - # matplotlib - # mujoco - # numba - # opencv-python - # robosuite - # scipy - # torchvision -opencv-python==4.6.0.66 - # via - # -r examples/libero/requirements.in - # robosuite -packaging==24.2 - # via matplotlib -pillow==10.4.0 - # via - # imageio - # matplotlib - # robosuite - # torchvision -psutil==6.1.0 - # via imageio -pygments==2.18.0 - # via rich -pynput==1.7.7 - # via robosuite -pyopengl==3.1.7 - # via mujoco -pyparsing==3.1.4 - # via matplotlib -python-dateutil==2.9.0.post0 - # via matplotlib -python-xlib==0.33 - # via pynput -pyyaml==6.0.2 - # via -r examples/libero/requirements.in -requests==2.32.3 - # via torchvision -rich==13.9.4 - # via tyro -robosuite==1.4.1 - # via -r examples/libero/requirements.in -scipy==1.10.1 - # via robosuite -setuptools==75.3.0 - # via - # imageio-ffmpeg - # numba -shtab==1.7.1 - # via tyro -six==1.17.0 - # via - # pynput - # python-dateutil - # python-xlib -termcolor==2.4.0 - # via robosuite -torch==1.11.0+cu113 - # via - # -r examples/libero/requirements.in - # torchaudio - # torchvision -torchaudio==0.11.0+cu113 - # via -r examples/libero/requirements.in -torchvision==0.12.0+cu113 - # via -r examples/libero/requirements.in -tqdm==4.67.1 - # via -r examples/libero/requirements.in -typeguard==4.4.0 - # via tyro -typing-extensions==4.12.2 - # via - # etils - # rich - # torch - # torchvision - # typeguard - # tyro -tyro==0.9.2 - # via -r examples/libero/requirements.in -urllib3==2.2.3 - # via requests -zipp==3.20.2 - # via - # etils - # importlib-metadata - # importlib-resources diff --git a/capvector-pi05/examples/policy_records.ipynb b/capvector-pi05/examples/policy_records.ipynb deleted file mode 100644 index f83575b47ff30a697dfbf8e76015f90dec898538..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/policy_records.ipynb +++ /dev/null @@ -1,134 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import pathlib\n", - "\n", - "import numpy as np\n", - "\n", - "record_path = pathlib.Path(\"../policy_records\")\n", - "num_steps = len(list(record_path.glob(\"step_*.npy\")))\n", - "\n", - "records = []\n", - "for i in range(num_steps):\n", - " record = np.load(record_path / f\"step_{i}.npy\", allow_pickle=True).item()\n", - " records.append(record)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(\"length of records\", len(records))\n", - "print(\"keys in records\", records[0].keys())\n", - "\n", - "for k in records[0]:\n", - " print(f\"{k} shape: {records[0][k].shape}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from PIL import Image\n", - "\n", - "\n", - "def get_image(step: int, idx: int = 0):\n", - " img = (255 * records[step][\"inputs/image\"]).astype(np.uint8)\n", - " return img[idx].transpose(1, 2, 0)\n", - "\n", - "\n", - "def show_image(step: int, idx_lst: list[int]):\n", - " imgs = [get_image(step, idx) for idx in idx_lst]\n", - " return Image.fromarray(np.hstack(imgs))\n", - "\n", - "\n", - "for i in range(2):\n", - " display(show_image(i, [0]))" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd\n", - "\n", - "\n", - "def get_axis(name, axis):\n", - " return np.array([record[name][axis] for record in records])\n", - "\n", - "\n", - "# qpos is [..., 14] of type float:\n", - "# 0-5: left arm joint angles\n", - "# 6: left arm gripper\n", - "# 7-12: right arm joint angles\n", - "# 13: right arm gripper\n", - "names = [(\"left_joint\", 6), (\"left_gripper\", 1), (\"right_joint\", 6), (\"right_gripper\", 1)]\n", - "\n", - "\n", - "def make_data():\n", - " cur_dim = 0\n", - " in_data = {}\n", - " out_data = {}\n", - " for name, dim_size in names:\n", - " for i in range(dim_size):\n", - " in_data[f\"{name}_{i}\"] = get_axis(\"inputs/qpos\", cur_dim)\n", - " out_data[f\"{name}_{i}\"] = get_axis(\"outputs/qpos\", cur_dim)\n", - " cur_dim += 1\n", - " return pd.DataFrame(in_data), pd.DataFrame(out_data)\n", - "\n", - "\n", - "in_data, out_data = make_data()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for name in in_data.columns:\n", - " data = pd.DataFrame({f\"in_{name}\": in_data[name], f\"out_{name}\": out_data[name]})\n", - " data.plot()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/capvector-pi05/examples/simple_client/Dockerfile b/capvector-pi05/examples/simple_client/Dockerfile deleted file mode 100644 index 095712073f98d5dd9b639d6d365698b506d15827..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/simple_client/Dockerfile +++ /dev/null @@ -1,32 +0,0 @@ -# Dockerfile for the simple client. - -# Build the container: -# docker build . -t simple_client -f examples/simple_client/Dockerfile - -# Run the container: -# docker run --rm -it --network=host -v .:/app simple_client /bin/bash - -FROM python:3.7-slim -COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/ - -WORKDIR /app - -# Copy from the cache instead of linking since it's a mounted volume -ENV UV_LINK_MODE=copy - -# Write the virtual environment outside of the project directory so it doesn't -# leak out of the container when we mount the application code. -ENV UV_PROJECT_ENVIRONMENT=/.venv - -# Copy the requirements files so we can install dependencies. -# The rest of the project is mounted as a volume, so we don't need to rebuild on changes. -# This strategy is best for development-style usage. -COPY ./examples/simple_client/requirements.txt /tmp/requirements.txt -COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml - -# Install python dependencies. -RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT -RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml -ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src - -CMD /bin/bash -c "source /.venv/bin/activate && python examples/simple_client/main.py $SERVER_ARGS" diff --git a/capvector-pi05/examples/simple_client/README.md b/capvector-pi05/examples/simple_client/README.md deleted file mode 100644 index ea2fe5050c25cbfccea3f1c5c99c8b6ef0961caa..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/simple_client/README.md +++ /dev/null @@ -1,30 +0,0 @@ -# Simple Client - -A minimal client that sends observations to the server and prints the inference rate. - -You can specify which runtime environment to use using the `--env` flag. You can see the available options by running: - -```bash -uv run examples/simple_client/main.py --help -``` - -## With Docker - -```bash -export SERVER_ARGS="--env ALOHA_SIM" -docker compose -f examples/simple_client/compose.yml up --build -``` - -## Without Docker - -Terminal window 1: - -```bash -uv run examples/simple_client/main.py --env DROID -``` - -Terminal window 2: - -```bash -uv run scripts/serve_policy.py --env DROID -``` diff --git a/capvector-pi05/examples/simple_client/compose.yml b/capvector-pi05/examples/simple_client/compose.yml deleted file mode 100644 index 109821bb445ff61c0ccb695fb5c43a1eb4220005..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/simple_client/compose.yml +++ /dev/null @@ -1,42 +0,0 @@ -# Run with: -# docker compose -f examples/simple_client/compose.yml up --build -services: - runtime: - image: simple_client - depends_on: - - openpi_server - build: - context: ../.. - dockerfile: examples/simple_client/Dockerfile - init: true - tty: true - network_mode: host - volumes: - - $PWD:/app - environment: - - SERVER_ARGS - - openpi_server: - image: openpi_server - build: - context: ../.. - dockerfile: scripts/docker/serve_policy.Dockerfile - init: true - tty: true - network_mode: host - volumes: - - $PWD:/app - - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets - environment: - - SERVER_ARGS - - OPENPI_DATA_HOME=/openpi_assets - - IS_DOCKER=true - - # Comment out this block if not running on a machine with GPUs. - deploy: - resources: - reservations: - devices: - - driver: nvidia - count: 1 - capabilities: [gpu] diff --git a/capvector-pi05/examples/simple_client/main.py b/capvector-pi05/examples/simple_client/main.py deleted file mode 100644 index 3907706b164e55d04e948cae6620d940c69ca7fe..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/simple_client/main.py +++ /dev/null @@ -1,187 +0,0 @@ -import dataclasses -import enum -import logging -import pathlib -import time - -import numpy as np -from openpi_client import websocket_client_policy as _websocket_client_policy -import polars as pl -import rich -import tqdm -import tyro - -logger = logging.getLogger(__name__) - - -class EnvMode(enum.Enum): - """Supported environments.""" - - ALOHA = "aloha" - ALOHA_SIM = "aloha_sim" - DROID = "droid" - LIBERO = "libero" - - -@dataclasses.dataclass -class Args: - """Command line arguments.""" - - # Host and port to connect to the server. - host: str = "0.0.0.0" - # Port to connect to the server. If None, the server will use the default port. - port: int | None = 8000 - # API key to use for the server. - api_key: str | None = None - # Number of steps to run the policy for. - num_steps: int = 20 - # Path to save the timings to a parquet file. (e.g., timing.parquet) - timing_file: pathlib.Path | None = None - # Environment to run the policy in. - env: EnvMode = EnvMode.ALOHA_SIM - - -class TimingRecorder: - """Records timing measurements for different keys.""" - - def __init__(self) -> None: - self._timings: dict[str, list[float]] = {} - - def record(self, key: str, time_ms: float) -> None: - """Record a timing measurement for the given key.""" - if key not in self._timings: - self._timings[key] = [] - self._timings[key].append(time_ms) - - def get_stats(self, key: str) -> dict[str, float]: - """Get statistics for the given key.""" - times = self._timings[key] - return { - "mean": float(np.mean(times)), - "std": float(np.std(times)), - "p25": float(np.quantile(times, 0.25)), - "p50": float(np.quantile(times, 0.50)), - "p75": float(np.quantile(times, 0.75)), - "p90": float(np.quantile(times, 0.90)), - "p95": float(np.quantile(times, 0.95)), - "p99": float(np.quantile(times, 0.99)), - } - - def print_all_stats(self) -> None: - """Print statistics for all keys in a concise format.""" - - table = rich.table.Table( - title="[bold blue]Timing Statistics[/bold blue]", - show_header=True, - header_style="bold white", - border_style="blue", - title_justify="center", - ) - - # Add metric column with custom styling - table.add_column("Metric", style="cyan", justify="left", no_wrap=True) - - # Add statistical columns with consistent styling - stat_columns = [ - ("Mean", "yellow", "mean"), - ("Std", "yellow", "std"), - ("P25", "magenta", "p25"), - ("P50", "magenta", "p50"), - ("P75", "magenta", "p75"), - ("P90", "magenta", "p90"), - ("P95", "magenta", "p95"), - ("P99", "magenta", "p99"), - ] - - for name, style, _ in stat_columns: - table.add_column(name, justify="right", style=style, no_wrap=True) - - # Add rows for each metric with formatted values - for key in sorted(self._timings.keys()): - stats = self.get_stats(key) - values = [f"{stats[key]:.1f}" for _, _, key in stat_columns] - table.add_row(key, *values) - - # Print with custom console settings - console = rich.console.Console(width=None, highlight=True) - console.print(table) - - def write_parquet(self, path: pathlib.Path) -> None: - """Save the timings to a parquet file.""" - logger.info(f"Writing timings to {path}") - frame = pl.DataFrame(self._timings) - path.parent.mkdir(parents=True, exist_ok=True) - frame.write_parquet(path) - - -def main(args: Args) -> None: - obs_fn = { - EnvMode.ALOHA: _random_observation_aloha, - EnvMode.ALOHA_SIM: _random_observation_aloha, - EnvMode.DROID: _random_observation_droid, - EnvMode.LIBERO: _random_observation_libero, - }[args.env] - - policy = _websocket_client_policy.WebsocketClientPolicy( - host=args.host, - port=args.port, - api_key=args.api_key, - ) - logger.info(f"Server metadata: {policy.get_server_metadata()}") - - # Send a few observations to make sure the model is loaded. - for _ in range(2): - policy.infer(obs_fn()) - - timing_recorder = TimingRecorder() - - for _ in tqdm.trange(args.num_steps, desc="Running policy"): - inference_start = time.time() - action = policy.infer(obs_fn()) - timing_recorder.record("client_infer_ms", 1000 * (time.time() - inference_start)) - for key, value in action.get("server_timing", {}).items(): - timing_recorder.record(f"server_{key}", value) - for key, value in action.get("policy_timing", {}).items(): - timing_recorder.record(f"policy_{key}", value) - - timing_recorder.print_all_stats() - - if args.timing_file is not None: - timing_recorder.write_parquet(args.timing_file) - - -def _random_observation_aloha() -> dict: - return { - "state": np.ones((14,)), - "images": { - "cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), - "cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), - "cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), - "cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), - }, - "prompt": "do something", - } - - -def _random_observation_droid() -> dict: - return { - "observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), - "observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), - "observation/joint_position": np.random.rand(7), - "observation/gripper_position": np.random.rand(1), - "prompt": "do something", - } - - -def _random_observation_libero() -> dict: - return { - "observation/state": np.random.rand(8), - "observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), - "observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), - "prompt": "do something", - } - - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) - main(tyro.cli(Args)) diff --git a/capvector-pi05/examples/simple_client/requirements.in b/capvector-pi05/examples/simple_client/requirements.in deleted file mode 100644 index 17ef4aef112d274624eba0503e00cc4aec44f7a6..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/simple_client/requirements.in +++ /dev/null @@ -1,5 +0,0 @@ -numpy>=1.22.4,<2.0.0 -rich -tqdm -tyro -polars \ No newline at end of file diff --git a/capvector-pi05/examples/simple_client/requirements.txt b/capvector-pi05/examples/simple_client/requirements.txt deleted file mode 100644 index 416d9cd72e32f4c35c93dd0ac8e5d2144dc52513..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/simple_client/requirements.txt +++ /dev/null @@ -1,30 +0,0 @@ -# This file was autogenerated by uv via the following command: -# uv pip compile examples/simple_client/requirements.in -o examples/simple_client/requirements.txt --python-version 3.11.9 -docstring-parser==0.16 - # via tyro -markdown-it-py==3.0.0 - # via rich -mdurl==0.1.2 - # via markdown-it-py -numpy==1.26.4 - # via -r examples/simple_client/requirements.in -polars==1.30.0 - # via -r examples/simple_client/requirements.in -pygments==2.19.1 - # via rich -rich==14.0.0 - # via - # -r examples/simple_client/requirements.in - # tyro -shtab==1.7.2 - # via tyro -tqdm==4.67.1 - # via -r examples/simple_client/requirements.in -typeguard==4.4.2 - # via tyro -typing-extensions==4.13.2 - # via - # typeguard - # tyro -tyro==0.9.22 - # via -r examples/simple_client/requirements.in diff --git a/capvector-pi05/examples/ur5/README.md b/capvector-pi05/examples/ur5/README.md deleted file mode 100644 index e90ca6c3bc8d1135e92ae55651639305bc0e209f..0000000000000000000000000000000000000000 --- a/capvector-pi05/examples/ur5/README.md +++ /dev/null @@ -1,142 +0,0 @@ -# UR5 Example - -Below we provide an outline of how to implement the key components mentioned in the "Finetune on your data" section of the [README](../README.md) for finetuning on UR5 datasets. - -First, we will define the `UR5Inputs` and `UR5Outputs` classes, which map the UR5 environment to the model and vice versa. Check the corresponding files in `src/openpi/policies/libero_policy.py` for comments explaining each line. - -```python - -@dataclasses.dataclass(frozen=True) -class UR5Inputs(transforms.DataTransformFn): - - model_type: _model.ModelType = _model.ModelType.PI0 - - def __call__(self, data: dict) -> dict: - # First, concatenate the joints and gripper into the state vector. - state = np.concatenate([data["joints"], data["gripper"]]) - - # Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically - # stores as float32 (C,H,W), gets skipped for policy inference. - base_image = _parse_image(data["base_rgb"]) - wrist_image = _parse_image(data["wrist_rgb"]) - - # Create inputs dict. - inputs = { - "state": state, - "image": { - "base_0_rgb": base_image, - "left_wrist_0_rgb": wrist_image, - # Since there is no right wrist, replace with zeros - "right_wrist_0_rgb": np.zeros_like(base_image), - }, - "image_mask": { - "base_0_rgb": np.True_, - "left_wrist_0_rgb": np.True_, - # Since the "slot" for the right wrist is not used, this mask is set - # to False - "right_wrist_0_rgb": np.True_ if self.model_type == _model.ModelType.PI0_FAST else np.False_, - }, - } - - if "actions" in data: - inputs["actions"] = data["actions"] - - # Pass the prompt (aka language instruction) to the model. - if "prompt" in data: - inputs["prompt"] = data["prompt"] - - return inputs - - -@dataclasses.dataclass(frozen=True) -class UR5Outputs(transforms.DataTransformFn): - - def __call__(self, data: dict) -> dict: - # Since the robot has 7 action dimensions (6 DoF + gripper), return the first 7 dims - return {"actions": np.asarray(data["actions"][:, :7])} - -``` - -Next, we will define the `UR5DataConfig` class, which defines how to process raw UR5 data from LeRobot dataset for training. For a full example, see the `LeRobotLiberoDataConfig` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py). - -```python - -@dataclasses.dataclass(frozen=True) -class LeRobotUR5DataConfig(DataConfigFactory): - - @override - def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig: - # Boilerplate for remapping keys from the LeRobot dataset. We assume no renaming needed here. - repack_transform = _transforms.Group( - inputs=[ - _transforms.RepackTransform( - { - "base_rgb": "image", - "wrist_rgb": "wrist_image", - "joints": "joints", - "gripper": "gripper", - "prompt": "prompt", - } - ) - ] - ) - - # These transforms are the ones we wrote earlier. - data_transforms = _transforms.Group( - inputs=[UR5Inputs(action_dim=model_config.action_dim, model_type=model_config.model_type)], - outputs=[UR5Outputs()], - ) - - # Convert absolute actions to delta actions. - # By convention, we do not convert the gripper action (7th dimension). - delta_action_mask = _transforms.make_bool_mask(6, -1) - data_transforms = data_transforms.push( - inputs=[_transforms.DeltaActions(delta_action_mask)], - outputs=[_transforms.AbsoluteActions(delta_action_mask)], - ) - - # Model transforms include things like tokenizing the prompt and action targets - # You do not need to change anything here for your own dataset. - model_transforms = ModelTransformFactory()(model_config) - - # We return all data transforms for training and inference. No need to change anything here. - return dataclasses.replace( - self.create_base_config(assets_dirs), - repack_transforms=repack_transform, - data_transforms=data_transforms, - model_transforms=model_transforms, - ) - -``` - -Finally, we define the TrainConfig for our UR5 dataset. Here, we define a config for fine-tuning pi0 on our UR5 dataset. See the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py) for more examples, e.g. for pi0-FAST or for LoRA fine-tuning. - -```python -TrainConfig( - name="pi0_ur5", - model=pi0.Pi0Config(), - data=LeRobotUR5DataConfig( - repo_id="your_username/ur5_dataset", - # This config lets us reload the UR5 normalization stats from the base model checkpoint. - # Reloading normalization stats can help transfer pre-trained models to new environments. - # See the [norm_stats.md](../docs/norm_stats.md) file for more details. - assets=AssetsConfig( - assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets", - asset_id="ur5e", - ), - base_config=DataConfig( - # This flag determines whether we load the prompt (i.e. the task instruction) from the - # ``task`` field in the LeRobot dataset. The recommended setting is True. - prompt_from_task=True, - ), - ), - # Load the pi0 base model checkpoint. - weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"), - num_train_steps=30_000, -) -``` - - - - - diff --git a/capvector-pi05/packages/openpi-client/pyproject.toml b/capvector-pi05/packages/openpi-client/pyproject.toml deleted file mode 100644 index 123c066e6e79a6d1bbc8e385ab99422d58e16acc..0000000000000000000000000000000000000000 --- a/capvector-pi05/packages/openpi-client/pyproject.toml +++ /dev/null @@ -1,23 +0,0 @@ -[project] -name = "openpi-client" -version = "0.1.0" -requires-python = ">=3.7" -dependencies = [ - "dm-tree>=0.1.8", - "msgpack>=1.0.5", - "numpy>=1.22.4,<2.0.0", - "pillow>=9.0.0", - "tree>=0.2.4", - "websockets>=11.0", -] - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.uv] -dev-dependencies = ["pytest>=8.3.4"] - -[tool.ruff] -line-length = 120 -target-version = "py37" diff --git a/capvector-pi05/packages/openpi-client/src/openpi_client/__init__.py b/capvector-pi05/packages/openpi-client/src/openpi_client/__init__.py deleted file mode 100644 index 3f5c4a7d6e309ba9807642ee936d82cbc458017e..0000000000000000000000000000000000000000 --- a/capvector-pi05/packages/openpi-client/src/openpi_client/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = "0.1.0" diff --git a/capvector-pi05/packages/openpi-client/src/openpi_client/action_chunk_broker.py b/capvector-pi05/packages/openpi-client/src/openpi_client/action_chunk_broker.py deleted file mode 100644 index 9445a66815e15ee32ceb033d5a481b58053783fb..0000000000000000000000000000000000000000 --- a/capvector-pi05/packages/openpi-client/src/openpi_client/action_chunk_broker.py +++ /dev/null @@ -1,50 +0,0 @@ -from typing import Dict - -import numpy as np -import tree -from typing_extensions import override - -from openpi_client import base_policy as _base_policy - - -class ActionChunkBroker(_base_policy.BasePolicy): - """Wraps a policy to return action chunks one-at-a-time. - - Assumes that the first dimension of all action fields is the chunk size. - - A new inference call to the inner policy is only made when the current - list of chunks is exhausted. - """ - - def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int): - self._policy = policy - self._action_horizon = action_horizon - self._cur_step: int = 0 - - self._last_results: Dict[str, np.ndarray] | None = None - - @override - def infer(self, obs: Dict) -> Dict: # noqa: UP006 - if self._last_results is None: - self._last_results = self._policy.infer(obs) - self._cur_step = 0 - - def slicer(x): - if isinstance(x, np.ndarray): - return x[self._cur_step, ...] - else: - return x - - results = tree.map_structure(slicer, self._last_results) - self._cur_step += 1 - - if self._cur_step >= self._action_horizon: - self._last_results = None - - return results - - @override - def reset(self) -> None: - self._policy.reset() - self._last_results = None - self._cur_step = 0 diff --git a/capvector-pi05/packages/openpi-client/src/openpi_client/base_policy.py b/capvector-pi05/packages/openpi-client/src/openpi_client/base_policy.py deleted file mode 100644 index 1b14963fe90508b804b04c6480d82d5b3e2b5ca3..0000000000000000000000000000000000000000 --- a/capvector-pi05/packages/openpi-client/src/openpi_client/base_policy.py +++ /dev/null @@ -1,12 +0,0 @@ -import abc -from typing import Dict - - -class BasePolicy(abc.ABC): - @abc.abstractmethod - def infer(self, obs: Dict) -> Dict: - """Infer actions from observations.""" - - def reset(self) -> None: - """Reset the policy to its initial state.""" - pass diff --git a/capvector-pi05/packages/openpi-client/src/openpi_client/image_tools.py b/capvector-pi05/packages/openpi-client/src/openpi_client/image_tools.py deleted file mode 100644 index 421532e216df6345902760be6cf1e22fc3c167fa..0000000000000000000000000000000000000000 --- a/capvector-pi05/packages/openpi-client/src/openpi_client/image_tools.py +++ /dev/null @@ -1,78 +0,0 @@ -import numpy as np -from PIL import Image - - -def convert_to_uint8(img: np.ndarray) -> np.ndarray: - """Converts an image to uint8 if it is a float image. - - This is important for reducing the size of the image when sending it over the network. - """ - if np.issubdtype(img.dtype, np.floating): - img = (255 * img).astype(np.uint8) - return img - - -def resize_with_pad(images: np.ndarray, height: int, width: int, method=Image.BILINEAR, return_mask=False) -> np.ndarray: - """Replicates tf.image.resize_with_pad for multiple images using PIL. Resizes a batch of images to a target height. - - Args: - images: A batch of images in [..., height, width, channel] format. - height: The target height of the image. - width: The target width of the image. - method: The interpolation method to use. Default is bilinear. - - Returns: - The resized images in [..., height, width, channel]. - """ - # If the images are already the correct size, return them as is. - if images.shape[-3:-1] == (height, width): - if return_mask: - img_padding_mask = np.ones((*images.shape[:-3], height, width), dtype=bool) - return images, img_padding_mask - return images - - original_shape = images.shape - - images = images.reshape(-1, *original_shape[-3:]) - - resized_results = [ - _resize_with_pad_pil(Image.fromarray(im), height, width, method=method) for im in images - ] - resized_images, img_padding_mask = zip(*resized_results) - resized_images = np.stack(resized_images) - img_padding_mask = np.stack(img_padding_mask) - - if return_mask: - return ( - resized_images.reshape(*original_shape[:-3], *resized_images.shape[-3:]), - img_padding_mask.reshape(*original_shape[:-3], *img_padding_mask.shape[-2:]), - ) - else: - return resized_images.reshape(*original_shape[:-3], *resized_images.shape[-3:]) - - -def _resize_with_pad_pil(image: Image.Image, height: int, width: int, method: int) -> Image.Image: - """Replicates tf.image.resize_with_pad for one image using PIL. Resizes an image to a target height and - width without distortion by padding with zeros. - - Unlike the jax version, note that PIL uses [width, height, channel] ordering instead of [batch, h, w, c]. - """ - cur_width, cur_height = image.size - if cur_width == width and cur_height == height: - return image # No need to resize if the image is already the correct size. - - ratio = max(cur_width / width, cur_height / height) - resized_height = int(cur_height / ratio) - resized_width = int(cur_width / ratio) - resized_image = image.resize((resized_width, resized_height), resample=method) - - zero_image = Image.new(resized_image.mode, (width, height), 0) - pad_height = max(0, int((height - resized_height) / 2)) - pad_width = max(0, int((width - resized_width) / 2)) - zero_image.paste(resized_image, (pad_width, pad_height)) - assert zero_image.size == (width, height) - - img_padding_mask = np.zeros((height, width), dtype=bool) - img_padding_mask[pad_height:pad_height+resized_height, pad_width:pad_width+resized_width] = True - - return zero_image, img_padding_mask diff --git a/capvector-pi05/packages/openpi-client/src/openpi_client/image_tools_test.py b/capvector-pi05/packages/openpi-client/src/openpi_client/image_tools_test.py deleted file mode 100644 index 1c8a2a26c04254f6246da50a2254f3c3c0c03c96..0000000000000000000000000000000000000000 --- a/capvector-pi05/packages/openpi-client/src/openpi_client/image_tools_test.py +++ /dev/null @@ -1,37 +0,0 @@ -import numpy as np - -import openpi_client.image_tools as image_tools - - -def test_resize_with_pad_shapes(): - # Test case 1: Resize image with larger dimensions - images = np.zeros((2, 10, 10, 3), dtype=np.uint8) # Input images of shape (batch_size, height, width, channels) - height = 20 - width = 20 - resized_images = image_tools.resize_with_pad(images, height, width) - assert resized_images.shape == (2, height, width, 3) - assert np.all(resized_images == 0) - - # Test case 2: Resize image with smaller dimensions - images = np.zeros((3, 30, 30, 3), dtype=np.uint8) - height = 15 - width = 15 - resized_images = image_tools.resize_with_pad(images, height, width) - assert resized_images.shape == (3, height, width, 3) - assert np.all(resized_images == 0) - - # Test case 3: Resize image with the same dimensions - images = np.zeros((1, 50, 50, 3), dtype=np.uint8) - height = 50 - width = 50 - resized_images = image_tools.resize_with_pad(images, height, width) - assert resized_images.shape == (1, height, width, 3) - assert np.all(resized_images == 0) - - # Test case 3: Resize image with odd-numbered padding - images = np.zeros((1, 256, 320, 3), dtype=np.uint8) - height = 60 - width = 80 - resized_images = image_tools.resize_with_pad(images, height, width) - assert resized_images.shape == (1, height, width, 3) - assert np.all(resized_images == 0) diff --git a/capvector-pi05/packages/openpi-client/src/openpi_client/msgpack_numpy.py b/capvector-pi05/packages/openpi-client/src/openpi_client/msgpack_numpy.py deleted file mode 100644 index 70e353a9762de8ea45988354ea5d044fc03a52b4..0000000000000000000000000000000000000000 --- a/capvector-pi05/packages/openpi-client/src/openpi_client/msgpack_numpy.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Adds NumPy array support to msgpack. - -msgpack is good for (de)serializing data over a network for multiple reasons: -- msgpack is secure (as opposed to pickle/dill/etc which allow for arbitrary code execution) -- msgpack is widely used and has good cross-language support -- msgpack does not require a schema (as opposed to protobuf/flatbuffers/etc) which is convenient in dynamically typed - languages like Python and JavaScript -- msgpack is fast and efficient (as opposed to readable formats like JSON/YAML/etc); I found that msgpack was ~4x faster - than pickle for serializing large arrays using the below strategy - -The code below is adapted from https://github.com/lebedov/msgpack-numpy. The reason not to use that library directly is -that it falls back to pickle for object arrays. -""" - -import functools - -import msgpack -import numpy as np - - -def pack_array(obj): - if (isinstance(obj, (np.ndarray, np.generic))) and obj.dtype.kind in ("V", "O", "c"): - raise ValueError(f"Unsupported dtype: {obj.dtype}") - - if isinstance(obj, np.ndarray): - return { - b"__ndarray__": True, - b"data": obj.tobytes(), - b"dtype": obj.dtype.str, - b"shape": obj.shape, - } - - if isinstance(obj, np.generic): - return { - b"__npgeneric__": True, - b"data": obj.item(), - b"dtype": obj.dtype.str, - } - - return obj - - -def unpack_array(obj): - if b"__ndarray__" in obj: - return np.ndarray(buffer=obj[b"data"], dtype=np.dtype(obj[b"dtype"]), shape=obj[b"shape"]) - - if b"__npgeneric__" in obj: - return np.dtype(obj[b"dtype"]).type(obj[b"data"]) - - return obj - - -Packer = functools.partial(msgpack.Packer, default=pack_array) -packb = functools.partial(msgpack.packb, default=pack_array) - -Unpacker = functools.partial(msgpack.Unpacker, object_hook=unpack_array) -unpackb = functools.partial(msgpack.unpackb, object_hook=unpack_array) diff --git a/capvector-pi05/packages/openpi-client/src/openpi_client/msgpack_numpy_test.py b/capvector-pi05/packages/openpi-client/src/openpi_client/msgpack_numpy_test.py deleted file mode 100644 index d0d0b027c3ba77269151c2274226a6baadb410a4..0000000000000000000000000000000000000000 --- a/capvector-pi05/packages/openpi-client/src/openpi_client/msgpack_numpy_test.py +++ /dev/null @@ -1,45 +0,0 @@ -import numpy as np -import pytest -import tree - -from openpi_client import msgpack_numpy - - -def _check(expected, actual): - if isinstance(expected, np.ndarray): - assert expected.shape == actual.shape - assert expected.dtype == actual.dtype - assert np.array_equal(expected, actual, equal_nan=expected.dtype.kind == "f") - else: - assert expected == actual - - -@pytest.mark.parametrize( - "data", - [ - 1, # int - 1.0, # float - "hello", # string - np.bool_(True), # boolean scalar - np.array([1, 2, 3])[0], # int scalar - np.str_("asdf"), # string scalar - [1, 2, 3], # list - {"key": "value"}, # dict - {"key": [1, 2, 3]}, # nested dict - np.array(1.0), # 0D array - np.array([1, 2, 3], dtype=np.int32), # 1D integer array - np.array(["asdf", "qwer"]), # string array - np.array([True, False]), # boolean array - np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), # 2D float array - np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.int16), # 3D integer array - np.array([np.nan, np.inf, -np.inf]), # special float values - {"arr": np.array([1, 2, 3]), "nested": {"arr": np.array([4, 5, 6])}}, # nested dict with arrays - [np.array([1, 2]), np.array([3, 4])], # list of arrays - np.zeros((3, 4, 5), dtype=np.float32), # 3D zeros - np.ones((2, 3), dtype=np.float64), # 2D ones with double precision - ], -) -def test_pack_unpack(data): - packed = msgpack_numpy.packb(data) - unpacked = msgpack_numpy.unpackb(packed) - tree.map_structure(_check, data, unpacked) diff --git a/capvector-pi05/packages/openpi-client/src/openpi_client/runtime/agent.py b/capvector-pi05/packages/openpi-client/src/openpi_client/runtime/agent.py deleted file mode 100644 index d09d57ddf0e670a7630b7bff95175984c3f9212e..0000000000000000000000000000000000000000 --- a/capvector-pi05/packages/openpi-client/src/openpi_client/runtime/agent.py +++ /dev/null @@ -1,17 +0,0 @@ -import abc - - -class Agent(abc.ABC): - """An Agent is the thing with agency, i.e. the entity that makes decisions. - - Agents receive observations about the state of the world, and return actions - to take in response. - """ - - @abc.abstractmethod - def get_action(self, observation: dict) -> dict: - """Query the agent for the next action.""" - - @abc.abstractmethod - def reset(self) -> None: - """Reset the agent to its initial state.""" diff --git a/capvector-pi05/packages/openpi-client/src/openpi_client/runtime/agents/policy_agent.py b/capvector-pi05/packages/openpi-client/src/openpi_client/runtime/agents/policy_agent.py deleted file mode 100644 index 2fff4f87f7072aa055dad04bb11c0524e385eaf4..0000000000000000000000000000000000000000 --- a/capvector-pi05/packages/openpi-client/src/openpi_client/runtime/agents/policy_agent.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing_extensions import override - -from openpi_client import base_policy as _base_policy -from openpi_client.runtime import agent as _agent - - -class PolicyAgent(_agent.Agent): - """An agent that uses a policy to determine actions.""" - - def __init__(self, policy: _base_policy.BasePolicy) -> None: - self._policy = policy - - @override - def get_action(self, observation: dict) -> dict: - return self._policy.infer(observation) - - def reset(self) -> None: - self._policy.reset() diff --git a/capvector-pi05/packages/openpi-client/src/openpi_client/runtime/environment.py b/capvector-pi05/packages/openpi-client/src/openpi_client/runtime/environment.py deleted file mode 100644 index 4b29f594f247700981fa87ff46a4500f060be052..0000000000000000000000000000000000000000 --- a/capvector-pi05/packages/openpi-client/src/openpi_client/runtime/environment.py +++ /dev/null @@ -1,32 +0,0 @@ -import abc - - -class Environment(abc.ABC): - """An Environment represents the robot and the environment it inhabits. - - The primary contract of environments is that they can be queried for observations - about their state, and have actions applied to them to change that state. - """ - - @abc.abstractmethod - def reset(self) -> None: - """Reset the environment to its initial state. - - This will be called once before starting each episode. - """ - - @abc.abstractmethod - def is_episode_complete(self) -> bool: - """Allow the environment to signal that the episode is complete. - - This will be called after each step. It should return `True` if the episode is - complete (either successfully or unsuccessfully), and `False` otherwise. - """ - - @abc.abstractmethod - def get_observation(self) -> dict: - """Query the environment for the current state.""" - - @abc.abstractmethod - def apply_action(self, action: dict) -> None: - """Take an action in the environment.""" diff --git a/capvector-pi05/packages/openpi-client/src/openpi_client/runtime/runtime.py b/capvector-pi05/packages/openpi-client/src/openpi_client/runtime/runtime.py deleted file mode 100644 index d480c2ebb01fc559a128c5e338a8be74ff8e55d3..0000000000000000000000000000000000000000 --- a/capvector-pi05/packages/openpi-client/src/openpi_client/runtime/runtime.py +++ /dev/null @@ -1,92 +0,0 @@ -import logging -import threading -import time - -from openpi_client.runtime import agent as _agent -from openpi_client.runtime import environment as _environment -from openpi_client.runtime import subscriber as _subscriber - - -class Runtime: - """The core module orchestrating interactions between key components of the system.""" - - def __init__( - self, - environment: _environment.Environment, - agent: _agent.Agent, - subscribers: list[_subscriber.Subscriber], - max_hz: float = 0, - num_episodes: int = 1, - max_episode_steps: int = 0, - ) -> None: - self._environment = environment - self._agent = agent - self._subscribers = subscribers - self._max_hz = max_hz - self._num_episodes = num_episodes - self._max_episode_steps = max_episode_steps - - self._in_episode = False - self._episode_steps = 0 - - def run(self) -> None: - """Runs the runtime loop continuously until stop() is called or the environment is done.""" - for _ in range(self._num_episodes): - self._run_episode() - - # Final reset, this is important for real environments to move the robot to its home position. - self._environment.reset() - - def run_in_new_thread(self) -> threading.Thread: - """Runs the runtime loop in a new thread.""" - thread = threading.Thread(target=self.run) - thread.start() - return thread - - def mark_episode_complete(self) -> None: - """Marks the end of an episode.""" - self._in_episode = False - - def _run_episode(self) -> None: - """Runs a single episode.""" - logging.info("Starting episode...") - self._environment.reset() - self._agent.reset() - for subscriber in self._subscribers: - subscriber.on_episode_start() - - self._in_episode = True - self._episode_steps = 0 - step_time = 1 / self._max_hz if self._max_hz > 0 else 0 - last_step_time = time.time() - - while self._in_episode: - self._step() - self._episode_steps += 1 - - # Sleep to maintain the desired frame rate - now = time.time() - dt = now - last_step_time - if dt < step_time: - time.sleep(step_time - dt) - last_step_time = time.time() - else: - last_step_time = now - - logging.info("Episode completed.") - for subscriber in self._subscribers: - subscriber.on_episode_end() - - def _step(self) -> None: - """A single step of the runtime loop.""" - observation = self._environment.get_observation() - action = self._agent.get_action(observation) - self._environment.apply_action(action) - - for subscriber in self._subscribers: - subscriber.on_step(observation, action) - - if self._environment.is_episode_complete() or ( - self._max_episode_steps > 0 and self._episode_steps >= self._max_episode_steps - ): - self.mark_episode_complete() diff --git a/capvector-pi05/packages/openpi-client/src/openpi_client/runtime/subscriber.py b/capvector-pi05/packages/openpi-client/src/openpi_client/runtime/subscriber.py deleted file mode 100644 index e11b583aa2c4c962df7ed7907f5070ef30b97ef5..0000000000000000000000000000000000000000 --- a/capvector-pi05/packages/openpi-client/src/openpi_client/runtime/subscriber.py +++ /dev/null @@ -1,20 +0,0 @@ -import abc - - -class Subscriber(abc.ABC): - """Subscribes to events in the runtime. - - Subscribers can be used to save data, visualize, etc. - """ - - @abc.abstractmethod - def on_episode_start(self) -> None: - """Called when an episode starts.""" - - @abc.abstractmethod - def on_step(self, observation: dict, action: dict) -> None: - """Append a step to the episode.""" - - @abc.abstractmethod - def on_episode_end(self) -> None: - """Called when an episode ends.""" diff --git a/capvector-pi05/packages/openpi-client/src/openpi_client/websocket_client_policy.py b/capvector-pi05/packages/openpi-client/src/openpi_client/websocket_client_policy.py deleted file mode 100644 index 6cd20760b0fb1ef93622626e17fad2f63146dbe7..0000000000000000000000000000000000000000 --- a/capvector-pi05/packages/openpi-client/src/openpi_client/websocket_client_policy.py +++ /dev/null @@ -1,55 +0,0 @@ -import logging -import time -from typing import Dict, Optional, Tuple - -from typing_extensions import override -import websockets.sync.client - -from openpi_client import base_policy as _base_policy -from openpi_client import msgpack_numpy - - -class WebsocketClientPolicy(_base_policy.BasePolicy): - """Implements the Policy interface by communicating with a server over websocket. - - See WebsocketPolicyServer for a corresponding server implementation. - """ - - def __init__(self, host: str = "0.0.0.0", port: Optional[int] = None, api_key: Optional[str] = None) -> None: - self._uri = f"ws://{host}" - if port is not None: - self._uri += f":{port}" - self._packer = msgpack_numpy.Packer() - self._api_key = api_key - self._ws, self._server_metadata = self._wait_for_server() - - def get_server_metadata(self) -> Dict: - return self._server_metadata - - def _wait_for_server(self) -> Tuple[websockets.sync.client.ClientConnection, Dict]: - logging.info(f"Waiting for server at {self._uri}...") - while True: - try: - headers = {"Authorization": f"Api-Key {self._api_key}"} if self._api_key else None - conn = websockets.sync.client.connect( - self._uri, compression=None, max_size=None, additional_headers=headers - ) - metadata = msgpack_numpy.unpackb(conn.recv()) - return conn, metadata - except ConnectionRefusedError: - logging.info("Still waiting for server...") - time.sleep(5) - - @override - def infer(self, obs: Dict) -> Dict: # noqa: UP006 - data = self._packer.pack(obs) - self._ws.send(data) - response = self._ws.recv() - if isinstance(response, str): - # we're expecting bytes; if the server sends a string, it's an error. - raise RuntimeError(f"Error in inference server:\n{response}") - return msgpack_numpy.unpackb(response) - - @override - def reset(self) -> None: - pass diff --git a/capvector-pi05/pyproject.toml b/capvector-pi05/pyproject.toml deleted file mode 100644 index a69ac282556588a89e875be2d7149313dbb14306..0000000000000000000000000000000000000000 --- a/capvector-pi05/pyproject.toml +++ /dev/null @@ -1,142 +0,0 @@ -[project] -name = "openpi" -version = "0.1.0" -description = "Physical Intelligence open source repo" -readme = "README.md" -requires-python = ">=3.11" -license = { file = "LICENSE" } -dependencies = [ - "augmax>=0.3.4", - "dm-tree>=0.1.8", - "einops>=0.8.0", - "equinox>=0.11.8", - "flatbuffers>=24.3.25", - "flax==0.10.2", - "fsspec[gcs]>=2024.6.0", - "gym-aloha>=0.1.1", - "imageio>=2.36.1", - "jax[cuda12]==0.5.3", - "jaxtyping==0.2.36", - "lerobot", - "ml_collections==1.0.0", - "numpy>=1.22.4,<2.0.0", - "numpydantic>=1.6.6", - "opencv-python>=4.10.0.84", - "openpi-client", - "orbax-checkpoint==0.11.13", - "pillow>=11.0.0", - "sentencepiece>=0.2.0", - "torch==2.7.1", - "tqdm-loggable>=0.2", - "typing-extensions>=4.12.2", - "tyro>=0.9.5", - "wandb>=0.19.1", - "filelock>=3.16.1", - "beartype==0.19.0", - "treescope>=0.1.7", - "transformers==4.53.2", - "rich>=14.0.0", - "polars>=1.30.0", - "gradio==5.17.1", - "viser==0.2.23", - "hydra-core", - "onnxruntime", - "safetensors", -] - - -[project.urls] -Repository = "https://github.com/Physical-Intelligence/openpi" - -[dependency-groups] -dev = [ - "pytest>=8.3.4", - "ruff>=0.8.6", - "pre-commit>=4.0.1", - "ipykernel>=6.29.5", - "ipywidgets>=8.1.5", - "matplotlib>=3.10.0", - "pynvml>=12.0.0", -] -rlds = [ - "dlimp", - "tensorflow-cpu==2.15.0", - "tensorflow-datasets==4.9.9", -] - -[tool.uv] -override-dependencies = ["datasets==3.6.0", "ml-dtypes==0.4.1", "tensorstore==0.1.74"] - -[tool.uv.sources] -openpi-client = { workspace = true } -lerobot = { git = "https://github.com/huggingface/lerobot", rev = "0cf864870cf29f4738d3ade893e6fd13fbd7cdb5" } -dlimp = { git = "https://github.com/kvablack/dlimp", rev = "ad72ce3a9b414db2185bc0b38461d4101a65477a" } - -[tool.uv.workspace] -members = ["packages/*", "src/vggt"] - -[tool.ruff] -line-length = 120 -target-version = "py311" -extend-exclude = ["docker", "third_party", "src/openpi/models_pytorch/transformers_replace/*"] - -[tool.ruff.lint] -# https://docs.astral.sh/ruff/rules/ -select = [ - "B", - "C4", - "DTZ", - "E4", - "E7", - "E9", - "F", - "FBT", - "FURB", - "I", - "ICN", - "ISC", - "LOG", - "N", - "PD", - "PERF", - "PIE", - "PLC", - "PLE", - "PLR1", - "PLR5", - "PLW", - "PT", - "Q", - "RET", - "RUF", - "SIM", - "SLF", - "T10", - "T20", - "UP", - "W", -] -ignore = [ - "F722", # Conflicts with array typing. - "T201", # We use print statements. - "PD008", # Lots of false positives. - "ISC001", # Disabling to support ruff format. - "LOG015", # Use logger.info. -] -unfixable = [ - "B905", # Fix defaults to strict=False, which is not what we want. -] - -[tool.ruff.lint.isort] -force-single-line = true -force-sort-within-sections = true -single-line-exclusions = ["collections.abc", "typing", "typing_extensions"] -known-third-party = ["wandb"] - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.pytest.ini_options] -markers = ["manual: should be run manually."] -testpaths = ["src", "scripts", "packages"] diff --git a/capvector-pi05/scripts/__init__.py b/capvector-pi05/scripts/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/capvector-pi05/scripts/compute_norm_stats.py b/capvector-pi05/scripts/compute_norm_stats.py deleted file mode 100644 index 07ccb5e9ac4acb5c8c43ac64fbc2684edf8d7495..0000000000000000000000000000000000000000 --- a/capvector-pi05/scripts/compute_norm_stats.py +++ /dev/null @@ -1,117 +0,0 @@ -"""Compute normalization statistics for a config. - -This script is used to compute the normalization statistics for a given config. It -will compute the mean and standard deviation of the data in the dataset and save it -to the config assets directory. -""" - -import numpy as np -import tqdm -import tyro - -import openpi.models.model as _model -import openpi.shared.normalize as normalize -import openpi.training.config as _config -import openpi.training.data_loader as _data_loader -import openpi.transforms as transforms - - -class RemoveStrings(transforms.DataTransformFn): - def __call__(self, x: dict) -> dict: - return {k: v for k, v in x.items() if not np.issubdtype(np.asarray(v).dtype, np.str_)} - - -def create_torch_dataloader( - data_config: _config.DataConfig, - action_horizon: int, - batch_size: int, - model_config: _model.BaseModelConfig, - num_workers: int, - max_frames: int | None = None, -) -> tuple[_data_loader.Dataset, int]: - if data_config.repo_id is None: - raise ValueError("Data config must have a repo_id") - dataset = _data_loader.create_torch_dataset(data_config, action_horizon, model_config) - dataset = _data_loader.TransformedDataset( - dataset, - [ - *data_config.repack_transforms.inputs, - *data_config.data_transforms.inputs, - # Remove strings since they are not supported by JAX and are not needed to compute norm stats. - RemoveStrings(), - ], - ) - if max_frames is not None and max_frames < len(dataset): - num_batches = max_frames // batch_size - shuffle = True - else: - num_batches = len(dataset) // batch_size - shuffle = False - data_loader = _data_loader.TorchDataLoader( - dataset, - local_batch_size=batch_size, - num_workers=num_workers, - shuffle=shuffle, - num_batches=num_batches, - ) - return data_loader, num_batches - - -def create_rlds_dataloader( - data_config: _config.DataConfig, - action_horizon: int, - batch_size: int, - max_frames: int | None = None, -) -> tuple[_data_loader.Dataset, int]: - dataset = _data_loader.create_rlds_dataset(data_config, action_horizon, batch_size, shuffle=False) - dataset = _data_loader.IterableTransformedDataset( - dataset, - [ - *data_config.repack_transforms.inputs, - *data_config.data_transforms.inputs, - # Remove strings since they are not supported by JAX and are not needed to compute norm stats. - RemoveStrings(), - ], - is_batched=True, - ) - if max_frames is not None and max_frames < len(dataset): - num_batches = max_frames // batch_size - else: - # NOTE: this length is currently hard-coded for DROID. - num_batches = len(dataset) // batch_size - data_loader = _data_loader.RLDSDataLoader( - dataset, - num_batches=num_batches, - ) - return data_loader, num_batches - - -def main(config_name: str, max_frames: int | None = None): - config = _config.get_config(config_name) - data_config = config.data.create(config.assets_dirs, config.model) - - if data_config.rlds_data_dir is not None: - data_loader, num_batches = create_rlds_dataloader( - data_config, config.model.action_horizon, config.batch_size, max_frames - ) - else: - data_loader, num_batches = create_torch_dataloader( - data_config, config.model.action_horizon, config.batch_size, config.model, config.num_workers, max_frames - ) - - keys = ["state", "actions"] - stats = {key: normalize.RunningStats() for key in keys} - - for batch in tqdm.tqdm(data_loader, total=num_batches, desc="Computing stats"): - for key in keys: - stats[key].update(np.asarray(batch[key])) - - norm_stats = {key: stats.get_statistics() for key, stats in stats.items()} - - output_path = config.assets_dirs / data_config.repo_id - print(f"Writing stats to: {output_path}") - normalize.save(output_path, norm_stats) - - -if __name__ == "__main__": - tyro.cli(main) diff --git a/capvector-pi05/scripts/docker/compose.yml b/capvector-pi05/scripts/docker/compose.yml deleted file mode 100644 index 3655b85cf287df0f4e4e586bae97a6c607841516..0000000000000000000000000000000000000000 --- a/capvector-pi05/scripts/docker/compose.yml +++ /dev/null @@ -1,29 +0,0 @@ -# Run with: -# docker compose -f scripts/docker/compose.yml up --build -services: - openpi_server: - image: openpi_server - build: - context: ../.. - dockerfile: scripts/docker/serve_policy.Dockerfile - init: true - tty: true - network_mode: host - # Populate configured openpi data home to /openpi_assets inside the container. - # Populate aws credential inside the container. - volumes: - - $PWD:/app - - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets - environment: - - SERVER_ARGS - - OPENPI_DATA_HOME=/openpi_assets - - IS_DOCKER=true - - # Comment out this block if not running on a machine with GPUs. - deploy: - resources: - reservations: - devices: - - driver: nvidia - count: 1 - capabilities: [gpu] diff --git a/capvector-pi05/scripts/docker/install_docker_ubuntu22.sh b/capvector-pi05/scripts/docker/install_docker_ubuntu22.sh deleted file mode 100644 index cdda7fd608abde9aa99ab9c47049db6ae59a90db..0000000000000000000000000000000000000000 --- a/capvector-pi05/scripts/docker/install_docker_ubuntu22.sh +++ /dev/null @@ -1,37 +0,0 @@ -#!/bin/bash - -# Add Docker's official GPG key: -sudo apt-get update -sudo apt-get install -y ca-certificates curl -sudo install -m 0755 -d /etc/apt/keyrings -sudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc -sudo chmod a+r /etc/apt/keyrings/docker.asc - -# Add the repository to Apt sources: -echo \ - "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \ - $(. /etc/os-release && echo "$VERSION_CODENAME") stable" | - sudo tee /etc/apt/sources.list.d/docker.list >/dev/null -sudo apt-get update - -sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin - -# Add current user to the 'docker' group, which allows them to use docker commands (docker build, docker run, etc). -# See https://docs.docker.com/engine/install/linux-postinstall/ -username=$(whoami) -sudo usermod -aG docker $username - -# Configure docker to start automatically on system boot. -sudo systemctl enable docker.service -sudo systemctl enable containerd.service - -# https://forums.docker.com/t/docker-credential-desktop-exe-executable-file-not-found-in-path-using-wsl2/100225/5 -if [ ~/.docker/config.json ]; then - sed -i 's/credsStore/credStore/g' ~/.docker/config.json -fi - -echo "" -echo "********************************************************************" -echo "**** Restart to allow Docker permission changes to take effect. ****" -echo "********************************************************************" -echo "" diff --git a/capvector-pi05/scripts/docker/install_nvidia_container_toolkit.sh b/capvector-pi05/scripts/docker/install_nvidia_container_toolkit.sh deleted file mode 100644 index 1a1583309d936ad358551f7224bbce0d3bf5c9d1..0000000000000000000000000000000000000000 --- a/capvector-pi05/scripts/docker/install_nvidia_container_toolkit.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/bin/bash - -# Installs the NVIDIA Container Toolkit, which allows Docker containers to access NVIDIA GPUs. -# NVIDIA's official documentation: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html - -curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg && - curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | - sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | - sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list - -# NVIDIA's documenation omits 'sudo' in the following command, but it is required. -sudo sed -i -e '/experimental/ s/^#//g' /etc/apt/sources.list.d/nvidia-container-toolkit.list -sudo apt-get update -sudo apt-get install -y nvidia-container-toolkit - -sudo nvidia-ctk runtime configure --runtime=docker -sudo systemctl restart docker diff --git a/capvector-pi05/scripts/docker/serve_policy.Dockerfile b/capvector-pi05/scripts/docker/serve_policy.Dockerfile deleted file mode 100644 index 4060254f052a48a7346b700394537a422dfb88ee..0000000000000000000000000000000000000000 --- a/capvector-pi05/scripts/docker/serve_policy.Dockerfile +++ /dev/null @@ -1,38 +0,0 @@ -# Dockerfile for serving a PI policy. -# Based on UV's instructions: https://docs.astral.sh/uv/guides/integration/docker/#developing-in-a-container - -# Build the container: -# docker build . -t openpi_server -f scripts/docker/serve_policy.Dockerfile - -# Run the container: -# docker run --rm -it --network=host -v .:/app --gpus=all openpi_server /bin/bash - -FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0 -COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/ - -WORKDIR /app - -# Needed because LeRobot uses git-lfs. -RUN apt-get update && apt-get install -y git git-lfs linux-headers-generic build-essential clang - -# Copy from the cache instead of linking since it's a mounted volume -ENV UV_LINK_MODE=copy - -# Write the virtual environment outside of the project directory so it doesn't -# leak out of the container when we mount the application code. -ENV UV_PROJECT_ENVIRONMENT=/.venv - -# Install the project's dependencies using the lockfile and settings -RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT -RUN --mount=type=cache,target=/root/.cache/uv \ - --mount=type=bind,source=uv.lock,target=uv.lock \ - --mount=type=bind,source=pyproject.toml,target=pyproject.toml \ - --mount=type=bind,source=packages/openpi-client/pyproject.toml,target=packages/openpi-client/pyproject.toml \ - --mount=type=bind,source=packages/openpi-client/src,target=packages/openpi-client/src \ - GIT_LFS_SKIP_SMUDGE=1 uv sync --frozen --no-install-project --no-dev - -# Copy transformers_replace files while preserving directory structure -COPY src/openpi/models_pytorch/transformers_replace/ /tmp/transformers_replace/ -RUN /.venv/bin/python -c "import transformers; print(transformers.__file__)" | xargs dirname | xargs -I{} cp -r /tmp/transformers_replace/* {} && rm -rf /tmp/transformers_replace - -CMD /bin/bash -c "uv run scripts/serve_policy.py $SERVER_ARGS" diff --git a/capvector-pi05/scripts/serve_policy.py b/capvector-pi05/scripts/serve_policy.py deleted file mode 100644 index edabae3deb3d7cc1c79a7fbb9e5a6059a8d82c01..0000000000000000000000000000000000000000 --- a/capvector-pi05/scripts/serve_policy.py +++ /dev/null @@ -1,122 +0,0 @@ -import dataclasses -import enum -import logging -import socket - -import tyro - -from openpi.policies import policy as _policy -from openpi.policies import policy_config as _policy_config -from openpi.serving import websocket_policy_server -from openpi.training import config as _config - - -class EnvMode(enum.Enum): - """Supported environments.""" - - ALOHA = "aloha" - ALOHA_SIM = "aloha_sim" - DROID = "droid" - LIBERO = "libero" - - -@dataclasses.dataclass -class Checkpoint: - """Load a policy from a trained checkpoint.""" - - # Training config name (e.g., "pi0_aloha_sim"). - config: str - # Checkpoint directory (e.g., "checkpoints/pi0_aloha_sim/exp/10000"). - dir: str - - -@dataclasses.dataclass -class Default: - """Use the default policy for the given environment.""" - - -@dataclasses.dataclass -class Args: - """Arguments for the serve_policy script.""" - - # Environment to serve the policy for. This is only used when serving default policies. - env: EnvMode = EnvMode.ALOHA_SIM - - # If provided, will be used in case the "prompt" key is not present in the data, or if the model doesn't have a default - # prompt. - default_prompt: str | None = None - - # Port to serve the policy on. - port: int = 8000 - # Record the policy's behavior for debugging. - record: bool = False - - # Specifies how to load the policy. If not provided, the default policy for the environment will be used. - policy: Checkpoint | Default = dataclasses.field(default_factory=Default) - - -# Default checkpoints that should be used for each environment. -DEFAULT_CHECKPOINT: dict[EnvMode, Checkpoint] = { - EnvMode.ALOHA: Checkpoint( - config="pi05_aloha", - dir="gs://openpi-assets/checkpoints/pi05_base", - ), - EnvMode.ALOHA_SIM: Checkpoint( - config="pi0_aloha_sim", - dir="gs://openpi-assets/checkpoints/pi0_aloha_sim", - ), - EnvMode.DROID: Checkpoint( - config="pi05_droid", - dir="gs://openpi-assets/checkpoints/pi05_droid", - ), - EnvMode.LIBERO: Checkpoint( - config="pi05_libero", - dir="gs://openpi-assets/checkpoints/pi05_libero", - ), -} - - -def create_default_policy(env: EnvMode, *, default_prompt: str | None = None) -> _policy.Policy: - """Create a default policy for the given environment.""" - if checkpoint := DEFAULT_CHECKPOINT.get(env): - return _policy_config.create_trained_policy( - _config.get_config(checkpoint.config), checkpoint.dir, default_prompt=default_prompt - ) - raise ValueError(f"Unsupported environment mode: {env}") - - -def create_policy(args: Args) -> _policy.Policy: - """Create a policy from the given arguments.""" - match args.policy: - case Checkpoint(): - return _policy_config.create_trained_policy( - _config.get_config(args.policy.config), args.policy.dir, default_prompt=args.default_prompt - ) - case Default(): - return create_default_policy(args.env, default_prompt=args.default_prompt) - - -def main(args: Args) -> None: - policy = create_policy(args) - policy_metadata = policy.metadata - - # Record the policy's behavior. - if args.record: - policy = _policy.PolicyRecorder(policy, "policy_records") - - hostname = socket.gethostname() - local_ip = socket.gethostbyname(hostname) - logging.info("Creating server (host: %s, ip: %s)", hostname, local_ip) - - server = websocket_policy_server.WebsocketPolicyServer( - policy=policy, - host="0.0.0.0", - port=args.port, - metadata=policy_metadata, - ) - server.serve_forever() - - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO, force=True) - main(tyro.cli(Args)) diff --git a/capvector-pi05/scripts/train.py b/capvector-pi05/scripts/train.py deleted file mode 100644 index 3d37bb20f2f8dfadf94e77e30fe22a4b747fd137..0000000000000000000000000000000000000000 --- a/capvector-pi05/scripts/train.py +++ /dev/null @@ -1,280 +0,0 @@ -import dataclasses -import functools -import logging -import platform -from typing import Any - -import etils.epath as epath -import flax.nnx as nnx -from flax.training import common_utils -import flax.traverse_util as traverse_util -import jax -import jax.experimental -import jax.numpy as jnp -import numpy as np -import optax -import tqdm_loggable.auto as tqdm -import wandb - -import openpi.models.model as _model -import openpi.shared.array_typing as at -import openpi.shared.nnx_utils as nnx_utils -import openpi.training.checkpoints as _checkpoints -import openpi.training.config as _config -import openpi.training.data_loader as _data_loader -import openpi.training.optimizer as _optimizer -import openpi.training.sharding as sharding -import openpi.training.utils as training_utils -import openpi.training.weight_loaders as _weight_loaders - - -def init_logging(): - """Custom logging format for better readability.""" - level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"} - - class CustomFormatter(logging.Formatter): - def format(self, record): - record.levelname = level_mapping.get(record.levelname, record.levelname) - return super().format(record) - - formatter = CustomFormatter( - fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)", - datefmt="%H:%M:%S", - ) - - logger = logging.getLogger() - logger.setLevel(logging.INFO) - logger.handlers[0].setFormatter(formatter) - - -def init_wandb(config: _config.TrainConfig, *, resuming: bool, log_code: bool = False, enabled: bool = True): - if not enabled: - wandb.init(mode="disabled") - return - - ckpt_dir = config.checkpoint_dir - if not ckpt_dir.exists(): - raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.") - if resuming: - run_id = (ckpt_dir / "wandb_id.txt").read_text().strip() - wandb.init(id=run_id, resume="must", project=config.project_name) - else: - wandb.init( - name=config.exp_name, - config=dataclasses.asdict(config), - project=config.project_name, - ) - (ckpt_dir / "wandb_id.txt").write_text(wandb.run.id) - - if log_code: - wandb.run.log_code(epath.Path(__file__).parent.parent) - - -def _load_weights_and_validate(loader: _weight_loaders.WeightLoader, params_shape: at.Params) -> at.Params: - """Loads and validates the weights. Returns a loaded subset of the weights.""" - loaded_params = loader.load(params_shape) - at.check_pytree_equality(expected=params_shape, got=loaded_params, check_shapes=True, check_dtypes=True) - - # Remove jax.ShapeDtypeStruct from the loaded params. This makes sure that only the loaded params are returned. - return traverse_util.unflatten_dict( - {k: v for k, v in traverse_util.flatten_dict(loaded_params).items() if not isinstance(v, jax.ShapeDtypeStruct)} - ) - - -@at.typecheck -def init_train_state( - config: _config.TrainConfig, init_rng: at.KeyArrayLike, mesh: jax.sharding.Mesh, *, resume: bool -) -> tuple[training_utils.TrainState, Any]: - tx = _optimizer.create_optimizer(config.optimizer, config.lr_schedule, weight_decay_mask=None) - - def init(rng: at.KeyArrayLike, partial_params: at.Params | None = None) -> training_utils.TrainState: - rng, model_rng = jax.random.split(rng) - # initialize the model (and its parameters). - model = config.model.create(model_rng) - - # Merge the partial params into the model. - if partial_params is not None: - graphdef, state = nnx.split(model) - # This will produce an error if the partial params are not a subset of the state. - state.replace_by_pure_dict(partial_params) - model = nnx.merge(graphdef, state) - - params = nnx.state(model) - # Convert frozen params to bfloat16. - params = nnx_utils.state_map(params, config.freeze_filter, lambda p: p.replace(p.value.astype(jnp.bfloat16))) - - return training_utils.TrainState( - step=0, - params=params, - model_def=nnx.graphdef(model), - tx=tx, - opt_state=tx.init(params.filter(config.trainable_filter)), - ema_decay=config.ema_decay, - ema_params=None if config.ema_decay is None else params, - ) - - train_state_shape = jax.eval_shape(init, init_rng) - state_sharding = sharding.fsdp_sharding(train_state_shape, mesh, log=True) - - if resume: - return train_state_shape, state_sharding - - partial_params = _load_weights_and_validate(config.weight_loader, train_state_shape.params.to_pure_dict()) - replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) - - # Initialize the train state and mix in the partial params. - train_state = jax.jit( - init, - donate_argnums=(1,), # donate the partial params buffer. - in_shardings=replicated_sharding, - out_shardings=state_sharding, - )(init_rng, partial_params) - - return train_state, state_sharding - - -@at.typecheck -def train_step( - config: _config.TrainConfig, - rng: at.KeyArrayLike, - state: training_utils.TrainState, - batch: tuple[_model.Observation, _model.Actions], -) -> tuple[training_utils.TrainState, dict[str, at.Array]]: - model = nnx.merge(state.model_def, state.params) - model.train() - - @at.typecheck - def loss_fn( - model: _model.BaseModel, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions - ): - chunked_loss = model.compute_loss(rng, observation, actions, train=True) - return jnp.mean(chunked_loss) - - train_rng = jax.random.fold_in(rng, state.step) - observation, actions = batch - - # Filter out frozen params. - diff_state = nnx.DiffState(0, config.trainable_filter) - loss, grads = nnx.value_and_grad(loss_fn, argnums=diff_state)(model, train_rng, observation, actions) - - params = state.params.filter(config.trainable_filter) - updates, new_opt_state = state.tx.update(grads, state.opt_state, params) - new_params = optax.apply_updates(params, updates) - - # Update the model in place and return the new full state. - nnx.update(model, new_params) - new_params = nnx.state(model) - - new_state = dataclasses.replace(state, step=state.step + 1, params=new_params, opt_state=new_opt_state) - if state.ema_decay is not None: - new_state = dataclasses.replace( - new_state, - ema_params=jax.tree.map( - lambda old, new: state.ema_decay * old + (1 - state.ema_decay) * new, state.ema_params, new_params - ), - ) - - # Filter out params that aren't kernels. - kernel_params = nnx.state( - model, - nnx.All( - nnx.Param, - nnx.Not(nnx_utils.PathRegex(".*/(bias|scale|pos_embedding|input_embedding)")), - lambda _, x: x.value.ndim > 1, - ), - ) - info = { - "loss": loss, - "grad_norm": optax.global_norm(grads), - "param_norm": optax.global_norm(kernel_params), - } - return new_state, info - - -def main(config: _config.TrainConfig): - init_logging() - logging.info(f"Running on: {platform.node()}") - - if config.batch_size % jax.device_count() != 0: - raise ValueError( - f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}." - ) - - jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax").expanduser())) - - rng = jax.random.key(config.seed) - train_rng, init_rng = jax.random.split(rng) - - mesh = sharding.make_mesh(config.fsdp_devices) - data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(sharding.DATA_AXIS)) - replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) - - checkpoint_manager, resuming = _checkpoints.initialize_checkpoint_dir( - config.checkpoint_dir, - keep_period=config.keep_period, - overwrite=config.overwrite, - resume=config.resume, - ) - init_wandb(config, resuming=resuming, enabled=config.wandb_enabled) - - data_loader = _data_loader.create_data_loader( - config, - sharding=data_sharding, - shuffle=True, - ) - data_iter = iter(data_loader) - batch = next(data_iter) - logging.info(f"Initialized data loader:\n{training_utils.array_tree_to_info(batch)}") - - # Log images from first batch to sanity check. - images_to_log = [ - wandb.Image(np.concatenate([np.array(img[i]) for img in batch[0].images.values()], axis=1)) - for i in range(min(5, len(next(iter(batch[0].images.values()))))) - ] - wandb.log({"camera_views": images_to_log}, step=0) - - train_state, train_state_sharding = init_train_state(config, init_rng, mesh, resume=resuming) - jax.block_until_ready(train_state) - logging.info(f"Initialized train state:\n{training_utils.array_tree_to_info(train_state.params)}") - - if resuming: - train_state = _checkpoints.restore_state(checkpoint_manager, train_state, data_loader) - - ptrain_step = jax.jit( - functools.partial(train_step, config), - in_shardings=(replicated_sharding, train_state_sharding, data_sharding), - out_shardings=(train_state_sharding, replicated_sharding), - donate_argnums=(1,), - ) - - start_step = int(train_state.step) - pbar = tqdm.tqdm( - range(start_step, config.num_train_steps), - initial=start_step, - total=config.num_train_steps, - dynamic_ncols=True, - ) - - infos = [] - for step in pbar: - with sharding.set_mesh(mesh): - train_state, info = ptrain_step(train_rng, train_state, batch) - infos.append(info) - if step % config.log_interval == 0: - stacked_infos = common_utils.stack_forest(infos) - reduced_info = jax.device_get(jax.tree.map(jnp.mean, stacked_infos)) - info_str = ", ".join(f"{k}={v:.4f}" for k, v in reduced_info.items()) - pbar.write(f"Step {step}: {info_str}") - wandb.log(reduced_info, step=step) - infos = [] - batch = next(data_iter) - - if (step % config.save_interval == 0 and step > start_step) or step == config.num_train_steps - 1: - _checkpoints.save_state(checkpoint_manager, train_state, data_loader, step) - - logging.info("Waiting for checkpoint manager to finish") - checkpoint_manager.wait_until_finished() - - -if __name__ == "__main__": - main(_config.cli()) diff --git a/capvector-pi05/scripts/train_align_pytorch.py b/capvector-pi05/scripts/train_align_pytorch.py deleted file mode 100644 index ff1479ce7927f49e0b46e382cf7c57ef0ea86e72..0000000000000000000000000000000000000000 --- a/capvector-pi05/scripts/train_align_pytorch.py +++ /dev/null @@ -1,658 +0,0 @@ -""" -PyTorch training entrypoint for PI0/PI05 with multi-GPU and multi-node (DDP) support. -This script mirrors the behavior of the JAX trainer (`scripts/train.py`) but runs -entirely in PyTorch using the `PI0Pytorch` model and your existing config/data -pipeline from `src/openpi/training/config.py` and `src/openpi/training/data_loader.py`. - -Usage -Single GPU: - python scripts/train_pytorch.py --exp_name --save_interval - Example: - python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test - python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test --resume # Resume from latest checkpoint -Multi-GPU (single node): - torchrun --standalone --nnodes=1 --nproc_per_node= scripts/train_pytorch.py --exp_name - Example: - torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test - torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume -Multi-Node Training: - torchrun \ - --nnodes= --nproc_per_node= --node_rank= \ - --master_addr= --master_port= \ - scripts/train_pytorch.py --exp_name= --save_interval - -""" - -import dataclasses -import gc -import logging -import os -import platform -import shutil -import time - -import jax -import numpy as np -import safetensors.torch -import torch -import torch.distributed as dist -import torch.nn.parallel -import tqdm -import wandb - -import openpi.models.pi0_config -from openpi.models_pytorch import pi0_pytorch, pi0_align_pytorch, projectors -import openpi.shared.normalize as _normalize -import openpi.training.config as _config -import openpi.training.data_loader as _data - -from vggt.models.vggt import VGGT - - -def init_logging(): - level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"} - - class CustomFormatter(logging.Formatter): - def format(self, record): - record.levelname = level_mapping.get(record.levelname, record.levelname) - return super().format(record) - - formatter = CustomFormatter( - fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)", - datefmt="%H:%M:%S", - ) - logger = logging.getLogger() - logger.setLevel(logging.INFO) - if not logger.handlers: - ch = logging.StreamHandler() - ch.setFormatter(formatter) - logger.addHandler(ch) - else: - logger.handlers[0].setFormatter(formatter) - - -def init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: bool = True): - """Initialize wandb logging.""" - if not enabled: - wandb.init(mode="disabled") - return - - ckpt_dir = config.checkpoint_dir - if not ckpt_dir.exists(): - raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.") - - if resuming: - run_id = (ckpt_dir / "wandb_id.txt").read_text().strip() - wandb.init(id=run_id, resume="must", project=config.project_name) - else: - wandb.init( - name=config.exp_name, - config=dataclasses.asdict(config), - project=config.project_name, - ) - (ckpt_dir / "wandb_id.txt").write_text(wandb.run.id) - - -def setup_ddp(): - world_size = int(os.environ.get("WORLD_SIZE", "1")) - use_ddp = world_size > 1 - if use_ddp and not torch.distributed.is_initialized(): - backend = "nccl" if torch.cuda.is_available() else "gloo" - torch.distributed.init_process_group(backend=backend, init_method="env://") - - # Set up debugging environment variables for DDP issues - if os.environ.get("TORCH_DISTRIBUTED_DEBUG") is None: - os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO" - - local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0"))) - device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") - if torch.cuda.is_available(): - torch.cuda.set_device(device) - return use_ddp, local_rank, device - - -def cleanup_ddp(): - if torch.distributed.is_initialized(): - torch.distributed.barrier() - torch.distributed.destroy_process_group() - - -def set_seed(seed: int, local_rank: int): - torch.manual_seed(seed + local_rank) - np.random.seed(seed + local_rank) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed + local_rank) - - -def build_datasets(config: _config.TrainConfig): - # Use the unified data loader with PyTorch framework - data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=True) - return data_loader, data_loader.data_config() - - -def get_model_state_dict(model): - """Get state dict from model, handling DDP wrapper.""" - return ( - model.module.state_dict() - if isinstance(model, torch.nn.parallel.DistributedDataParallel) - else model.state_dict() - ) - - -def get_model_parameters(model): - """Get parameters from model, handling DDP wrapper.""" - return ( - model.module.parameters() - if isinstance(model, torch.nn.parallel.DistributedDataParallel) - else model.parameters() - ) - - -def save_checkpoint(model, optimizer, global_step, config, is_main, data_config): - """Save a checkpoint with model state, optimizer state, and metadata.""" - if not is_main: - return - - # Only save if it's time to save or if it's the final step - if (global_step % config.save_interval == 0 and global_step > 0) or global_step == config.num_train_steps - 1: - # Create temporary directory for atomic checkpoint saving - final_ckpt_dir = config.checkpoint_dir / f"{global_step}" - tmp_ckpt_dir = config.checkpoint_dir / f"tmp_{global_step}" - - # Remove any existing temp directory and create new one - if tmp_ckpt_dir.exists(): - shutil.rmtree(tmp_ckpt_dir) - tmp_ckpt_dir.mkdir(parents=True, exist_ok=True) - - # Save model state using safetensors (handle shared tensors) - model_to_save = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model - safetensors.torch.save_model(model_to_save, tmp_ckpt_dir / "model.safetensors") - - # Save optimizer state using PyTorch format - torch.save(optimizer.state_dict(), tmp_ckpt_dir / "optimizer.pt") - - # Save training metadata (avoid saving full config to prevent JAX/Flax compatibility issues) - metadata = { - "global_step": global_step, - "config": dataclasses.asdict(config), - "timestamp": time.time(), - } - torch.save(metadata, tmp_ckpt_dir / "metadata.pt") - - # save norm stats - norm_stats = data_config.norm_stats - if norm_stats is not None and data_config.asset_id is not None: - _normalize.save(tmp_ckpt_dir / "assets" / data_config.asset_id, norm_stats) - - # Atomically move temp directory to final location - if final_ckpt_dir.exists(): - shutil.rmtree(final_ckpt_dir) - tmp_ckpt_dir.rename(final_ckpt_dir) - - logging.info(f"Saved checkpoint at step {global_step} -> {final_ckpt_dir}") - - # Log checkpoint to wandb - if config.wandb_enabled: - wandb.log({"checkpoint_step": global_step}, step=global_step) - - -def load_checkpoint(model, optimizer, checkpoint_dir, device): - """Load the latest checkpoint and return the global step.""" - checkpoint_steps = [ - int(d.name) - for d in checkpoint_dir.iterdir() - if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_") - ] - - if not checkpoint_steps: - raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") - - latest_step = max(checkpoint_steps) - ckpt_dir = checkpoint_dir / f"{latest_step}" - - # Clear memory before loading checkpoints - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - log_memory_usage(device, latest_step, "before_loading_checkpoint") - - try: - # Load model state with error handling - logging.info("Loading model state...") - safetensors_path = ckpt_dir / "model.safetensors" - - if safetensors_path.exists(): - model_to_load = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model - safetensors.torch.load_model(model_to_load, safetensors_path, device=str(device)) - logging.info("Loaded model state from safetensors format") - else: - raise FileNotFoundError(f"No model checkpoint found at {ckpt_dir}") - - torch.cuda.empty_cache() - gc.collect() - log_memory_usage(device, latest_step, "after_loading_model") - - # Load optimizer state with error handling - logging.info("Loading optimizer state...") - optimizer_path = ckpt_dir / "optimizer.pt" - - if optimizer_path.exists(): - optimizer_state_dict = torch.load(optimizer_path, map_location=device, weights_only=False) - logging.info("Loaded optimizer state from pt format") - else: - raise FileNotFoundError(f"No optimizer checkpoint found at {ckpt_dir}") - - optimizer.load_state_dict(optimizer_state_dict) - del optimizer_state_dict - torch.cuda.empty_cache() - gc.collect() - log_memory_usage(device, latest_step, "after_loading_optimizer") - - # Load metadata - logging.info("Loading metadata...") - metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False) - global_step = metadata.get("global_step", latest_step) - del metadata - torch.cuda.empty_cache() - gc.collect() - log_memory_usage(device, latest_step, "after_loading_metadata") - - logging.info(f"Successfully loaded all checkpoint components from step {latest_step}") - return global_step - - except RuntimeError as e: - if "out of memory" in str(e): - # Clear memory and provide detailed error message - torch.cuda.empty_cache() - gc.collect() - logging.error(f"Out of memory error while loading checkpoint: {e!s}") - log_memory_usage(device, latest_step, "after_oom_error") - raise RuntimeError( - "Out of memory while loading checkpoint. Try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True" - ) from e - raise - - -def get_latest_checkpoint_step(checkpoint_dir): - """Get the latest checkpoint step number from a checkpoint directory.""" - checkpoint_steps = [ - int(d.name) - for d in checkpoint_dir.iterdir() - if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_") - ] - return max(checkpoint_steps) if checkpoint_steps else None - - -def log_memory_usage(device, step, phase="unknown"): - """Log detailed memory usage information.""" - if not torch.cuda.is_available(): - return - - memory_allocated = torch.cuda.memory_allocated(device) / 1e9 - memory_reserved = torch.cuda.memory_reserved(device) / 1e9 - memory_free = torch.cuda.memory_reserved(device) - torch.cuda.memory_allocated(device) - memory_free = memory_free / 1e9 - - # Get more detailed memory info - memory_stats = torch.cuda.memory_stats(device) - max_memory_allocated = memory_stats.get("allocated_bytes.all.peak", 0) / 1e9 - max_memory_reserved = memory_stats.get("reserved_bytes.all.peak", 0) / 1e9 - - # Get DDP info if available - ddp_info = "" - if dist.is_initialized(): - ddp_info = f" | DDP: rank={dist.get_rank()}, world_size={dist.get_world_size()}" - - logging.info( - f"Step {step} ({phase}): GPU memory - allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB, free: {memory_free:.2f}GB, peak_allocated: {max_memory_allocated:.2f}GB, peak_reserved: {max_memory_reserved:.2f}GB{ddp_info}" - ) - - -def train_loop(config: _config.TrainConfig): - use_ddp, local_rank, device = setup_ddp() - is_main = (not use_ddp) or (dist.get_rank() == 0) - set_seed(config.seed, local_rank) - - # Initialize checkpoint directory and wandb - resuming = False - if config.resume: - # Find checkpoint directory based on experiment name - exp_checkpoint_dir = config.checkpoint_dir - if exp_checkpoint_dir.exists(): - # Use validation to find the latest working checkpoint - latest_step = get_latest_checkpoint_step(exp_checkpoint_dir) - if latest_step is not None: - resuming = True - logging.info( - f"Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}" - ) - else: - raise FileNotFoundError(f"No valid checkpoints found in {exp_checkpoint_dir} for resume") - else: - raise FileNotFoundError(f"Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume") - elif config.overwrite and config.checkpoint_dir.exists(): - shutil.rmtree(config.checkpoint_dir) - logging.info(f"Overwriting checkpoint directory: {config.checkpoint_dir}") - - # Create checkpoint directory with experiment name - if not resuming: - # For new runs, create experiment-specific checkpoint directory - exp_checkpoint_dir = config.checkpoint_dir - exp_checkpoint_dir.mkdir(parents=True, exist_ok=True) - logging.info(f"Created experiment checkpoint directory: {exp_checkpoint_dir}") - else: - # For resume, checkpoint_dir is already set to the experiment directory - logging.info(f"Using existing experiment checkpoint directory: {config.checkpoint_dir}") - - # Initialize wandb (only on main process) - if is_main: - init_wandb(config, resuming=resuming, enabled=config.wandb_enabled) - - # Build data loader using the unified data loader - # Calculate effective batch size per GPU for DDP - # For N GPUs, each GPU should get batch_size/N samples, so total across all GPUs is batch_size - world_size = torch.distributed.get_world_size() if use_ddp else 1 - effective_batch_size = config.batch_size // world_size - logging.info( - f"Using batch size per GPU: {effective_batch_size} (total batch size across {world_size} GPUs: {config.batch_size})" - ) - - # Pass the original batch size to data loader - it will handle DDP splitting internally - loader, data_config = build_datasets(config) - - # Log sample images to wandb on first batch - if is_main and config.wandb_enabled and not resuming: - # Create a separate data loader for sample batch to avoid consuming the main loader - sample_data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=False) - sample_batch = next(iter(sample_data_loader)) - # Convert observation and actions to torch tensors - observation, actions = sample_batch - sample_batch = observation.to_dict() - sample_batch["actions"] = actions - - # Create sample images for wandb - images_to_log = [] - # Get batch size from the first image tensor - batch_size = next(iter(sample_batch["image"].values())).shape[0] - for i in range(min(5, batch_size)): - # Concatenate all camera views horizontally for this batch item - # Convert from NCHW to NHWC format for wandb - img_concatenated = torch.cat([img[i].permute(1, 2, 0) for img in sample_batch["image"].values()], axis=1) - img_concatenated = img_concatenated.cpu().numpy() - images_to_log.append(wandb.Image(img_concatenated)) - - wandb.log({"camera_views": images_to_log}, step=0) - - # Clear sample batch from memory aggressively - del sample_batch, observation, actions, images_to_log, img_concatenated - del sample_data_loader # Also delete the sample data loader - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - logging.info("Cleared sample batch and data loader from memory") - - # Build model - if not isinstance(config.model, openpi.models.pi0_config.Pi0Config): - # Convert dataclass to Pi0Config if needed - model_cfg = openpi.models.pi0_config.Pi0Config( - dtype=config.pytorch_training_precision, - action_dim=config.model.action_dim, - action_horizon=config.model.action_horizon, - max_token_len=config.model.max_token_len, - paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"), - action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"), - pi05=getattr(config.model, "pi05", False), - ) - else: - model_cfg = config.model - # Update dtype to match pytorch_training_precision - object.__setattr__(model_cfg, "dtype", config.pytorch_training_precision) - - model = openpi.models_pytorch.pi0_align_pytorch.PI0Pytorch(model_cfg, config).to(device) - vggt_model = VGGT( - enable_camera=False, - enable_point=False, - enable_depth=False, - enable_track=False, - feature_only=True, - ).to(device) - align_projector = projectors.AlignProjector( - model.LLM_width, - config.vggt_dim, - config.use_vlm_norm).to(device) - - if hasattr(model, "gradient_checkpointing_enable"): - enable_gradient_checkpointing = True - model.gradient_checkpointing_enable() - logging.info("Enabled gradient checkpointing for memory optimization") - else: - enable_gradient_checkpointing = False - logging.info("Gradient checkpointing is not supported for this model") - - # Log initial memory usage after model creation - if is_main and torch.cuda.is_available(): - log_memory_usage(device, 0, "after_model_creation") - - # Enable memory optimizations for large-scale training - if world_size >= 8: - torch.backends.cudnn.benchmark = True - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - # Set memory allocation configuration - os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True" - logging.info("Enabled memory optimizations for 8+ GPU training") - - if use_ddp: - model = torch.nn.parallel.DistributedDataParallel( - model, - device_ids=[device.index] if device.type == "cuda" else None, - find_unused_parameters=True, # Disable for memory efficiency - gradient_as_bucket_view=True, # Enable for memory efficiency - static_graph=world_size >= 8, # Enable for 8+ GPUs - ) - align_projector = torch.nn.parallel.DistributedDataParallel( - align_projector, - device_ids=[device.index] if device.type == "cuda" else None, - find_unused_parameters=True, # Disable for memory efficiency - gradient_as_bucket_view=True, # Enable for memory efficiency - static_graph=world_size >= 8, # Enable for 8+ GPUs - ) - - # Load weights from weight_loader if specified (for fine-tuning) - if config.pytorch_weight_path is not None: - logging.info(f"Loading weights from: {config.pytorch_weight_path}") - model_path = os.path.join(config.pytorch_weight_path, "model.safetensors") - safetensors.torch.load_model( - (model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model), - model_path, - strict=False, - ) - logging.info(f"Loaded PyTorch weights from {config.pytorch_weight_path}") - if config.vggt_weight_path is not None: - vggt_path = os.path.join(config.vggt_weight_path, "model.pt") - if not os.path.exists(vggt_path): - raise FileNotFoundError(f"VGGT weight file not found at {vggt_path}") - vggt_model.load_state_dict(torch.load(vggt_path), strict=False) - logging.info(f"Loaded VGGT weights from {config.vggt_weight_path}") - - # Optimizer + learning rate schedule from config - warmup_steps = config.lr_schedule.warmup_steps - peak_lr = config.lr_schedule.peak_lr - decay_steps = config.lr_schedule.decay_steps - end_lr = config.lr_schedule.decay_lr - - # Create optimizer with config parameters - optim = torch.optim.AdamW( - list(model.parameters()) + list(align_projector.parameters()), - lr=peak_lr, - betas=(config.optimizer.b1, config.optimizer.b2), - eps=config.optimizer.eps, - weight_decay=config.optimizer.weight_decay, - ) - - # Load checkpoint if resuming - global_step = 0 - if resuming: - global_step = load_checkpoint(model, optim, config.checkpoint_dir, device) - logging.info(f"Resumed training from step {global_step}") - - def lr_schedule(step: int): - if step < warmup_steps: - # Match JAX behavior: start from peak_lr / (warmup_steps + 1) - init_lr = peak_lr / (warmup_steps + 1) - return init_lr + (peak_lr - init_lr) * step / warmup_steps - # cosine decay - progress = min(1.0, (step - warmup_steps) / max(1, decay_steps - warmup_steps)) - cos = 0.5 * (1 + np.cos(np.pi * progress)) - return end_lr + (peak_lr - end_lr) * cos - - model.train() - align_projector.train() - vggt_model.eval() - start_time = time.time() - infos = [] # Collect stats over log interval - if is_main: - logging.info( - f"Running on: {platform.node()} | world_size={torch.distributed.get_world_size() if use_ddp else 1}" - ) - logging.info( - f"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}" - ) - logging.info(f"Memory optimizations: gradient_checkpointing={enable_gradient_checkpointing}") - logging.info( - f"LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}" - ) - logging.info( - f"Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}" - ) - logging.info("EMA is not supported for PyTorch training") - logging.info(f"Training precision: {model_cfg.dtype}") - - # Training loop - iterate until we reach num_train_steps - pbar = ( - tqdm.tqdm(total=config.num_train_steps, initial=global_step, desc="Training", disable=not is_main) - if is_main - else None - ) - - while global_step < config.num_train_steps: - # Set epoch for distributed training - if use_ddp and hasattr(loader, "set_epoch"): - loader.set_epoch(global_step // len(loader)) - - for observation, actions in loader: - # Check if we've reached the target number of steps - if global_step >= config.num_train_steps: - break - - # The unified data loader returns (observation, actions) tuple - observation = jax.tree.map(lambda x: x.to(device), observation) # noqa: PLW2901 - actions = actions.to(torch.float32) # noqa: PLW2901 - actions = actions.to(device) # noqa: PLW2901 - - # Update LR - for pg in optim.param_groups: - pg["lr"] = lr_schedule(global_step) - - # Forward pass - action_losses, align_loss = model(observation, actions, vggt=vggt_model, align_proj=align_projector) - loss = action_losses + config.align_loss_coeff * align_loss - - # Backward pass - loss.backward() - - # Log memory usage after backward pass - if global_step < 5 and is_main and torch.cuda.is_available(): - log_memory_usage(device, global_step, "after_backward") - - # Gradient clipping - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optimizer.clip_gradient_norm) - - # Optimizer step - optim.step() - optim.zero_grad(set_to_none=True) - - # Clear gradients more aggressively - for param in model.parameters(): - if param.grad is not None: - param.grad.detach_() - param.grad = None - - # Collect stats - if is_main: - infos.append( - { - "action_loss": action_losses.item(), - "align_loss": align_loss.item(), - "learning_rate": optim.param_groups[0]["lr"], - "grad_norm": float(grad_norm) if isinstance(grad_norm, torch.Tensor) else grad_norm, - } - ) - - if is_main and (global_step % config.log_interval == 0): - elapsed = time.time() - start_time - - # Average stats over log interval - avg_loss = sum(info["action_loss"] for info in infos) / len(infos) - avg_align_loss = sum(info["align_loss"] for info in infos) / len(infos) - avg_lr = sum(info["learning_rate"] for info in infos) / len(infos) - - avg_grad_norm = None - if any("grad_norm" in info for info in infos): - vals = [ - info["grad_norm"] for info in infos if "grad_norm" in info and info["grad_norm"] is not None - ] - if len(vals) > 0: - avg_grad_norm = sum(vals) / len(vals) - logging.info( - f"step={global_step} action_loss={avg_loss:.4f} align_loss={avg_align_loss:.4f} lr={avg_lr:.2e} grad_norm={avg_grad_norm:.2f} time={elapsed:.1f}s" - if avg_grad_norm is not None - else f"step={global_step} action_loss={avg_loss:.4f} align_loss={avg_align_loss:.4f} lr={avg_lr:.2e} time={elapsed:.1f}s" - ) - - # Log to wandb - if config.wandb_enabled and len(infos) > 0: - log_payload = { - "action_loss": avg_loss, - "align_loss": avg_align_loss, - "learning_rate": avg_lr, - "step": global_step, - "time_per_step": elapsed / config.log_interval, - } - if avg_grad_norm is not None: - log_payload["grad_norm"] = avg_grad_norm - wandb.log(log_payload, step=global_step) - - start_time = time.time() - infos = [] # Reset stats collection - - global_step += 1 - # Save checkpoint using the new mechanism - save_checkpoint(model, optim, global_step, config, is_main, data_config) - - # Update progress bar - if pbar is not None: - pbar.update(1) - pbar.set_postfix( - {"loss": f"{loss.item():.4f}", "lr": f"{optim.param_groups[0]['lr']:.2e}", "step": global_step} - ) - - # Close progress bar - if pbar is not None: - pbar.close() - - # Finish wandb run - if is_main and config.wandb_enabled: - wandb.finish() - - cleanup_ddp() - - -def main(): - init_logging() - config = _config.cli() - train_loop(config) - - -if __name__ == "__main__": - main() diff --git a/capvector-pi05/scripts/train_pytorch.py b/capvector-pi05/scripts/train_pytorch.py deleted file mode 100644 index a03c206466e99c506a0debc0ec51b5b3302d0249..0000000000000000000000000000000000000000 --- a/capvector-pi05/scripts/train_pytorch.py +++ /dev/null @@ -1,632 +0,0 @@ -""" -PyTorch training entrypoint for PI0/PI05 with multi-GPU and multi-node (DDP) support. -This script mirrors the behavior of the JAX trainer (`scripts/train.py`) but runs -entirely in PyTorch using the `PI0Pytorch` model and your existing config/data -pipeline from `src/openpi/training/config.py` and `src/openpi/training/data_loader.py`. - -Usage -Single GPU: - python scripts/train_pytorch.py --exp_name --save_interval - Example: - python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test - python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test --resume # Resume from latest checkpoint -Multi-GPU (single node): - torchrun --standalone --nnodes=1 --nproc_per_node= scripts/train_pytorch.py --exp_name - Example: - torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test - torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume -Multi-Node Training: - torchrun \ - --nnodes= --nproc_per_node= --node_rank= \ - --master_addr= --master_port= \ - scripts/train_pytorch.py --exp_name= --save_interval - -""" - -import dataclasses -import gc -import logging -import os -import platform -import shutil -import time - -import jax -import numpy as np -import safetensors.torch -import torch -import torch.distributed as dist -import torch.nn.parallel -import tqdm -import wandb - -import openpi.models.pi0_config -import openpi.models_pytorch.pi0_pytorch -import openpi.shared.normalize as _normalize -import openpi.training.config as _config -import openpi.training.data_loader as _data - - -def init_logging(): - level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"} - - class CustomFormatter(logging.Formatter): - def format(self, record): - record.levelname = level_mapping.get(record.levelname, record.levelname) - return super().format(record) - - formatter = CustomFormatter( - fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)", - datefmt="%H:%M:%S", - ) - logger = logging.getLogger() - logger.setLevel(logging.INFO) - if not logger.handlers: - ch = logging.StreamHandler() - ch.setFormatter(formatter) - logger.addHandler(ch) - else: - logger.handlers[0].setFormatter(formatter) - - -def init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: bool = True): - """Initialize wandb logging.""" - if not enabled: - wandb.init(mode="disabled") - return - - ckpt_dir = config.checkpoint_dir - if not ckpt_dir.exists(): - raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.") - - if resuming: - run_id = (ckpt_dir / "wandb_id.txt").read_text().strip() - wandb.init(id=run_id, resume="must", project=config.project_name) - else: - wandb.init( - name=config.exp_name, - config=dataclasses.asdict(config), - project=config.project_name, - ) - (ckpt_dir / "wandb_id.txt").write_text(wandb.run.id) - - -def setup_ddp(): - world_size = int(os.environ.get("WORLD_SIZE", "1")) - use_ddp = world_size > 1 - if use_ddp and not torch.distributed.is_initialized(): - backend = "nccl" if torch.cuda.is_available() else "gloo" - torch.distributed.init_process_group(backend=backend, init_method="env://") - - # Set up debugging environment variables for DDP issues - if os.environ.get("TORCH_DISTRIBUTED_DEBUG") is None: - os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO" - - local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0"))) - device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") - if torch.cuda.is_available(): - torch.cuda.set_device(device) - return use_ddp, local_rank, device - - -def cleanup_ddp(): - if torch.distributed.is_initialized(): - torch.distributed.barrier() - torch.distributed.destroy_process_group() - - -def set_seed(seed: int, local_rank: int): - torch.manual_seed(seed + local_rank) - np.random.seed(seed + local_rank) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed + local_rank) - - -def build_datasets(config: _config.TrainConfig): - # Use the unified data loader with PyTorch framework - data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=True) - return data_loader, data_loader.data_config() - - -def get_model_state_dict(model): - """Get state dict from model, handling DDP wrapper.""" - return ( - model.module.state_dict() - if isinstance(model, torch.nn.parallel.DistributedDataParallel) - else model.state_dict() - ) - - -def get_model_parameters(model): - """Get parameters from model, handling DDP wrapper.""" - return ( - model.module.parameters() - if isinstance(model, torch.nn.parallel.DistributedDataParallel) - else model.parameters() - ) - - -def save_checkpoint(model, optimizer, global_step, config, is_main, data_config): - """Save a checkpoint with model state, optimizer state, and metadata.""" - if not is_main: - return - - # Only save if it's time to save or if it's the final step - if (global_step % config.save_interval == 0 and global_step > 0) or global_step == config.num_train_steps - 1: - # Create temporary directory for atomic checkpoint saving - final_ckpt_dir = config.checkpoint_dir / f"{global_step}" - tmp_ckpt_dir = config.checkpoint_dir / f"tmp_{global_step}" - - # Remove any existing temp directory and create new one - if tmp_ckpt_dir.exists(): - shutil.rmtree(tmp_ckpt_dir) - tmp_ckpt_dir.mkdir(parents=True, exist_ok=True) - - # Save model state using safetensors (handle shared tensors) - model_to_save = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model - safetensors.torch.save_model(model_to_save, tmp_ckpt_dir / "model.safetensors") - - # Save optimizer state using PyTorch format - torch.save(optimizer.state_dict(), tmp_ckpt_dir / "optimizer.pt") - - # Save training metadata (avoid saving full config to prevent JAX/Flax compatibility issues) - metadata = { - "global_step": global_step, - "config": dataclasses.asdict(config), - "timestamp": time.time(), - } - torch.save(metadata, tmp_ckpt_dir / "metadata.pt") - - # save norm stats - norm_stats = data_config.norm_stats - if norm_stats is not None and data_config.asset_id is not None: - _normalize.save(tmp_ckpt_dir / "assets" / data_config.asset_id, norm_stats) - - # Atomically move temp directory to final location - if final_ckpt_dir.exists(): - shutil.rmtree(final_ckpt_dir) - tmp_ckpt_dir.rename(final_ckpt_dir) - - logging.info(f"Saved checkpoint at step {global_step} -> {final_ckpt_dir}") - - # Log checkpoint to wandb - if config.wandb_enabled: - wandb.log({"checkpoint_step": global_step}, step=global_step) - - -def load_checkpoint(model, optimizer, checkpoint_dir, device): - """Load the latest checkpoint and return the global step.""" - checkpoint_steps = [ - int(d.name) - for d in checkpoint_dir.iterdir() - if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_") - ] - - if not checkpoint_steps: - raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") - - latest_step = max(checkpoint_steps) - ckpt_dir = checkpoint_dir / f"{latest_step}" - - # Clear memory before loading checkpoints - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - log_memory_usage(device, latest_step, "before_loading_checkpoint") - - try: - # Load model state with error handling - logging.info("Loading model state...") - safetensors_path = ckpt_dir / "model.safetensors" - - if safetensors_path.exists(): - model_to_load = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model - safetensors.torch.load_model(model_to_load, safetensors_path, device=str(device)) - logging.info("Loaded model state from safetensors format") - else: - raise FileNotFoundError(f"No model checkpoint found at {ckpt_dir}") - - torch.cuda.empty_cache() - gc.collect() - log_memory_usage(device, latest_step, "after_loading_model") - - # Load optimizer state with error handling - logging.info("Loading optimizer state...") - optimizer_path = ckpt_dir / "optimizer.pt" - - if optimizer_path.exists(): - optimizer_state_dict = torch.load(optimizer_path, map_location=device, weights_only=False) - logging.info("Loaded optimizer state from pt format") - else: - raise FileNotFoundError(f"No optimizer checkpoint found at {ckpt_dir}") - - optimizer.load_state_dict(optimizer_state_dict) - del optimizer_state_dict - torch.cuda.empty_cache() - gc.collect() - log_memory_usage(device, latest_step, "after_loading_optimizer") - - # Load metadata - logging.info("Loading metadata...") - metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False) - global_step = metadata.get("global_step", latest_step) - del metadata - torch.cuda.empty_cache() - gc.collect() - log_memory_usage(device, latest_step, "after_loading_metadata") - - logging.info(f"Successfully loaded all checkpoint components from step {latest_step}") - return global_step - - except RuntimeError as e: - if "out of memory" in str(e): - # Clear memory and provide detailed error message - torch.cuda.empty_cache() - gc.collect() - logging.error(f"Out of memory error while loading checkpoint: {e!s}") - log_memory_usage(device, latest_step, "after_oom_error") - raise RuntimeError( - "Out of memory while loading checkpoint. Try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True" - ) from e - raise - - -def get_latest_checkpoint_step(checkpoint_dir): - """Get the latest checkpoint step number from a checkpoint directory.""" - checkpoint_steps = [ - int(d.name) - for d in checkpoint_dir.iterdir() - if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_") - ] - return max(checkpoint_steps) if checkpoint_steps else None - - -def log_memory_usage(device, step, phase="unknown"): - """Log detailed memory usage information.""" - if not torch.cuda.is_available(): - return - - memory_allocated = torch.cuda.memory_allocated(device) / 1e9 - memory_reserved = torch.cuda.memory_reserved(device) / 1e9 - memory_free = torch.cuda.memory_reserved(device) - torch.cuda.memory_allocated(device) - memory_free = memory_free / 1e9 - - # Get more detailed memory info - memory_stats = torch.cuda.memory_stats(device) - max_memory_allocated = memory_stats.get("allocated_bytes.all.peak", 0) / 1e9 - max_memory_reserved = memory_stats.get("reserved_bytes.all.peak", 0) / 1e9 - - # Get DDP info if available - ddp_info = "" - if dist.is_initialized(): - ddp_info = f" | DDP: rank={dist.get_rank()}, world_size={dist.get_world_size()}" - - logging.info( - f"Step {step} ({phase}): GPU memory - allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB, free: {memory_free:.2f}GB, peak_allocated: {max_memory_allocated:.2f}GB, peak_reserved: {max_memory_reserved:.2f}GB{ddp_info}" - ) - - -def train_loop(config: _config.TrainConfig): - use_ddp, local_rank, device = setup_ddp() - is_main = (not use_ddp) or (dist.get_rank() == 0) - set_seed(config.seed, local_rank) - - # Initialize checkpoint directory and wandb - resuming = False - if config.resume: - # Find checkpoint directory based on experiment name - exp_checkpoint_dir = config.checkpoint_dir - if exp_checkpoint_dir.exists(): - # Use validation to find the latest working checkpoint - latest_step = get_latest_checkpoint_step(exp_checkpoint_dir) - if latest_step is not None: - resuming = True - logging.info( - f"Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}" - ) - else: - raise FileNotFoundError(f"No valid checkpoints found in {exp_checkpoint_dir} for resume") - else: - raise FileNotFoundError(f"Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume") - elif config.overwrite and config.checkpoint_dir.exists(): - shutil.rmtree(config.checkpoint_dir) - logging.info(f"Overwriting checkpoint directory: {config.checkpoint_dir}") - - # Create checkpoint directory with experiment name - if not resuming: - # For new runs, create experiment-specific checkpoint directory - exp_checkpoint_dir = config.checkpoint_dir - exp_checkpoint_dir.mkdir(parents=True, exist_ok=True) - logging.info(f"Created experiment checkpoint directory: {exp_checkpoint_dir}") - else: - # For resume, checkpoint_dir is already set to the experiment directory - logging.info(f"Using existing experiment checkpoint directory: {config.checkpoint_dir}") - - # Initialize wandb (only on main process) - if is_main: - init_wandb(config, resuming=resuming, enabled=config.wandb_enabled) - - # Build data loader using the unified data loader - # Calculate effective batch size per GPU for DDP - # For N GPUs, each GPU should get batch_size/N samples, so total across all GPUs is batch_size - world_size = torch.distributed.get_world_size() if use_ddp else 1 - effective_batch_size = config.batch_size // world_size - logging.info( - f"Using batch size per GPU: {effective_batch_size} (total batch size across {world_size} GPUs: {config.batch_size})" - ) - - # Pass the original batch size to data loader - it will handle DDP splitting internally - loader, data_config = build_datasets(config) - - # Log sample images to wandb on first batch - if is_main and config.wandb_enabled and not resuming: - # Create a separate data loader for sample batch to avoid consuming the main loader - sample_data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=False) - sample_batch = next(iter(sample_data_loader)) - # Convert observation and actions to torch tensors - observation, actions = sample_batch - sample_batch = observation.to_dict() - sample_batch["actions"] = actions - - # Create sample images for wandb - images_to_log = [] - # Get batch size from the first image tensor - batch_size = next(iter(sample_batch["image"].values())).shape[0] - for i in range(min(5, batch_size)): - # Concatenate all camera views horizontally for this batch item - # Convert from NCHW to NHWC format for wandb - img_concatenated = torch.cat([img[i].permute(1, 2, 0) for img in sample_batch["image"].values()], axis=1) - img_concatenated = img_concatenated.cpu().numpy() - images_to_log.append(wandb.Image(img_concatenated)) - - wandb.log({"camera_views": images_to_log}, step=0) - - # Clear sample batch from memory aggressively - del sample_batch, observation, actions, images_to_log, img_concatenated - del sample_data_loader # Also delete the sample data loader - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - logging.info("Cleared sample batch and data loader from memory") - - # Build model - if not isinstance(config.model, openpi.models.pi0_config.Pi0Config): - # Convert dataclass to Pi0Config if needed - model_cfg = openpi.models.pi0_config.Pi0Config( - dtype=config.pytorch_training_precision, - action_dim=config.model.action_dim, - action_horizon=config.model.action_horizon, - max_token_len=config.model.max_token_len, - paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"), - action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"), - pi05=getattr(config.model, "pi05", False), - ) - else: - model_cfg = config.model - # Update dtype to match pytorch_training_precision - object.__setattr__(model_cfg, "dtype", config.pytorch_training_precision) - - model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_cfg).to(device) - - if hasattr(model, "gradient_checkpointing_enable"): - enable_gradient_checkpointing = True - model.gradient_checkpointing_enable() - logging.info("Enabled gradient checkpointing for memory optimization") - else: - enable_gradient_checkpointing = False - logging.info("Gradient checkpointing is not supported for this model") - - # Log initial memory usage after model creation - if is_main and torch.cuda.is_available(): - log_memory_usage(device, 0, "after_model_creation") - - # Enable memory optimizations for large-scale training - if world_size >= 8: - torch.backends.cudnn.benchmark = True - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - # Set memory allocation configuration - os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True" - logging.info("Enabled memory optimizations for 8+ GPU training") - - if use_ddp: - model = torch.nn.parallel.DistributedDataParallel( - model, - device_ids=[device.index] if device.type == "cuda" else None, - find_unused_parameters=True, # Disable for memory efficiency - gradient_as_bucket_view=True, # Enable for memory efficiency - static_graph=world_size >= 8, # Enable for 8+ GPUs - ) - - # Load weights from weight_loader if specified (for fine-tuning) - if config.pytorch_weight_path is not None: - logging.info(f"Loading weights from: {config.pytorch_weight_path}") - - model_path = os.path.join(config.pytorch_weight_path, "model.safetensors") - safetensors.torch.load_model( - (model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model), model_path - ) - logging.info(f"Loaded PyTorch weights from {config.pytorch_weight_path}") - - # Optimizer + learning rate schedule from config - warmup_steps = config.lr_schedule.warmup_steps - peak_lr = config.lr_schedule.peak_lr - decay_steps = config.lr_schedule.decay_steps - end_lr = config.lr_schedule.decay_lr - - # Create optimizer with config parameters - optim = torch.optim.AdamW( - model.parameters(), - lr=peak_lr, - betas=(config.optimizer.b1, config.optimizer.b2), - eps=config.optimizer.eps, - weight_decay=config.optimizer.weight_decay, - ) - - # Load checkpoint if resuming - global_step = 0 - if resuming: - global_step = load_checkpoint(model, optim, config.checkpoint_dir, device) - logging.info(f"Resumed training from step {global_step}") - - def lr_schedule(step: int): - if step < warmup_steps: - # Match JAX behavior: start from peak_lr / (warmup_steps + 1) - init_lr = peak_lr / (warmup_steps + 1) - return init_lr + (peak_lr - init_lr) * step / warmup_steps - # cosine decay - progress = min(1.0, (step - warmup_steps) / max(1, decay_steps - warmup_steps)) - cos = 0.5 * (1 + np.cos(np.pi * progress)) - return end_lr + (peak_lr - end_lr) * cos - - model.train() - start_time = time.time() - infos = [] # Collect stats over log interval - if is_main: - logging.info( - f"Running on: {platform.node()} | world_size={torch.distributed.get_world_size() if use_ddp else 1}" - ) - logging.info( - f"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}" - ) - logging.info(f"Memory optimizations: gradient_checkpointing={enable_gradient_checkpointing}") - logging.info( - f"LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}" - ) - logging.info( - f"Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}" - ) - logging.info("EMA is not supported for PyTorch training") - logging.info(f"Training precision: {model_cfg.dtype}") - - # Training loop - iterate until we reach num_train_steps - pbar = ( - tqdm.tqdm(total=config.num_train_steps, initial=global_step, desc="Training", disable=not is_main) - if is_main - else None - ) - - while global_step < config.num_train_steps: - # Set epoch for distributed training - if use_ddp and hasattr(loader, "set_epoch"): - loader.set_epoch(global_step // len(loader)) - - for observation, actions in loader: - # Check if we've reached the target number of steps - if global_step >= config.num_train_steps: - break - - # The unified data loader returns (observation, actions) tuple - observation = jax.tree.map(lambda x: x.to(device), observation) # noqa: PLW2901 - actions = actions.to(torch.float32) # noqa: PLW2901 - actions = actions.to(device) # noqa: PLW2901 - - # Update LR - for pg in optim.param_groups: - pg["lr"] = lr_schedule(global_step) - - # Forward pass - losses = model(observation, actions) - # Ensure losses is a tensor and handle different return types - if isinstance(losses, list | tuple): - losses = torch.stack(losses) - elif not isinstance(losses, torch.Tensor): - losses = torch.tensor(losses, device=device, dtype=torch.float32) - - loss = losses.mean() - - # Backward pass - loss.backward() - - # Log memory usage after backward pass - if global_step < 5 and is_main and torch.cuda.is_available(): - log_memory_usage(device, global_step, "after_backward") - - # Gradient clipping - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optimizer.clip_gradient_norm) - - # Optimizer step - optim.step() - optim.zero_grad(set_to_none=True) - - # Clear gradients more aggressively - for param in model.parameters(): - if param.grad is not None: - param.grad.detach_() - param.grad = None - - # Collect stats - if is_main: - infos.append( - { - "loss": loss.item(), - "learning_rate": optim.param_groups[0]["lr"], - "grad_norm": float(grad_norm) if isinstance(grad_norm, torch.Tensor) else grad_norm, - } - ) - - if is_main and (global_step % config.log_interval == 0): - elapsed = time.time() - start_time - - # Average stats over log interval - avg_loss = sum(info["loss"] for info in infos) / len(infos) - avg_lr = sum(info["learning_rate"] for info in infos) / len(infos) - - avg_grad_norm = None - if any("grad_norm" in info for info in infos): - vals = [ - info["grad_norm"] for info in infos if "grad_norm" in info and info["grad_norm"] is not None - ] - if len(vals) > 0: - avg_grad_norm = sum(vals) / len(vals) - logging.info( - f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} grad_norm={avg_grad_norm:.2f} time={elapsed:.1f}s" - if avg_grad_norm is not None - else f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} time={elapsed:.1f}s" - ) - - # Log to wandb - if config.wandb_enabled and len(infos) > 0: - log_payload = { - "loss": avg_loss, - "learning_rate": avg_lr, - "step": global_step, - "time_per_step": elapsed / config.log_interval, - } - if avg_grad_norm is not None: - log_payload["grad_norm"] = avg_grad_norm - wandb.log(log_payload, step=global_step) - - start_time = time.time() - infos = [] # Reset stats collection - - global_step += 1 - # Save checkpoint using the new mechanism - save_checkpoint(model, optim, global_step, config, is_main, data_config) - - # Update progress bar - if pbar is not None: - pbar.update(1) - pbar.set_postfix( - {"loss": f"{loss.item():.4f}", "lr": f"{optim.param_groups[0]['lr']:.2e}", "step": global_step} - ) - - # Close progress bar - if pbar is not None: - pbar.close() - - # Finish wandb run - if is_main and config.wandb_enabled: - wandb.finish() - - cleanup_ddp() - - -def main(): - init_logging() - config = _config.cli() - train_loop(config) - - -if __name__ == "__main__": - main() diff --git a/capvector-pi05/scripts/train_regular_loss_pytorch.py b/capvector-pi05/scripts/train_regular_loss_pytorch.py deleted file mode 100644 index 2a688cfc72b950f46073b24177fbc4b6b13246f6..0000000000000000000000000000000000000000 --- a/capvector-pi05/scripts/train_regular_loss_pytorch.py +++ /dev/null @@ -1,754 +0,0 @@ -""" -PyTorch training entrypoint for PI0/PI05 with multi-GPU and multi-node (DDP) support. -This script mirrors the behavior of the JAX trainer (`scripts/train.py`) but runs -entirely in PyTorch using the `PI0Pytorch` model and your existing config/data -pipeline from `src/openpi/training/config.py` and `src/openpi/training/data_loader.py`. - -Usage -Single GPU: - python scripts/train_pytorch.py --exp_name --save_interval - Example: - python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test - python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test --resume # Resume from latest checkpoint -Multi-GPU (single node): - torchrun --standalone --nnodes=1 --nproc_per_node= scripts/train_pytorch.py --exp_name - Example: - torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test - torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume -Multi-Node Training: - torchrun \ - --nnodes= --nproc_per_node= --node_rank= \ - --master_addr= --master_port= \ - scripts/train_pytorch.py --exp_name= --save_interval - -""" - -import dataclasses -import gc -import logging -import os -import platform -from pathlib import Path -import shutil -import time - -import jax -import numpy as np -import safetensors.torch -import torch -import torch.distributed as dist -import torch.nn.parallel -import tqdm -import wandb - -import openpi.models.pi0_config -import openpi.models_pytorch.pi0_pytorch -import openpi.shared.normalize as _normalize -import openpi.training.config as _config -import openpi.training.data_loader as _data - - -def init_logging(): - level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"} - - class CustomFormatter(logging.Formatter): - def format(self, record): - record.levelname = level_mapping.get(record.levelname, record.levelname) - return super().format(record) - - formatter = CustomFormatter( - fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)", - datefmt="%H:%M:%S", - ) - logger = logging.getLogger() - logger.setLevel(logging.INFO) - if not logger.handlers: - ch = logging.StreamHandler() - ch.setFormatter(formatter) - logger.addHandler(ch) - else: - logger.handlers[0].setFormatter(formatter) - - -def init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: bool = True): - """Initialize wandb logging.""" - if not enabled: - wandb.init(mode="disabled") - return - - ckpt_dir = config.checkpoint_dir - if not ckpt_dir.exists(): - raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.") - - if resuming: - run_id = (ckpt_dir / "wandb_id.txt").read_text().strip() - wandb.init(id=run_id, resume="must", project=config.project_name) - else: - wandb.init( - name=config.name, - config=dataclasses.asdict(config), - project=config.project_name, - id="-".join([config.name, config.exp_name]), - ) - (ckpt_dir / "wandb_id.txt").write_text(wandb.run.id) - - -def setup_ddp(): - world_size = int(os.environ.get("WORLD_SIZE", "1")) - use_ddp = world_size > 1 - if use_ddp and not torch.distributed.is_initialized(): - backend = "nccl" if torch.cuda.is_available() else "gloo" - torch.distributed.init_process_group(backend=backend, init_method="env://") - - # Set up debugging environment variables for DDP issues - if os.environ.get("TORCH_DISTRIBUTED_DEBUG") is None: - os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO" - - local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0"))) - device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") - if torch.cuda.is_available(): - torch.cuda.set_device(device) - return use_ddp, local_rank, device - - -def cleanup_ddp(): - if torch.distributed.is_initialized(): - torch.distributed.barrier() - torch.distributed.destroy_process_group() - - -def set_seed(seed: int, local_rank: int): - torch.manual_seed(seed + local_rank) - np.random.seed(seed + local_rank) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed + local_rank) - - -def build_datasets(config: _config.TrainConfig): - # Use the unified data loader with PyTorch framework - data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=True) - return data_loader, data_loader.data_config() - - -def get_model_state_dict(model): - """Get state dict from model, handling DDP wrapper.""" - return ( - model.module.state_dict() - if isinstance(model, torch.nn.parallel.DistributedDataParallel) - else model.state_dict() - ) - - -def get_model_parameters(model): - """Get parameters from model, handling DDP wrapper.""" - return ( - model.module.parameters() - if isinstance(model, torch.nn.parallel.DistributedDataParallel) - else model.parameters() - ) - - -def load_regular_vector_dict(path: str | Path) -> dict[str, torch.Tensor]: - """Load the regularization vectors, which are used for delta-based regularization.""" - tensor_path = Path(path) - suffix = tensor_path.suffix.lower() - - if suffix in {".pt", ".pth"}: - tensors = torch.load(tensor_path, map_location="cpu", weights_only=False, mmap=True) - elif suffix == ".safetensors": - tensors = safetensors.torch.load_file(str(tensor_path), device="cpu") - else: - raise ValueError(f"Unsupported tensor file format: {tensor_path}") - - return tensors["state_dict"] - - -def prepare_regularization_context( - model, - config: _config.TrainConfig, -) -> dict | None: - """Load regularization tensors and build the runtime context for delta-based regularization.""" - - # Don't use regularization optionally - if not config.regularization_vector_path or config.regularization_coeff == 0.0: - return None - - # Get the regularization vectors as reference directions - if config.resume: - raise ValueError( - "Delta-based regularization with --resume is not supported in this PyTorch trainer. " - "This run now keeps the anchor only in memory at startup." - ) - vector_path = Path(config.regularization_vector_path).expanduser() - if not vector_path.exists(): - raise FileNotFoundError(f"Regularization vector file does not exist: {vector_path}") - regularization_vectors = load_regular_vector_dict(vector_path) - - # Get the model's trainable parameters to be regularized and the corresponding freezing anchors at startup - model_module = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model - - trainable_entries = [] - missing_vectors = 0 - shape_mismatches = 0 - trainable_param_names = set() - - for name, param in model_module.named_parameters(): - if not param.requires_grad: - continue - trainable_param_names.add(name) - regularization_vector = regularization_vectors.get(name) - if regularization_vector is None: - missing_vectors += 1 - continue - anchor_param = param.detach().clone().contiguous() - if regularization_vector.shape != param.shape or anchor_param.shape != param.shape: - shape_mismatches += 1 - continue - trainable_entries.append( - { - "name": name, - "param": param, - "anchor": anchor_param, - "vector": regularization_vector.to(device=param.device, dtype=param.dtype).contiguous(), - } - ) - - logging.info( - "Regularization coverage: matched=%d missing_vectors=%d shape_mismatches=%d", - len(trainable_entries), - missing_vectors, - shape_mismatches, - ) - - return { - "entries": trainable_entries, - "weight": config.regularization_coeff, - "vector_path": str(vector_path), - } - - -def compute_regularization_loss(regularization_context: dict | None, device: torch.device) -> torch.Tensor: - """Compute the delta-based regularization loss for the current model parameters.""" - reg_loss = torch.zeros((), device=device, dtype=torch.float32) - - if not regularization_context: - return reg_loss - - for entry in regularization_context["entries"]: - param = entry["param"] - anchor = entry["anchor"] - vector = entry["vector"] - - delta = (param - anchor).reshape(-1).float() - direction = vector.reshape(-1).float() - reg_loss = reg_loss + torch.abs(torch.dot(delta, direction)) - - return reg_loss * regularization_context["weight"] - - -def save_checkpoint(model, optimizer, global_step, config, is_main, data_config): - """Save a checkpoint with model state, optimizer state, and metadata.""" - if not is_main: - return - - # Only save if it's time to save or if it's the final step - if (global_step % config.save_interval == 0 and global_step > 0) or global_step == config.num_train_steps - 1: - # Create temporary directory for atomic checkpoint saving - final_ckpt_dir = config.checkpoint_dir / f"{global_step}" - tmp_ckpt_dir = config.checkpoint_dir / f"tmp_{global_step}" - - # Remove any existing temp directory and create new one - if tmp_ckpt_dir.exists(): - shutil.rmtree(tmp_ckpt_dir) - tmp_ckpt_dir.mkdir(parents=True, exist_ok=True) - - # Save model state using safetensors (handle shared tensors) - model_to_save = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model - safetensors.torch.save_model(model_to_save, tmp_ckpt_dir / "model.safetensors") - - # Save optimizer state using PyTorch format - torch.save(optimizer.state_dict(), tmp_ckpt_dir / "optimizer.pt") - - # Save training metadata (avoid saving full config to prevent JAX/Flax compatibility issues) - metadata = { - "global_step": global_step, - "config": dataclasses.asdict(config), - "timestamp": time.time(), - } - torch.save(metadata, tmp_ckpt_dir / "metadata.pt") - - # save norm stats - norm_stats = data_config.norm_stats - if norm_stats is not None and data_config.asset_id is not None: - _normalize.save(tmp_ckpt_dir / "assets" / data_config.asset_id, norm_stats) - - # Atomically move temp directory to final location - if final_ckpt_dir.exists(): - shutil.rmtree(final_ckpt_dir) - tmp_ckpt_dir.rename(final_ckpt_dir) - - logging.info(f"Saved checkpoint at step {global_step} -> {final_ckpt_dir}") - - # Log checkpoint to wandb - if config.wandb_enabled: - wandb.log({"checkpoint_step": global_step}, step=global_step) - - -def load_checkpoint(model, optimizer, checkpoint_dir, device): - """Load the latest checkpoint and return the global step.""" - checkpoint_steps = [ - int(d.name) - for d in checkpoint_dir.iterdir() - if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_") - ] - - if not checkpoint_steps: - raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") - - latest_step = max(checkpoint_steps) - ckpt_dir = checkpoint_dir / f"{latest_step}" - - # Clear memory before loading checkpoints - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - log_memory_usage(device, latest_step, "before_loading_checkpoint") - - try: - # Load model state with error handling - logging.info("Loading model state...") - safetensors_path = ckpt_dir / "model.safetensors" - - if safetensors_path.exists(): - model_to_load = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model - safetensors.torch.load_model(model_to_load, safetensors_path, device=str(device)) - logging.info("Loaded model state from safetensors format") - else: - raise FileNotFoundError(f"No model checkpoint found at {ckpt_dir}") - - torch.cuda.empty_cache() - gc.collect() - log_memory_usage(device, latest_step, "after_loading_model") - - # Load optimizer state with error handling - logging.info("Loading optimizer state...") - optimizer_path = ckpt_dir / "optimizer.pt" - - if optimizer_path.exists(): - optimizer_state_dict = torch.load(optimizer_path, map_location=device, weights_only=False) - logging.info("Loaded optimizer state from pt format") - else: - raise FileNotFoundError(f"No optimizer checkpoint found at {ckpt_dir}") - - optimizer.load_state_dict(optimizer_state_dict) - del optimizer_state_dict - torch.cuda.empty_cache() - gc.collect() - log_memory_usage(device, latest_step, "after_loading_optimizer") - - # Load metadata - logging.info("Loading metadata...") - metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False) - global_step = metadata.get("global_step", latest_step) - del metadata - torch.cuda.empty_cache() - gc.collect() - log_memory_usage(device, latest_step, "after_loading_metadata") - - logging.info(f"Successfully loaded all checkpoint components from step {latest_step}") - return global_step - - except RuntimeError as e: - if "out of memory" in str(e): - # Clear memory and provide detailed error message - torch.cuda.empty_cache() - gc.collect() - logging.error(f"Out of memory error while loading checkpoint: {e!s}") - log_memory_usage(device, latest_step, "after_oom_error") - raise RuntimeError( - "Out of memory while loading checkpoint. Try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True" - ) from e - raise - - -def get_latest_checkpoint_step(checkpoint_dir): - """Get the latest checkpoint step number from a checkpoint directory.""" - checkpoint_steps = [ - int(d.name) - for d in checkpoint_dir.iterdir() - if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_") - ] - return max(checkpoint_steps) if checkpoint_steps else None - - -def log_memory_usage(device, step, phase="unknown"): - """Log detailed memory usage information.""" - if not torch.cuda.is_available(): - return - - memory_allocated = torch.cuda.memory_allocated(device) / 1e9 - memory_reserved = torch.cuda.memory_reserved(device) / 1e9 - memory_free = torch.cuda.memory_reserved(device) - torch.cuda.memory_allocated(device) - memory_free = memory_free / 1e9 - - # Get more detailed memory info - memory_stats = torch.cuda.memory_stats(device) - max_memory_allocated = memory_stats.get("allocated_bytes.all.peak", 0) / 1e9 - max_memory_reserved = memory_stats.get("reserved_bytes.all.peak", 0) / 1e9 - - # Get DDP info if available - ddp_info = "" - if dist.is_initialized(): - ddp_info = f" | DDP: rank={dist.get_rank()}, world_size={dist.get_world_size()}" - - logging.info( - f"Step {step} ({phase}): GPU memory - allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB, free: {memory_free:.2f}GB, peak_allocated: {max_memory_allocated:.2f}GB, peak_reserved: {max_memory_reserved:.2f}GB{ddp_info}" - ) - - -def train_loop(config: _config.TrainConfig): - use_ddp, local_rank, device = setup_ddp() - is_main = (not use_ddp) or (dist.get_rank() == 0) - set_seed(config.seed, local_rank) - - # Initialize checkpoint directory and wandb - resuming = False - if config.resume: - # Find checkpoint directory based on experiment name - exp_checkpoint_dir = config.checkpoint_dir - if exp_checkpoint_dir.exists(): - # Use validation to find the latest working checkpoint - latest_step = get_latest_checkpoint_step(exp_checkpoint_dir) - if latest_step is not None: - resuming = True - logging.info( - f"Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}" - ) - else: - raise FileNotFoundError(f"No valid checkpoints found in {exp_checkpoint_dir} for resume") - else: - raise FileNotFoundError(f"Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume") - elif config.overwrite and config.checkpoint_dir.exists(): - shutil.rmtree(config.checkpoint_dir) - logging.info(f"Overwriting checkpoint directory: {config.checkpoint_dir}") - - # Create checkpoint directory with experiment name - if not resuming: - # For new runs, create experiment-specific checkpoint directory - exp_checkpoint_dir = config.checkpoint_dir - exp_checkpoint_dir.mkdir(parents=True, exist_ok=True) - logging.info(f"Created experiment checkpoint directory: {exp_checkpoint_dir}") - else: - # For resume, checkpoint_dir is already set to the experiment directory - logging.info(f"Using existing experiment checkpoint directory: {config.checkpoint_dir}") - - # Initialize wandb (only on main process) - if is_main: - init_wandb(config, resuming=resuming, enabled=config.wandb_enabled) - - # Build data loader using the unified data loader - # Calculate effective batch size per GPU for DDP - # For N GPUs, each GPU should get batch_size/N samples, so total across all GPUs is batch_size - world_size = torch.distributed.get_world_size() if use_ddp else 1 - effective_batch_size = config.batch_size // world_size - logging.info( - f"Using batch size per GPU: {effective_batch_size} (total batch size across {world_size} GPUs: {config.batch_size})" - ) - - # Pass the original batch size to data loader - it will handle DDP splitting internally - loader, data_config = build_datasets(config) - - # Log sample images to wandb on first batch - if is_main and config.wandb_enabled and not resuming: - # Create a separate data loader for sample batch to avoid consuming the main loader - sample_data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=False) - sample_batch = next(iter(sample_data_loader)) - # Convert observation and actions to torch tensors - observation, actions = sample_batch - sample_batch = observation.to_dict() - sample_batch["actions"] = actions - - # Create sample images for wandb - images_to_log = [] - # Get batch size from the first image tensor - batch_size = next(iter(sample_batch["image"].values())).shape[0] - for i in range(min(5, batch_size)): - # Concatenate all camera views horizontally for this batch item - # Convert from NCHW to NHWC format for wandb - img_concatenated = torch.cat([img[i].permute(1, 2, 0) for img in sample_batch["image"].values()], axis=1) - img_concatenated = img_concatenated.cpu().numpy() - images_to_log.append(wandb.Image(img_concatenated)) - - wandb.log({"camera_views": images_to_log}, step=0) - - # Clear sample batch from memory aggressively - del sample_batch, observation, actions, images_to_log, img_concatenated - del sample_data_loader # Also delete the sample data loader - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - logging.info("Cleared sample batch and data loader from memory") - - # Build model - if not isinstance(config.model, openpi.models.pi0_config.Pi0Config): - # Convert dataclass to Pi0Config if needed - model_cfg = openpi.models.pi0_config.Pi0Config( - dtype=config.pytorch_training_precision, - action_dim=config.model.action_dim, - action_horizon=config.model.action_horizon, - max_token_len=config.model.max_token_len, - paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"), - action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"), - pi05=getattr(config.model, "pi05", False), - ) - else: - model_cfg = config.model - # Update dtype to match pytorch_training_precision - object.__setattr__(model_cfg, "dtype", config.pytorch_training_precision) - - model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_cfg).to(device) - - if hasattr(model, "gradient_checkpointing_enable"): - enable_gradient_checkpointing = True - model.gradient_checkpointing_enable() - logging.info("Enabled gradient checkpointing for memory optimization") - else: - enable_gradient_checkpointing = False - logging.info("Gradient checkpointing is not supported for this model") - - # Log initial memory usage after model creation - if is_main and torch.cuda.is_available(): - log_memory_usage(device, 0, "after_model_creation") - - # Enable memory optimizations for large-scale training - if world_size >= 8: - torch.backends.cudnn.benchmark = True - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - # Set memory allocation configuration - os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True" - logging.info("Enabled memory optimizations for 8+ GPU training") - - if use_ddp: - model = torch.nn.parallel.DistributedDataParallel( - model, - device_ids=[device.index] if device.type == "cuda" else None, - find_unused_parameters=True, # Disable for memory efficiency - gradient_as_bucket_view=True, # Enable for memory efficiency - static_graph=world_size >= 8, # Enable for 8+ GPUs - ) - - # Load weights from weight_loader if specified (for fine-tuning) - if config.pytorch_weight_path is not None: - logging.info(f"Loading weights from: {config.pytorch_weight_path}") - - model_path = os.path.join(config.pytorch_weight_path, "model.safetensors") - safetensors.torch.load_model( - (model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model), model_path - ) - logging.info(f"Loaded PyTorch weights from {config.pytorch_weight_path}") - - regularization_context = prepare_regularization_context(model, config) - - # Optimizer + learning rate schedule from config - warmup_steps = config.lr_schedule.warmup_steps - peak_lr = config.lr_schedule.peak_lr - decay_steps = config.lr_schedule.decay_steps - end_lr = config.lr_schedule.decay_lr - - # Create optimizer with config parameters - optim = torch.optim.AdamW( - model.parameters(), - lr=peak_lr, - betas=(config.optimizer.b1, config.optimizer.b2), - eps=config.optimizer.eps, - weight_decay=config.optimizer.weight_decay, - ) - - # Load checkpoint if resuming - global_step = 0 - if resuming: - global_step = load_checkpoint(model, optim, config.checkpoint_dir, device) - logging.info(f"Resumed training from step {global_step}") - - def lr_schedule(step: int): - if step < warmup_steps: - # Match JAX behavior: start from peak_lr / (warmup_steps + 1) - init_lr = peak_lr / (warmup_steps + 1) - return init_lr + (peak_lr - init_lr) * step / warmup_steps - # cosine decay - progress = min(1.0, (step - warmup_steps) / max(1, decay_steps - warmup_steps)) - cos = 0.5 * (1 + np.cos(np.pi * progress)) - return end_lr + (peak_lr - end_lr) * cos - - model.train() - start_time = time.time() - infos = [] # Collect stats over log interval - if is_main: - logging.info( - f"Running on: {platform.node()} | world_size={torch.distributed.get_world_size() if use_ddp else 1}" - ) - logging.info( - f"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}" - ) - logging.info(f"Memory optimizations: gradient_checkpointing={enable_gradient_checkpointing}") - logging.info( - f"LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}" - ) - logging.info( - f"Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}" - ) - logging.info("EMA is not supported for PyTorch training") - logging.info(f"Training precision: {model_cfg.dtype}") - if regularization_context: - logging.info( - "Delta-based regularization: enabled | weight=%.2e | vector=%s", - config.regularization_coeff, - regularization_context["vector_path"], - ) - - # Training loop - iterate until we reach num_train_steps - pbar = ( - tqdm.tqdm(total=config.num_train_steps, initial=global_step, desc="Training", disable=not is_main) - if is_main - else None - ) - - while global_step < config.num_train_steps: - # Set epoch for distributed training - if use_ddp and hasattr(loader, "set_epoch"): - loader.set_epoch(global_step // len(loader)) - - for observation, actions in loader: - # Check if we've reached the target number of steps - if global_step >= config.num_train_steps: - break - - # The unified data loader returns (observation, actions) tuple - observation = jax.tree.map(lambda x: x.to(device), observation) # noqa: PLW2901 - actions = actions.to(torch.float32) # noqa: PLW2901 - actions = actions.to(device) # noqa: PLW2901 - - # Update LR - for pg in optim.param_groups: - pg["lr"] = lr_schedule(global_step) - - # Forward pass - losses = model(observation, actions) - # Ensure losses is a tensor and handle different return types - if isinstance(losses, list | tuple): - losses = torch.stack(losses) - elif not isinstance(losses, torch.Tensor): - losses = torch.tensor(losses, device=device, dtype=torch.float32) - - action_loss = losses.mean() - regularization_loss = compute_regularization_loss(regularization_context, device) - total_loss = action_loss + regularization_loss - - # Backward pass - total_loss.backward() - - # Log memory usage after backward pass - if global_step < 5 and is_main and torch.cuda.is_available(): - log_memory_usage(device, global_step, "after_backward") - - # Gradient clipping - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optimizer.clip_gradient_norm) - - # Optimizer step - optim.step() - optim.zero_grad(set_to_none=True) - - # Clear gradients more aggressively - for param in model.parameters(): - if param.grad is not None: - param.grad.detach_() - param.grad = None - - # Collect stats - if is_main: - infos.append( - { - "action_loss": action_loss.item(), - "regularization_loss": regularization_loss.item(), - "total_loss": total_loss.item(), - "learning_rate": optim.param_groups[0]["lr"], - "grad_norm": float(grad_norm) if isinstance(grad_norm, torch.Tensor) else grad_norm, - } - ) - - if is_main and (global_step % config.log_interval == 0): - elapsed = time.time() - start_time - - # Average stats over log interval - avg_action_loss = sum(info["action_loss"] for info in infos) / len(infos) - avg_regularization_loss = sum(info["regularization_loss"] for info in infos) / len(infos) - avg_total_loss = sum(info["total_loss"] for info in infos) / len(infos) - avg_lr = sum(info["learning_rate"] for info in infos) / len(infos) - - avg_grad_norm = None - if any("grad_norm" in info for info in infos): - vals = [ - info["grad_norm"] for info in infos if "grad_norm" in info and info["grad_norm"] is not None - ] - if len(vals) > 0: - avg_grad_norm = sum(vals) / len(vals) - logging.info( - f"step={global_step} action_loss={avg_action_loss:.4f} regularization_loss={avg_regularization_loss:.4f} total_loss={avg_total_loss:.4f} lr={avg_lr:.2e} grad_norm={avg_grad_norm:.2f} time={elapsed:.1f}s" - if avg_grad_norm is not None - else f"step={global_step} action_loss={avg_action_loss:.4f} regularization_loss={avg_regularization_loss:.4f} total_loss={avg_total_loss:.4f} lr={avg_lr:.2e} time={elapsed:.1f}s" - ) - - # Log to wandb - if config.wandb_enabled and len(infos) > 0: - log_payload = { - "action_loss": avg_action_loss, - "regularization_loss": avg_regularization_loss, - "total_loss": avg_total_loss, - "learning_rate": avg_lr, - "step": global_step, - "time_per_step": elapsed / config.log_interval, - } - if avg_grad_norm is not None: - log_payload["grad_norm"] = avg_grad_norm - wandb.log(log_payload, step=global_step) - - start_time = time.time() - infos = [] # Reset stats collection - - global_step += 1 - # Save checkpoint using the new mechanism - save_checkpoint(model, optim, global_step, config, is_main, data_config) - - # Update progress bar - if pbar is not None: - pbar.update(1) - pbar.set_postfix( - { - "action_loss": f"{action_loss.item():.4f}", - "reg_loss": f"{regularization_loss.item():.4f}", - "total_loss": f"{total_loss.item():.4f}", - "lr": f"{optim.param_groups[0]['lr']:.2e}", - "step": global_step, - } - ) - - # Close progress bar - if pbar is not None: - pbar.close() - - # Finish wandb run - if is_main and config.wandb_enabled: - wandb.finish() - - cleanup_ddp() - - -def main(): - init_logging() - config = _config.cli() - train_loop(config) - - -if __name__ == "__main__": - main() diff --git a/capvector-pi05/scripts/train_test.py b/capvector-pi05/scripts/train_test.py deleted file mode 100644 index 9e0a31234f68eadbe2721c18c1c98967f68280dc..0000000000000000000000000000000000000000 --- a/capvector-pi05/scripts/train_test.py +++ /dev/null @@ -1,30 +0,0 @@ -import dataclasses -import os -import pathlib - -import pytest - -os.environ["JAX_PLATFORMS"] = "cpu" - -from openpi.training import config as _config - -from . import train - - -@pytest.mark.parametrize("config_name", ["debug"]) -def test_train(tmp_path: pathlib.Path, config_name: str): - config = dataclasses.replace( - _config._CONFIGS_DICT[config_name], # noqa: SLF001 - batch_size=2, - checkpoint_base_dir=str(tmp_path / "checkpoint"), - exp_name="test", - overwrite=False, - resume=False, - num_train_steps=2, - log_interval=1, - ) - train.main(config) - - # test resuming - config = dataclasses.replace(config, resume=True, num_train_steps=4) - train.main(config) diff --git a/capvector-pi05/src/openpi/__init__.py b/capvector-pi05/src/openpi/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/capvector-pi05/src/openpi/conftest.py b/capvector-pi05/src/openpi/conftest.py deleted file mode 100644 index dfc58102eae19451f63992024bb9a856c78b46a5..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/conftest.py +++ /dev/null @@ -1,17 +0,0 @@ -import os - -import pynvml -import pytest - - -def set_jax_cpu_backend_if_no_gpu() -> None: - try: - pynvml.nvmlInit() - pynvml.nvmlShutdown() - except pynvml.NVMLError: - # No GPU found. - os.environ["JAX_PLATFORMS"] = "cpu" - - -def pytest_configure(config: pytest.Config) -> None: - set_jax_cpu_backend_if_no_gpu() diff --git a/capvector-pi05/src/openpi/models/__init__.py b/capvector-pi05/src/openpi/models/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/capvector-pi05/src/openpi/models/gemma.py b/capvector-pi05/src/openpi/models/gemma.py deleted file mode 100644 index 128a286cae227a461c88dde0e3f7f3b7bb21bce6..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/models/gemma.py +++ /dev/null @@ -1,459 +0,0 @@ -# Copyright 2024 Big Vision Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Gemma adaptation for Pi, taken from big_vision. - -We follow this einsum axis naming convention: - B: batch - T: query length - S: k/v length - N: num query heads - K: num k/v heads - G: num query heads per k/v head - H: head dim - D: d_model ("features") -""" - -from collections.abc import Sequence -import dataclasses -from typing import Literal, TypeAlias - -import einops -import flax.linen as nn -import jax -import jax.numpy as jnp - -import openpi.models.lora as lora -import openpi.shared.array_typing as at -import openpi.training.sharding as sharding - -PALIGEMMA_VOCAB_SIZE = 257_152 - - -@dataclasses.dataclass -class Config: - width: int - depth: int - mlp_dim: int - num_heads: int - num_kv_heads: int - head_dim: int - lora_configs: dict[str, lora.LoRAConfig] = dataclasses.field(default_factory=dict) - - -Variant = Literal["dummy", "gemma_300m", "gemma_300m_lora", "gemma_2b", "gemma_2b_lora"] - - -def get_config(variant: Variant) -> Config: - """Returns config for specified gemma variant.""" - if variant == "dummy": - return Config( - width=64, - depth=4, - mlp_dim=128, - num_heads=8, - num_kv_heads=1, - head_dim=16, - ) - if variant == "gemma_300m": - # 311M params - return Config( - width=1024, - depth=18, - mlp_dim=4096, - num_heads=8, - num_kv_heads=1, - head_dim=256, - ) - if variant == "gemma_2b": - return Config( - width=2048, - depth=18, - mlp_dim=16_384, - num_heads=8, - num_kv_heads=1, - head_dim=256, - ) - if variant == "gemma_2b_lora": - return Config( - width=2048, - depth=18, - mlp_dim=16_384, - num_heads=8, - num_kv_heads=1, - head_dim=256, - lora_configs={"attn": lora.LoRAConfig(rank=16, alpha=16.0), "ffn": lora.LoRAConfig(rank=16, alpha=16.0)}, - ) - if variant == "gemma_300m_lora": - # 311M params - return Config( - width=1024, - depth=18, - mlp_dim=4096, - num_heads=8, - num_kv_heads=1, - head_dim=256, - lora_configs={"attn": lora.LoRAConfig(rank=32, alpha=32.0), "ffn": lora.LoRAConfig(rank=32, alpha=32.0)}, - ) - raise ValueError(f"Unknown variant: {variant}") - - -@at.typecheck -class RMSNorm(nn.Module): - @nn.compact - def __call__(self, x, cond): - dtype = x.dtype # original dtype, could be half-precision - var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True) # compute variance in float32 - normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) # compute normalization in float32 - if cond is None: - # regular RMSNorm - scale = self.param("scale", nn.initializers.zeros_init(), (x.shape[-1])) - normed_inputs = normed_inputs * ( - 1 + scale - ) # scale by learned parameter in float32 (matches Flax implementation) - return normed_inputs.astype(dtype), None # return in original dtype - - # adaptive RMSNorm - modulation = nn.Dense(x.shape[-1] * 3, kernel_init=nn.initializers.zeros, dtype=dtype)(cond) - scale, shift, gate = jnp.split(modulation[:, None, :], 3, axis=-1) - normed_inputs = normed_inputs * (1 + scale) + shift # scale and shift in float32 - return normed_inputs.astype(dtype), gate - - -@at.typecheck -class Embedder(nn.Module): - """Embedder module.""" - - vocab_size: int - embed_dim: int - - def setup(self): - self.input_embedding_table = self.param( - "input_embedding", - nn.initializers.normal(), - (self.vocab_size, self.embed_dim), - ) - - def encode(self, x): - x = self.input_embedding_table[(x,)] - x *= jnp.sqrt(self.embed_dim).astype(x.dtype) - return x - - def decode(self, x): - return jnp.dot(x, self.input_embedding_table.T) - - -@at.typecheck -class Attention(nn.Module): - """Attention module.""" - - configs: Sequence[Config] - - @nn.compact - def __call__(self, xs, positions, attn_mask, kv_cache): - # all experts must share the same head dim, num heads, and num kv heads for self-attention to work - assert all(config.head_dim == self.configs[0].head_dim for config in self.configs) - assert all(config.num_heads == self.configs[0].num_heads for config in self.configs) - assert all(config.num_kv_heads == self.configs[0].num_kv_heads for config in self.configs) - - dtype = next(x.dtype for x in xs if x is not None) # original dtype, could be half-precision - - qkvs = [] - for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)): - if x is None: - continue - if config.num_kv_heads == config.num_heads: - qkv_einsum = lora.Einsum( - shape=(3, config.num_heads, config.width, config.head_dim), - name=_name("qkv_einsum", i), - init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)), - lora_config=config.lora_configs.get("attn"), - ) - qkvs.append(qkv_einsum("BSD,3KDH->3BSKH", x)) - else: - q_einsum = lora.Einsum( - shape=(config.num_heads, config.width, config.head_dim), - name=_name("q_einsum", i), - init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)), - lora_config=config.lora_configs.get("attn"), - ) - q = q_einsum("BTD,NDH->BTNH", x) - kv_einsum = lora.Einsum( - shape=(2, config.num_kv_heads, config.width, config.head_dim), - name=_name("kv_einsum", i), - init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)), - lora_config=config.lora_configs.get("attn"), - ) - k, v = kv_einsum("BSD,2KDH->2BSKH", x) - qkvs.append((q, k, v)) - - q, k, v = (jnp.concatenate(y, axis=1) for y in zip(*qkvs, strict=True)) - - q = _apply_rope(q, positions=positions) - q *= self.configs[0].head_dim ** -0.5 - - k = _apply_rope(k, positions=positions) - - # should still be half-precision here (if input was half-precision) - assert q.dtype == k.dtype == v.dtype == dtype - - if kv_cache is not None: - cache_k, cache_v = kv_cache - k = jnp.concatenate([cache_k, k], axis=1) - v = jnp.concatenate([cache_v, v], axis=1) - - q = einops.rearrange(q, "B T (K G) H -> B T K G H", K=self.configs[0].num_kv_heads) - logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k, preferred_element_type=jnp.float32) - - if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]): - raise ValueError( - f"Attention mask with shape {attn_mask.shape} but shapes for q and k are: {q.shape} and {k.shape}" - ) - - # big_neg = jnp.finfo(logits.dtype).min - big_neg = -2.3819763e38 # See gemma/modules.py - masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg) - - probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype) - - encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v) - encoded = einops.rearrange(encoded, "B T K G H -> B T (K G) H") - - out = [] - start = 0 - for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)): - if x is not None: - end = start + x.shape[1] - out_einsum = lora.Einsum( - shape=(config.num_heads, config.head_dim, config.width), - name=_name("attn_vec_einsum", i), - init_fn=nn.initializers.lecun_normal(in_axis=(-3, -2), out_axis=-1), - lora_config=config.lora_configs.get("attn"), - ) - out.append(out_einsum("BTNH,NHD->BTD", encoded[:, start:end])) - start = end - else: - out.append(None) - - return out, (k, v) - - -@at.typecheck -class FeedForward(nn.Module): - """Feed forward module.""" - - features: int - hidden_dim: int - - @nn.compact - def __call__(self, x): - dtype = x.dtype # original dtype, could be half-precision - w_gating = self.param( - "gating_einsum", - nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)), - (2, self.features, self.hidden_dim), - ).astype(dtype) - ff_gate = jnp.dot(x, w_gating[0]) - gate_value = nn.gelu(ff_gate) - - ff1 = jnp.dot(x, w_gating[1]) - activations = gate_value * ff1 - - w_linear = self.param( - "linear", - nn.initializers.lecun_normal(in_axis=-2, out_axis=-1), - (self.hidden_dim, self.features), - ).astype(dtype) - outputs = jnp.dot(activations, w_linear) - assert outputs.dtype == dtype - return outputs - - -@at.typecheck -class Block(nn.Module): - """Transformer block.""" - - configs: tuple[Config, ...] - - dropout: float = 0.0 - dropout_bdims: tuple[int, ...] = () - - @nn.compact - def __call__(self, xs, kv_cache, positions, attn_mask, adarms_cond, deterministic=True): # noqa: FBT002 - xs = sharding.activation_sharding_constraint(xs) - drop = nn.Dropout(self.dropout, self.dropout_bdims) if self.dropout else lambda x, _: x - - attn = Attention(configs=self.configs, name="attn") - - pre_attn = [] - gates = [] - for i, x in enumerate(xs): - if x is not None: - x, gate = RMSNorm(name=_name("pre_attention_norm", i))(x, adarms_cond[i]) # noqa: PLW2901 - pre_attn.append(x) - gates.append(gate if x is not None else None) - - pre_attn = sharding.activation_sharding_constraint(pre_attn) - post_attn, kv_cache = attn(pre_attn, positions, attn_mask, kv_cache) - post_attn = jax.tree.map(lambda x: drop(x, deterministic), post_attn) - post_attn = sharding.activation_sharding_constraint(post_attn) - xs = [_gated_residual(x, y, gate) for x, y, gate in zip(xs, post_attn, gates, strict=True)] - xs = sharding.activation_sharding_constraint(xs) - - out = [] - gates = [] - for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)): - if x is not None: - x, gate = RMSNorm(name=_name("pre_ffw_norm", i))(x, adarms_cond[i]) # noqa: PLW2901 - x = lora.FeedForward( # noqa: PLW2901 - features=config.width, - hidden_dim=config.mlp_dim, - name=_name("mlp", i), - lora_config=config.lora_configs.get("ffn"), - )(x) - out.append(x) - gates.append(gate if x is not None else None) - - out = sharding.activation_sharding_constraint(out) - out = jax.tree.map(lambda x: drop(x, deterministic), out) - xs = [_gated_residual(x, y, gate) for x, y, gate in zip(xs, out, gates, strict=True)] - xs = sharding.activation_sharding_constraint(xs) - - return xs, kv_cache - - -KVCache: TypeAlias = tuple[at.Float[at.Array, "l b _t _k _h"], at.Float[at.Array, "l b _t _v _h"]] - - -@at.typecheck -class Module(nn.Module): - """Transformer model, supporting a mixture of different weights for different tokens.""" - - configs: Sequence[Config] # list of configs, one for each expert - embed_dtype: str - - dropout: float = 0.0 - dropout_bdims: tuple[int, ...] = () # Every float is dropped independently. - adarms: bool = False - - def setup(self): - # all experts must have the same depth - assert all(config.depth == self.configs[0].depth for config in self.configs) - - self.embedder = Embedder( - vocab_size=PALIGEMMA_VOCAB_SIZE, - embed_dim=self.configs[0].width, # embedder for first expert only - name="embedder", - ) - block_cls = nn.remat( - Block, - prevent_cse=False, - static_argnums=(5,), # 0=self, 6=deterministic - policy=jax.checkpoint_policies.nothing_saveable, - ) - self.layers = nn.scan( - block_cls, - variable_axes={"params": 0}, - split_rngs={"params": True, "dropout": True}, - in_axes=( - 0, - nn.broadcast, - nn.broadcast, - nn.broadcast, - nn.broadcast, - ), # 0=kv_cache, 1=positions, 2=mask, 3=adarms_cond, 4=deterministic - length=self.configs[0].depth, - )( - configs=self.configs, - dropout=self.dropout, - dropout_bdims=self.dropout_bdims, - ) - self.final_norms = [RMSNorm(name=_name("final_norm", i)) for i in range(len(self.configs))] - - @at.typecheck - def embed(self, tokens: at.Int[at.Array, "b t"]) -> at.Float[at.Array, "b t d"]: - return self.embedder.encode(tokens).astype(self.embed_dtype) - - @at.typecheck - def __call__( - self, - # list of token arrays, one for each expert, or None if that expert should not be run - embedded: Sequence[at.Float[at.Array, "b _t _d"] | None], - positions: at.Int[at.Array, "b t"], - mask: at.Bool[at.Array, "b t s"], - adarms_cond: Sequence[at.Float[at.Array, "b _d"] | None] | None = None, - *, - kv_cache: KVCache | None = None, - deterministic: bool = True, - ) -> tuple[Sequence[at.Float[at.Array, "b _t _d"] | None], KVCache]: - embedded = jax.tree.map(lambda e: e.astype(self.embed_dtype), embedded) - mask = jnp.asarray(mask)[:, None, :, :] - if adarms_cond is None: - adarms_cond = [None] * len(self.configs) - - embedded, kv_cache = self.layers(embedded, kv_cache, positions, mask, adarms_cond, deterministic) - - assert all(e.dtype == jnp.dtype(self.embed_dtype) for e in embedded if e is not None) - - return [ - f(e, a)[0] if e is not None else e for f, e, a in zip(self.final_norms, embedded, adarms_cond, strict=True) - ], kv_cache - - def init(self, use_adarms: Sequence[bool]): - """Convenience method for initializing all parameters, necessary due to the quirks of linen.""" - self.embed(jnp.zeros((1, 1), dtype=jnp.int32)) - self( - [jnp.zeros((1, 1, c.width)) for c in self.configs], - jnp.zeros((1, len(self.configs)), dtype=jnp.int32), - jnp.zeros((1, len(self.configs), len(self.configs)), dtype=bool), - adarms_cond=[jnp.zeros((1, c.width)) if u else None for u, c in zip(use_adarms, self.configs, strict=True)], - ) - - -def _apply_rope(x, *, positions, max_wavelength=10_000): - """Applies RoPE positions [B, L] to x [B, L, H, D].""" - freq_exponents = (2.0 / x.shape[-1]) * jnp.arange(x.shape[-1] // 2, dtype=jnp.float32) - timescale = max_wavelength**freq_exponents - radians = positions[..., None] / timescale[None, None, :] - radians = radians[..., None, :] - assert radians.dtype == jnp.float32 - # radians.shape = [...,L,1,d=D/2] - sin, cos = jnp.sin(radians), jnp.cos(radians) - x1, x2 = jnp.split(x, 2, axis=-1) - res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1) - assert res.dtype == jnp.float32 - # The original bigvision impl allows RoPE to upcast to float32. It is then immediately downcast again to the cache - # dtype when in inference mode (but not in training mode). I don't think any of this was intentional. Based on the - # original DeepMind impl, as well as the widely-used transformers impl, it is ok to always downcast back to bfloat16 - # here. - return res.astype(x.dtype) - - -def _name(name, i): - # we name layers like this because we want the first expert's weights to have no suffix (e.g., "attn"), so that they - # can be loaded seamlessly from the existing PaliGemma checkpoint. subsequent experts will have a suffix (e.g., - # "attn_1") and their weights will be initialized from scratch. in practice, we only use two experts -- PaliGemma, - # and the action expert. - if i == 0: - return name - return f"{name}_{i}" - - -def _gated_residual(x, y, gate): - assert (x is None) == (y is None) - if x is None: - return None - if gate is None: - return x + y - return x + y * gate diff --git a/capvector-pi05/src/openpi/models/gemma_fast.py b/capvector-pi05/src/openpi/models/gemma_fast.py deleted file mode 100644 index 0ba787601e143338a54241dab96c0fb6a311de89..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/models/gemma_fast.py +++ /dev/null @@ -1,437 +0,0 @@ -# Copyright 2024 Big Vision Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Gemma model implementation from big_vision/models/ppp/gemma.py (with small modifications for NNX compatibility) -Used for FAST autoregressive policies. -""" - -import dataclasses -from typing import Literal, TypeAlias - -import einops -import flax.linen as nn -import jax -import jax.numpy as jnp -import ml_collections - -import openpi.models.lora as lora -import openpi.shared.array_typing as at - -Variant = Literal["gemma_2b", "gemma_2b_lora"] - - -def get_config(variant): - """Returns config for specified gemma variant.""" - if variant == "gemma_2b": - return ml_collections.ConfigDict( - { - "variant": variant, - "width": 2048, - "depth": 18, - "mlp_dim": 16_384, - "num_heads": 8, - "num_kv_heads": 1, - "head_dim": 256, - "norm_eps": 1e-6, - "vocab_size": 257_152, - "scan": True, - "remat_policy": "nothing_saveable", - } - ) - if variant == "gemma_2b_lora": - return ml_collections.ConfigDict( - { - "variant": variant, - "width": 2048, - "depth": 18, - "mlp_dim": 16_384, - "num_heads": 8, - "num_kv_heads": 1, - "head_dim": 256, - "norm_eps": 1e-6, - "vocab_size": 257_152, - "scan": True, - "remat_policy": "nothing_saveable", - "lora_configs": { - "attn": lora.LoRAConfig(rank=16, alpha=16.0), - "ffn": lora.LoRAConfig(rank=16, alpha=16.0), - }, - } - ) - raise ValueError(f"Unknown variant: {variant}") - - -@at.typecheck -class Einsum(nn.Module): - shape: tuple[int, ...] - - @nn.compact - def __call__(self, eqn, x): - dtype = x.dtype # original dtype, could be half-precision - w = self.param("w", nn.initializers.zeros_init(), self.shape).astype(dtype) - return jnp.einsum(eqn, x, w) - - -@at.typecheck -class RMSNorm(nn.Module): - @nn.compact - def __call__(self, x): - dtype = x.dtype # original dtype, could be half-precision - scale = self.param("scale", nn.initializers.zeros_init(), (x.shape[-1])) - var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True) # compute variance in float32 - normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) # compute normalization in float32 - normed_inputs = normed_inputs * ( - 1 + scale - ) # scale by learned parameter in float32 (matches Flax implementation) - return normed_inputs.astype(dtype) # return in original dtype - - -@at.typecheck -class Embedder(nn.Module): - """Embedder module.""" - - vocab_size: int - embed_dim: int - - def setup(self): - self.input_embedding_table = self.param( - "input_embedding", - nn.initializers.zeros_init(), - (self.vocab_size, self.embed_dim), - ) - - def encode(self, x): - x = self.input_embedding_table[(x,)] - x *= jnp.sqrt(self.embed_dim).astype(x.dtype) - return x - - def decode(self, x): - return jnp.dot(x, self.input_embedding_table.T) - - -@at.typecheck -class Attention(nn.Module): - """Attention module.""" - - num_heads: int - num_kv_heads: int - features: int - head_dim: int - - cache_dtype: str | None = None - - lora_config: lora.LoRAConfig | None = None - - def setup(self): - if self.num_kv_heads == self.num_heads: - self.qkv_einsum = lora.Einsum( - shape=(3, self.num_heads, self.features, self.head_dim), - name="qkv_einsum", - init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)), - lora_config=self.lora_config, - ) - else: - self.q_einsum = lora.Einsum( - shape=(self.num_heads, self.features, self.head_dim), - name="q_einsum", - init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)), - lora_config=self.lora_config, - ) - self.kv_einsum = lora.Einsum( - shape=(2, self.num_kv_heads, self.features, self.head_dim), - name="kv_einsum", - init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)), - lora_config=self.lora_config, - ) - self.attn_vec_einsum = lora.Einsum( - shape=(self.num_heads, self.head_dim, self.features), - name="attn_vec_einsum", - init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)), - lora_config=self.lora_config, - ) - - def _init_cache(self, k, v, cache_size): - """Initialize KV cache""" - prefill_len = k.shape[1] - pad_width = ((0, 0), (0, cache_size - prefill_len), (0, 0), (0, 0)) - cache_dtype = self.cache_dtype or k.dtype - k_cache = jnp.pad(k.astype(cache_dtype), pad_width) - v_cache = jnp.pad(v.astype(cache_dtype), pad_width) - idx = jnp.zeros((k.shape[0],), dtype=jnp.int32) + prefill_len - return idx, k_cache, v_cache - - def _update_cache(self, k, v, idx, k_cache, v_cache): - """Update KV cache with new values""" - assert k.shape[1] == 1, "Only support kv-cache updates of length 1" - indices = (0, idx[0], 0, 0) - cache_dtype = self.cache_dtype or k.dtype - k_new = jax.lax.dynamic_update_slice(k_cache, k.astype(cache_dtype), indices) - v_new = jax.lax.dynamic_update_slice(v_cache, v.astype(cache_dtype), indices) - idx_new = idx + 1 - return idx_new, k_new, v_new - - @nn.compact - def __call__(self, x, positions, attn_mask, kv_cache, decode, deterministic=True): # noqa: FBT002 - dtype = x.dtype # original dtype, could be half-precision - if self.num_kv_heads == self.num_heads: - q, k, v = self.qkv_einsum("BSD,3KDH->3BSKH", x) - else: - q = self.q_einsum("BTD,NDH->BTNH", x) - k, v = self.kv_einsum("BSD,2KDH->2BSKH", x) - - q = _apply_rope(q, positions=positions) # promotes to float32 - q *= self.head_dim**-0.5 - - k = _apply_rope(k, positions=positions) # promotes to float32 - - if kv_cache is None: - idx, k_cache, v_cache = self._init_cache(k, v, attn_mask.shape[-1]) - else: - idx, k_cache, v_cache = kv_cache - idx, k_cache, v_cache = self._update_cache(k, v, idx, k_cache, v_cache) - - k, v = k_cache, v_cache - kv_cache = (idx, k_cache, v_cache) - - q = einops.rearrange(q, "B T (K G) H -> B T K G H", K=self.num_kv_heads) - logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k, preferred_element_type=jnp.float32) - - if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]): - raise ValueError( - f"Attention mask with shape {attn_mask.shape} but shapes for q and k are: {q.shape} and {k.shape}" - ) - - # big_neg = jnp.finfo(logits.dtype).min - big_neg = -2.3819763e38 # See gemma/modules.py - masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg) - - probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype) - - encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v) - encoded = einops.rearrange(encoded, "B T K G H -> B T (K G) H") - return self.attn_vec_einsum("BTNH,NHD->BTD", encoded), kv_cache - - -@at.typecheck -class Block(nn.Module): - """Transformer block.""" - - num_heads: int - num_kv_heads: int - embed_dim: int - head_dim: int - hidden_dim: int - - dropout: float = 0.0 - dropout_bdims: tuple[int, ...] = () - cache_dtype: str | None = None - lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict) - - def setup(self): - self.pre_attention_norm = RMSNorm() - self.attn = Attention( - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - features=self.embed_dim, - head_dim=self.head_dim, - cache_dtype=self.cache_dtype, - lora_config=self.lora_configs.get("attn"), - ) - self.pre_ffw_norm = RMSNorm() - self.mlp = lora.FeedForward( - features=self.embed_dim, hidden_dim=self.hidden_dim, name="mlp", lora_config=self.lora_configs.get("ffn") - ) - if self.dropout: - self.drop = nn.Dropout(self.dropout, self.dropout_bdims) - else: - self.drop = lambda x, _: x - - def __call__(self, x, kv_cache, positions, attn_mask, decode, deterministic=True): # noqa: FBT002 - x = nn.with_logical_constraint(x, ("act_batch", "act_len", "act_emb")) - inputs_normalized = self.pre_attention_norm(x) - attn_output, kv_cache = self.attn(inputs_normalized, positions, attn_mask, kv_cache, decode, deterministic) - attn_output = self.drop(attn_output, deterministic) - attn_output += x - residual = attn_output - attn_output = self.pre_ffw_norm(attn_output) - outputs = self.mlp(attn_output) - outputs = self.drop(outputs, deterministic) - outputs = residual + outputs - return outputs, kv_cache - - -KVCache: TypeAlias = tuple[at.Int[at.Array, " b"], at.Float[at.Array, "b _t _k _h"], at.Float[at.Array, "b _t _v _h"]] - - -@at.typecheck -class Module(nn.Module): - """gemma model.""" - - variant: str - - width: int - depth: int - mlp_dim: int - num_heads: int - num_kv_heads: int - head_dim: int - norm_eps: float - vocab_size: int - embed_dtype: str - - dropout: float = 0.0 - dropout_bdims: tuple[int, ...] = () # Every float is dropped independently. - cache_dtype: str | None = None - - scan: bool = False - remat_policy: str = "none" - lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict) - - @nn.compact - def __call__( - self, - tokens=None, - embedded_prefix=None, - embed_only=False, # noqa: FBT002 - pre_logits=None, - positions=None, - mask=None, - decode=False, # noqa: FBT002 - kv_cache=None, - deterministic=True, # noqa: FBT002 - return_prelogits=False, # noqa: FBT002 - ): - """Embed only, or complete forward pass. - - Args: - tokens: Embedded, then and appended to `embedded_prefix`. Can be None. - embedded_prefix: Optional prefix that is already embedded. - embed_only: Whether to compute embeddings only. - pre_logits: If present computes logits from pre_logits and returns. - positions: Optional `[B, T]` allows to specify the absolute position of - the tokens. - mask: Optional attention mask `[B, T, S]`. - decode: Whether to use kv-cache. Caller must pass masks and positions. - deterministic: Forwarded to all dropout layers. - return_prelogits: Whether to return the pre-logits. - - Returns: - If `embed_only=False`, then `(logits, out)` will be returned. - If `embed_only=True`, then the embeddings will be returned. - If `return_prelogits=True`, then the pre-logits will be returned. - """ - out = {} - - embedder = Embedder(vocab_size=self.vocab_size, embed_dim=self.width, name="embedder") - - if pre_logits is not None: - x = out["pre_logits"] = pre_logits - logits = out["logits"] = embedder.decode(x) - return logits, out - - x = [] - if embedded_prefix is not None: - x.append(embedded_prefix) - if tokens is not None: - x.append(embedder.encode(tokens)) - - x = jnp.concatenate(x, axis=-2) - x = x.astype(self.embed_dtype) - batch_size, seq_len, width = x.shape - - if embed_only: - return x - - if decode: - assert positions is not None and mask is not None, ( # noqa: PT018 - "Must explicitly pass positions and mask for decoding." - ) - - if positions is None: - positions = jnp.arange(seq_len).astype(jnp.int32)[None, :] - assert positions.shape[1] == x.shape[1], (positions.shape, x.shape) - - if mask is None: - mask = nn.attention.make_causal_mask(jnp.ones([batch_size, seq_len])) - if mask.ndim == 3: - mask = mask[:, None, :, :] - cache_size = max(seq_len, mask.shape[-1]) - assert mask.shape == (batch_size, 1, seq_len, cache_size), mask.shape - - if self.remat_policy == "none": - block_cls = Block - else: - block_cls = nn.remat( - Block, - prevent_cse=not self.scan, - static_argnums=(5, 6), # 0=self, 5=decode, 6=deterministic - policy=getattr(jax.checkpoint_policies, self.remat_policy), - ) - - block_kw = { - "num_heads": self.num_heads, - "head_dim": self.head_dim, - "num_kv_heads": self.num_kv_heads, - "embed_dim": width, - "hidden_dim": self.mlp_dim, - "dropout": self.dropout, - "dropout_bdims": self.dropout_bdims, - "cache_dtype": self.cache_dtype, - "lora_configs": self.lora_configs, - } - layers = self.scope.push("layers") - blocks = [ - nn.scan( - block_cls, - variable_axes={"params": 0}, - split_rngs={"params": True, "dropout": True}, - in_axes=(0, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast), # 0=kv_cache, 1=positions, 2=mask - length=self.depth, - )(parent=layers, **block_kw) - ] - for block in blocks: - x, kv_cache = block(x, kv_cache, positions, mask, decode, deterministic) - - assert x.dtype == jnp.dtype(self.embed_dtype) # Sanity check. - out["encoded"] = x - - x = RMSNorm(name="final_norm")(x) - out["pre_logits"] = x - if return_prelogits: - return x, kv_cache, out - - x = embedder.decode(x) - out["logits"] = x - - return x, kv_cache, out - - def init(self): - """Convenience method for initializing all parameters, necessary due to the quirks of linen.""" - self(jnp.zeros((1, 1), dtype=jnp.int32)) - - -def _apply_rope(x, *, positions, max_wavelength=10_000): - """Applies RoPE positions [B, L] to x [B, L, H, D].""" - freq_exponents = (2.0 / x.shape[-1]) * jnp.arange(x.shape[-1] // 2, dtype=jnp.float32) - timescale = max_wavelength**freq_exponents - radians = positions[..., None] / timescale[None, None, :] - radians = radians[..., None, :] - assert radians.dtype == jnp.float32 - # radians.shape = [...,L,1,d=D/2] - sin, cos = jnp.sin(radians), jnp.cos(radians) - x1, x2 = jnp.split(x, 2, axis=-1) - res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1) - assert res.dtype == jnp.float32 - return res diff --git a/capvector-pi05/src/openpi/models/lora.py b/capvector-pi05/src/openpi/models/lora.py deleted file mode 100644 index 0524f2e10023837634e3855d87f35a53736ccd07..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/models/lora.py +++ /dev/null @@ -1,148 +0,0 @@ -import math -import re - -import flax.linen as nn -import flax.struct as struct -import jax.numpy as jnp - -import openpi.shared.array_typing as at - - -@struct.dataclass -class LoRAConfig: - """Configuration for LoRA.""" - - # LoRA rank. - rank: int - # LoRA scaling factor. - alpha: float = 1.0 - # Initialization function for LoRA parameters. - init_fn: nn.initializers.Initializer = nn.initializers.normal(stddev=0.01) - # Enable rank-stabilized LoRA: https://arxiv.org/pdf/2312.03732 - rslora: bool = False - # Axes in the weight to apply LoRA to. Should typically be the last two axes. - axes: tuple[int, int] = (-2, -1) - # Axis label which is used by LoRA in einsum equations. Must not be present in the original equation. - label: str = "L" - - @property - def scaling_value(self) -> float: - return self.alpha / math.sqrt(self.rank) if self.rslora else self.alpha / self.rank - - -class Einsum(nn.Module): - """Einsum with LoRA support. Can be used as a drop-in replacement for the Gemma Einsum.""" - - # Shape of the weight. - shape: tuple[int, ...] - # Initialization function for the weight. - init_fn: nn.initializers.Initializer = nn.initializers.zeros - # If not None, apply LoRA to the weight. - lora_config: LoRAConfig | None = None - - def setup(self): - self.w = self.param("w", self.init_fn, self.shape) - - if config := self.lora_config: - # Setup LoRA parameters. - shape_a, shape_b = list(self.shape), list(self.shape) - shape_a[config.axes[1]] = config.rank - shape_b[config.axes[0]] = config.rank - self.w_a = self.param("lora_a", config.init_fn, shape_a) - self.w_b = self.param("lora_b", config.init_fn, shape_b) - - @nn.compact - def __call__(self, eqn: str, x): - dtype = x.dtype # original dtype, could be half-precision - result = jnp.einsum(eqn, x, self.w.astype(dtype)) - - if config := self.lora_config: - eqn_a, eqn_b = self._make_lora_eqns(eqn) - lora = jnp.einsum(eqn_a, x, self.w_a.astype(dtype)) - lora = jnp.einsum(eqn_b, lora, self.w_b.astype(dtype)) - result = result + lora * config.scaling_value - - return result - - def _make_lora_eqns(self, eqn: str) -> tuple[str, str]: - if "L" in eqn: - raise ValueError(f"L already in eqn: {eqn}") - if not (m := re.match("(.*),(.*)->(.*)", eqn)): - raise ValueError(f"Unsupported einsum eqn: {eqn}") - lhs, rhs, out = m.groups() - - assert self.lora_config is not None - a_label, b_label = (rhs[x] for x in self.lora_config.axes) - label = self.lora_config.label - - a_rhs = rhs.replace(b_label, label) - a_out = out.replace(b_label, label) - eqn_a = f"{lhs},{a_rhs}->{a_out}" - - b_rhs = rhs.replace(a_label, label) - eqn_b = f"{a_out},{b_rhs}->{out}" - - return eqn_a, eqn_b - - -class FeedForward(nn.Module): - """Feed forward module.""" - - features: int - hidden_dim: int - # If not None, apply LoRA to the weight. - lora_config: LoRAConfig | None = None - - def setup(self): - self.w_gating = self.param( - "gating_einsum", - nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)), - (2, self.features, self.hidden_dim), - ) - self.w_linear = self.param( - "linear", - nn.initializers.lecun_normal(in_axis=-2, out_axis=-1), - (self.hidden_dim, self.features), - ) - self.w_gating_lora = None - self.w_linear_lora = None - if self.lora_config: - # Setup LoRA parameters. - # TODO: follow up with a simplified init_fn api. - self.w_gating_lora = ( - self.param("gating_einsum_lora_a", self.lora_config.init_fn, (2, self.features, self.lora_config.rank)), - self.param( - "gating_einsum_lora_b", self.lora_config.init_fn, (2, self.lora_config.rank, self.hidden_dim) - ), - ) - self.w_linear_lora = ( - self.param("linear_lora_a", self.lora_config.init_fn, (self.hidden_dim, self.lora_config.rank)), - self.param("linear_lora_b", self.lora_config.init_fn, (self.lora_config.rank, self.features)), - ) - - @nn.compact - def __call__(self, x): - dtype = x.dtype # original dtype, could be half-precision - ff_gate = self._dot( - x, - self.w_gating[0], - None if self.w_gating_lora is None else (self.w_gating_lora[0][0], self.w_gating_lora[1][0]), - ) - gate_value = nn.gelu(ff_gate) - - ff1 = self._dot( - x, - self.w_gating[1], - None if self.w_gating_lora is None else (self.w_gating_lora[0][1], self.w_gating_lora[1][1]), - ) - activations = gate_value * ff1 - - outputs = self._dot(activations, self.w_linear, self.w_linear_lora) - assert outputs.dtype == dtype - return outputs - - def _dot(self, x: at.Array, w: at.Array, lora_weights: tuple[at.Array, at.Array] | None) -> at.Array: - base = jnp.dot(x, w.astype(x.dtype)) - if lora_weights is None: - return base - return base + jnp.dot(jnp.dot(x, lora_weights[0].astype(x.dtype)), lora_weights[1].astype(x.dtype)) diff --git a/capvector-pi05/src/openpi/models/lora_test.py b/capvector-pi05/src/openpi/models/lora_test.py deleted file mode 100644 index d303c025000565b8c3002cb9a42ac8246a2a4936..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/models/lora_test.py +++ /dev/null @@ -1,94 +0,0 @@ -import flax.linen as nn -import jax -import jax.numpy as jnp - -import openpi.models.lora as lora - - -def test_lora_einsum_params_shape(): - shape = (3, 8, 32, 4) # (3KDH) - einsum = lora.Einsum(shape) - lora0 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2)) - lora1 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2, axes=(1, 2))) - - key = jax.random.key(0) - x = jax.random.normal(key, (8, 64, 32)) # (BSD) - eqn = "BSD,3KDH->3BSKH" - - # Ensure that lora parameters are not initialized when LoRA is not used. - params = einsum.init(key, eqn, x) - assert "lora_a" not in params["params"] - assert "lora_b" not in params["params"] - - # Check that default axes work. - params_lora0 = lora0.init(key, eqn, x) - assert params_lora0["params"]["lora_a"].shape == (3, 8, 32, 2) - assert params_lora0["params"]["lora_b"].shape == (3, 8, 2, 4) - - # Check that user provided axes work. - params_lora1 = lora1.init(key, eqn, x) - assert params_lora1["params"]["lora_a"].shape == (3, 8, 2, 4) - assert params_lora1["params"]["lora_b"].shape == (3, 2, 32, 4) - - -def test_lora_einsum_same_output(): - shape = (3, 8, 32, 4) # (3KDH) - einsum = lora.Einsum(shape) - einsum_lora = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros)) - - key = jax.random.key(0) - x = jax.random.normal(key, (8, 64, 32)) # (BSD) - eqn = "BSD,3KDH->3BSKH" - - params = einsum.init(key, eqn, x) - output = einsum.apply(params, eqn, x) - - params_lora = einsum_lora.init(key, eqn, x) - output_lora = einsum_lora.apply(params_lora, eqn, x) - - # Results are the same since the LoRA parameters are initialized to zeros. - assert jnp.allclose(output, output_lora) - - -def test_lora_ffn_params_shape(): - ffn = lora.FeedForward(features=8, hidden_dim=32) - ffn_lora = lora.FeedForward( - features=8, - hidden_dim=32, - lora_config=lora.LoRAConfig(rank=2), - ) - - key = jax.random.key(0) - x = jax.random.normal(key, (2, 8)) - - params = ffn.init(key, x) - assert params["params"]["gating_einsum"].shape == (2, 8, 32) - assert params["params"]["linear"].shape == (32, 8) - - params_lora = ffn_lora.init(key, x) - assert params_lora["params"]["gating_einsum"].shape == (2, 8, 32) - assert params_lora["params"]["linear"].shape == (32, 8) - assert params_lora["params"]["gating_einsum_lora_a"].shape == (2, 8, 2) - assert params_lora["params"]["gating_einsum_lora_b"].shape == (2, 2, 32) - assert params_lora["params"]["linear_lora_a"].shape == (32, 2) - assert params_lora["params"]["linear_lora_b"].shape == (2, 8) - - -def test_lora_ffn_same_output(): - ffn = lora.FeedForward(features=8, hidden_dim=32) - ffn_lora = lora.FeedForward( - features=8, - hidden_dim=32, - lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros), - ) - - key = jax.random.key(0) - x = jax.random.normal(key, (2, 8)) - - params = ffn.init(key, x) - output = ffn.apply(params, x) - - params_lora = ffn_lora.init(key, x) - output_lora = ffn_lora.apply(params_lora, x) - - assert jnp.allclose(output, output_lora) diff --git a/capvector-pi05/src/openpi/models/model.py b/capvector-pi05/src/openpi/models/model.py deleted file mode 100644 index f097be4e3558084b017c7bceabc7d78a159d7da6..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/models/model.py +++ /dev/null @@ -1,335 +0,0 @@ -import abc -from collections.abc import Sequence -import dataclasses -import enum -import logging -import pathlib -from typing import Generic, TypeVar - -import augmax -from flax import nnx -from flax import struct -from flax import traverse_util -import jax -import jax.numpy as jnp -import numpy as np -import orbax.checkpoint as ocp -import safetensors -import torch - -from openpi.models_pytorch import pi0_pytorch -from openpi.shared import image_tools -import openpi.shared.array_typing as at - -logger = logging.getLogger("openpi") - -# Type variable for array types (JAX arrays, PyTorch tensors, or numpy arrays) -ArrayT = TypeVar("ArrayT", bound=jax.Array | torch.Tensor | np.ndarray) - - -class ModelType(enum.Enum): - """Supported model types.""" - - PI0 = "pi0" - PI0_FAST = "pi0_fast" - PI05 = "pi05" - - -# The model always expects these images -IMAGE_KEYS = ( - "base_0_rgb", - "left_wrist_0_rgb", - "right_wrist_0_rgb", -) - - -# This may need change if we release a small model. -IMAGE_RESOLUTION = (224, 224) - - -# Data format -# -# Data transforms produce the model input as a nested dictionary which is later converted -# into `Obesrvation` and `Actions` objects. See below. -# -# In the dictory form, this data should look like: -# { -# # Observation data. -# "image": { -# "base_0_rgb": (float32|uint8)[*b, h, w, 3], # RGB image in [-1, 1] or [0, 255] -# ... # Additional camera views -# }, -# "image_mask": { -# "base_0_rgb": bool[*b], # True if image is valid -# ... # Masks for additional views -# }, -# "state": float32[*b, s], # Low-dimensional robot state -# "tokenized_prompt": int32[*b, l], # Optional, tokenized language prompt -# "tokenized_prompt_mask": bool[*b, l], # Optional, mask for tokenized prompt -# "token_ar_mask": int32[*b, l], # Optional, autoregressive mask for FAST model -# "token_loss_mask": bool[*b, l], # Optional, loss mask for FAST model -# -# # Actions data. -# "actions": float32[*b ah ad] -# } -# where: -# *b = batch dimensions -# h,w = image height/width -# s = state dimension -# l = sequence length -# -@at.typecheck -@struct.dataclass -class Observation(Generic[ArrayT]): - """Holds observations, i.e., inputs to the model. - - See `Observation.from_dict` to see the expected dictionary form. This is the format - that should be produced by the data transforms. - """ - - # Images, in [-1, 1] float32. - images: dict[str, at.Float[ArrayT, "*b h w c"]] - # the padding area for non-rectangular input images is False - image_padding_mask: dict[str, at.Bool[ArrayT, "*b w c"]] - # Image masks, with same keys as images. - image_masks: dict[str, at.Bool[ArrayT, "*b"]] - # Low-dimensional robot state. - state: at.Float[ArrayT, "*b s"] - - # Tokenized prompt. - tokenized_prompt: at.Int[ArrayT, "*b l"] | None = None - # Tokenized prompt mask. - tokenized_prompt_mask: at.Bool[ArrayT, "*b l"] | None = None - - # pi0-fast model specific fields. - - # Token auto-regressive mask (for FAST autoregressive model). - token_ar_mask: at.Int[ArrayT, "*b l"] | None = None - # Token loss mask (for FAST autoregressive model). - token_loss_mask: at.Bool[ArrayT, "*b l"] | None = None - - @classmethod - def from_dict(cls, data: at.PyTree[ArrayT]) -> "Observation[ArrayT]": - """This method defines the mapping between unstructured data (i.e., nested dict) to the structured Observation format.""" - # Ensure that tokenized_prompt and tokenized_prompt_mask are provided together. - if ("tokenized_prompt" in data) != ("tokenized_prompt_mask" in data): - raise ValueError("tokenized_prompt and tokenized_prompt_mask must be provided together.") - # If images are uint8, convert them to [-1, 1] float32. - for key in data["image"]: - if data["image"][key].dtype == np.uint8: - data["image"][key] = data["image"][key].astype(np.float32) / 255.0 * 2.0 - 1.0 - elif hasattr(data["image"][key], "dtype") and data["image"][key].dtype == torch.uint8: - data["image"][key] = data["image"][key].to(torch.float32).permute(0, 3, 1, 2) / 255.0 * 2.0 - 1.0 - return cls( - images=data["image"], - image_padding_mask=data.get("image_padding_mask", {}), - image_masks=data["image_mask"], - state=data["state"], - tokenized_prompt=data.get("tokenized_prompt"), - tokenized_prompt_mask=data.get("tokenized_prompt_mask"), - token_ar_mask=data.get("token_ar_mask"), - token_loss_mask=data.get("token_loss_mask"), - ) - - def to_dict(self) -> at.PyTree[ArrayT]: - """Convert the Observation to a nested dict.""" - result = dataclasses.asdict(self) - result["image"] = result.pop("images") - result["image_mask"] = result.pop("image_masks") - return result - - -# Defines the format of the actions. This field is included as "actions" inside the dictionary -# produced by the data transforms. -Actions = at.Float[ArrayT, "*b ah ad"] - - -def preprocess_observation( - rng: at.KeyArrayLike | None, - observation: Observation, - *, - train: bool = False, - image_keys: Sequence[str] = IMAGE_KEYS, - image_resolution: tuple[int, int] = IMAGE_RESOLUTION, -) -> Observation: - """Preprocess the observations by performing image augmentations (if train=True), resizing (if necessary), and - filling in a default image mask (if necessary). - """ - - if not set(image_keys).issubset(observation.images): - raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}") - - batch_shape = observation.state.shape[:-1] - - out_images = {} - for key in image_keys: - image = observation.images[key] - if image.shape[1:3] != image_resolution: - logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}") - image = image_tools.resize_with_pad(image, *image_resolution) - - if train: - # Convert from [-1, 1] to [0, 1] for augmax. - image = image / 2.0 + 0.5 - - transforms = [] - if "wrist" not in key: - height, width = image.shape[1:3] - transforms += [ - augmax.RandomCrop(int(width * 0.95), int(height * 0.95)), - augmax.Resize(width, height), - augmax.Rotate((-5, 5)), - ] - transforms += [ - augmax.ColorJitter(brightness=0.3, contrast=0.4, saturation=0.5), - ] - sub_rngs = jax.random.split(rng, image.shape[0]) - image = jax.vmap(augmax.Chain(*transforms))(sub_rngs, image) - - # Back to [-1, 1]. - image = image * 2.0 - 1.0 - - out_images[key] = image - - # obtain mask - out_masks = {} - for key in out_images: - if key not in observation.image_masks: - # do not mask by default - out_masks[key] = jnp.ones(batch_shape, dtype=jnp.bool) - else: - out_masks[key] = jnp.asarray(observation.image_masks[key]) - - return Observation( - images=out_images, - image_masks=out_masks, - state=observation.state, - tokenized_prompt=observation.tokenized_prompt, - tokenized_prompt_mask=observation.tokenized_prompt_mask, - token_ar_mask=observation.token_ar_mask, - token_loss_mask=observation.token_loss_mask, - ) - - -@dataclasses.dataclass(frozen=True) -class BaseModelConfig(abc.ABC): - """Configuration shared by all models. Specific models should inherit from this class, and implement the `create` - method to create the corresponding model. - """ - - # Action space dimension. - action_dim: int - # Action sequence length. - action_horizon: int - # Tokenized prompt maximum length. - max_token_len: int - - @property - @abc.abstractmethod - def model_type(self) -> ModelType: - """The model type.""" - - @abc.abstractmethod - def create(self, rng: at.KeyArrayLike) -> "BaseModel": - """Create a new model, initializing parameters.""" - - def load(self, params: at.Params, *, remove_extra_params: bool = True) -> "BaseModel": - """Create a model with the given parameters.""" - model = nnx.eval_shape(self.create, jax.random.key(0)) - graphdef, state = nnx.split(model) - if remove_extra_params: - params = ocp.transform_utils.intersect_trees(state.to_pure_dict(), params) - at.check_pytree_equality(expected=state.to_pure_dict(), got=params, check_shapes=True, check_dtypes=False) - state.replace_by_pure_dict(params) - return nnx.merge(graphdef, state) - - def load_pytorch(self, train_config, weight_path: str): - logger.info(f"train_config: {train_config}") - model = pi0_pytorch.PI0Pytorch(config=train_config.model) - safetensors.torch.load_model(model, weight_path) - return model - - @abc.abstractmethod - def inputs_spec(self, *, batch_size: int = 1) -> tuple[Observation, Actions]: - """Returns the input specification for the model. Values are jax.ShapeDtypeStruct.""" - - def fake_obs(self, batch_size: int = 1) -> Observation: - observation_spec, _ = self.inputs_spec(batch_size=batch_size) - return jax.tree.map(lambda x: jnp.ones(x.shape, x.dtype), observation_spec) - - def fake_act(self, batch_size: int = 1) -> Actions: - _, action_spec = self.inputs_spec(batch_size=batch_size) - return jax.tree.map(lambda x: jnp.ones(x.shape, x.dtype), action_spec) - - -@dataclasses.dataclass -class BaseModel(nnx.Module, abc.ABC): - """Base class for all model implementations. Specific models should inherit from this class. They should call - super().__init__() to initialize the shared attributes (action_dim, action_horizon, and max_token_len). - """ - - action_dim: int - action_horizon: int - max_token_len: int - - @abc.abstractmethod - def compute_loss( - self, - rng: at.KeyArrayLike, - observation: Observation, - actions: Actions, - *, - train: bool = False, - ) -> at.Float[at.Array, "*b ah"]: ... - - @abc.abstractmethod - def sample_actions(self, rng: at.KeyArrayLike, observation: Observation, **kwargs) -> Actions: ... - - -def restore_params( - params_path: pathlib.Path | str, - *, - restore_type: type[np.ndarray] | type[jax.Array] = jax.Array, - dtype: jnp.dtype | None = None, - sharding: jax.sharding.Sharding | None = None, -) -> at.Params: - """Restores unstructured params PyTree from a checkpoint. - - This works with checkpoints saved with `save_state` during openpi training (see `training/checkpoints.py`) as - well as pre-trained checkpoints released for openpi. - - Args: - params_path: The local path to the checkpoint directory. - restore_type: The type to restore the params as. Can be set to `np.ndarray` to load the params as a numpy array. - dtype: The dtype to restore all params as. If not provided, will use the original dtype from the checkpoint. - sharding: The sharding to use for the params. If not provided, the params will be replicated across all devices. - - Returns: - The restored params. - """ - params_path = pathlib.Path(params_path).resolve() if not str(params_path).startswith("gs://") else params_path - - if restore_type is jax.Array and sharding is None: - mesh = jax.sharding.Mesh(jax.devices(), ("x",)) - sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) - - with ocp.PyTreeCheckpointer() as ckptr: - metadata = ckptr.metadata(params_path) - item = {"params": metadata["params"]} - - params = ckptr.restore( - params_path, - ocp.args.PyTreeRestore( - item=item, - restore_args=jax.tree.map( - lambda _: ocp.ArrayRestoreArgs(sharding=sharding, restore_type=restore_type, dtype=dtype), item - ), - ), - )["params"] - - # If the params were saved with `save_state` during openpi training, every key path will end with "value", which is - # added by `nnx.State`. We remove the "value" suffix here and always return what NNX calls a "pure dict". - flat_params = traverse_util.flatten_dict(params) - if all(kp[-1] == "value" for kp in flat_params): - flat_params = {kp[:-1]: v for kp, v in flat_params.items()} - return traverse_util.unflatten_dict(flat_params) diff --git a/capvector-pi05/src/openpi/models/model_test.py b/capvector-pi05/src/openpi/models/model_test.py deleted file mode 100644 index 528dc32f9c0f42d5008f6496deec056141f8cb03..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/models/model_test.py +++ /dev/null @@ -1,94 +0,0 @@ -from flax import nnx -import jax -import pytest - -from openpi.models import model as _model -from openpi.models import pi0_config -from openpi.models import pi0_fast -from openpi.shared import download -from openpi.shared import nnx_utils - - -def test_pi0_model(): - key = jax.random.key(0) - config = pi0_config.Pi0Config() - model = config.create(key) - - batch_size = 2 - obs, act = config.fake_obs(batch_size), config.fake_act(batch_size) - - loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act) - assert loss.shape == (batch_size, config.action_horizon) - - actions = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10) - assert actions.shape == (batch_size, model.action_horizon, model.action_dim) - - -def test_pi0_lora_model(): - key = jax.random.key(0) - config = pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora") - model = config.create(key) - - batch_size = 2 - obs, act = config.fake_obs(batch_size), config.fake_act(batch_size) - - loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act) - assert loss.shape == (batch_size, config.action_horizon) - - actions = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10) - assert actions.shape == (batch_size, model.action_horizon, model.action_dim) - - -def test_pi0_fast_model(): - key = jax.random.key(0) - config = pi0_fast.Pi0FASTConfig() - model = config.create(key) - - batch_size = 2 - obs, act = config.fake_obs(batch_size), config.fake_act(batch_size) - - loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act) - assert loss.shape == (batch_size,) - - actions = nnx_utils.module_jit(model.sample_actions)(key, obs) - assert actions.shape == (batch_size, 256) - - -def test_pi0_fast_lora_model(): - key = jax.random.key(0) - config = pi0_fast.Pi0FASTConfig(paligemma_variant="gemma_2b_lora") - model = config.create(key) - - batch_size = 2 - obs, act = config.fake_obs(batch_size), config.fake_act(batch_size) - - loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act) - assert loss.shape == (batch_size,) - - actions = nnx_utils.module_jit(model.sample_actions)(key, obs) - assert actions.shape == (batch_size, 256) - - lora_filter = nnx_utils.PathRegex(".*lora.*") - model_state = nnx.state(model) - - lora_state_elems = list(model_state.filter(lora_filter)) - assert len(lora_state_elems) > 0 - - -@pytest.mark.manual -def test_model_restore(): - key = jax.random.key(0) - config = pi0_config.Pi0Config() - - batch_size = 2 - obs, act = config.fake_obs(batch_size), config.fake_act(batch_size) - - model = config.load( - _model.restore_params(download.maybe_download("gs://openpi-assets/checkpoints/pi0_base/params")) - ) - - loss = model.compute_loss(key, obs, act) - assert loss.shape == (batch_size, config.action_horizon) - - actions = model.sample_actions(key, obs, num_steps=10) - assert actions.shape == (batch_size, model.action_horizon, model.action_dim) diff --git a/capvector-pi05/src/openpi/models/pi0.py b/capvector-pi05/src/openpi/models/pi0.py deleted file mode 100644 index 90fd7935a3f46c86b3100a42db701e8af719cbff..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/models/pi0.py +++ /dev/null @@ -1,279 +0,0 @@ -import logging - -import einops -import flax.nnx as nnx -import flax.nnx.bridge as nnx_bridge -import jax -import jax.numpy as jnp -from typing_extensions import override - -from openpi.models import model as _model -from openpi.models import pi0_config -import openpi.models.gemma as _gemma -import openpi.models.siglip as _siglip -from openpi.shared import array_typing as at - -logger = logging.getLogger("openpi") - - -def make_attn_mask(input_mask, mask_ar): - """Adapted from big_vision. - - Tokens can attend to valid inputs tokens which have a cumulative mask_ar - smaller or equal to theirs. This way `mask_ar` bool[?B, N] can be used to - setup several types of attention, for example: - - [[1 1 1 1 1 1]]: pure causal attention. - - [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between - themselves and the last 3 tokens have a causal attention. The first - entry could also be a 1 without changing behaviour. - - [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a - block can attend all previous blocks and all tokens on the same block. - - Args: - input_mask: bool[B, N] true if its part of the input, false if padding. - mask_ar: bool[?B, N] mask that's true where previous tokens cannot depend on - it and false where it shares the same attention mask as the previous token. - """ - mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape) - cumsum = jnp.cumsum(mask_ar, axis=1) - attn_mask = cumsum[:, None, :] <= cumsum[:, :, None] - valid_mask = input_mask[:, None, :] * input_mask[:, :, None] - return jnp.logical_and(attn_mask, valid_mask) - - -@at.typecheck -def posemb_sincos( - pos: at.Real[at.Array, " b"], embedding_dim: int, min_period: float, max_period: float -) -> at.Float[at.Array, "b {embedding_dim}"]: - """Computes sine-cosine positional embedding vectors for scalar positions.""" - if embedding_dim % 2 != 0: - raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by 2") - - fraction = jnp.linspace(0.0, 1.0, embedding_dim // 2) - period = min_period * (max_period / min_period) ** fraction - sinusoid_input = jnp.einsum( - "i,j->ij", - pos, - 1.0 / period * 2 * jnp.pi, - precision=jax.lax.Precision.HIGHEST, - ) - return jnp.concatenate([jnp.sin(sinusoid_input), jnp.cos(sinusoid_input)], axis=-1) - - -class Pi0(_model.BaseModel): - def __init__(self, config: pi0_config.Pi0Config, rngs: nnx.Rngs): - super().__init__(config.action_dim, config.action_horizon, config.max_token_len) - self.pi05 = config.pi05 - paligemma_config = _gemma.get_config(config.paligemma_variant) - action_expert_config = _gemma.get_config(config.action_expert_variant) - # TODO: rewrite gemma in NNX. For now, use bridge. - llm = nnx_bridge.ToNNX( - _gemma.Module( - configs=[paligemma_config, action_expert_config], - embed_dtype=config.dtype, - adarms=config.pi05, - ) - ) - llm.lazy_init(rngs=rngs, method="init", use_adarms=[False, True] if config.pi05 else [False, False]) - img = nnx_bridge.ToNNX( - _siglip.Module( - num_classes=paligemma_config.width, - variant="So400m/14", - pool_type="none", - scan=True, - dtype_mm=config.dtype, - ) - ) - img.lazy_init(next(iter(config.fake_obs().images.values())), train=False, rngs=rngs) - self.PaliGemma = nnx.Dict(llm=llm, img=img) - self.action_in_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs) - if config.pi05: - self.time_mlp_in = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs) - self.time_mlp_out = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs) - else: - self.state_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs) - self.action_time_mlp_in = nnx.Linear(2 * action_expert_config.width, action_expert_config.width, rngs=rngs) - self.action_time_mlp_out = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs) - self.action_out_proj = nnx.Linear(action_expert_config.width, config.action_dim, rngs=rngs) - - # This attribute gets automatically set by model.train() and model.eval(). - self.deterministic = True - - @at.typecheck - def embed_prefix( - self, obs: _model.Observation - ) -> tuple[at.Float[at.Array, "b s emb"], at.Bool[at.Array, "b s"], at.Bool[at.Array, " s"]]: - input_mask = [] - ar_mask = [] - tokens = [] - # embed images - for name in obs.images: - image_tokens, _ = self.PaliGemma.img(obs.images[name], train=False) - - tokens.append(image_tokens) - input_mask.append( - einops.repeat( - obs.image_masks[name], - "b -> b s", - s=image_tokens.shape[1], - ) - ) - # image tokens attend to each other - ar_mask += [False] * image_tokens.shape[1] - - # add language (aka tokenized inputs) - if obs.tokenized_prompt is not None: - tokenized_inputs = self.PaliGemma.llm(obs.tokenized_prompt, method="embed") - tokens.append(tokenized_inputs) - input_mask.append(obs.tokenized_prompt_mask) - # full attention between image and language inputs - ar_mask += [False] * tokenized_inputs.shape[1] - tokens = jnp.concatenate(tokens, axis=1) - input_mask = jnp.concatenate(input_mask, axis=1) - ar_mask = jnp.array(ar_mask) - return tokens, input_mask, ar_mask - - @at.typecheck - def embed_suffix( - self, obs: _model.Observation, noisy_actions: _model.Actions, timestep: at.Float[at.Array, " b"] - ) -> tuple[ - at.Float[at.Array, "b s emb"], - at.Bool[at.Array, "b s"], - at.Bool[at.Array, " s"], - at.Float[at.Array, "b emb"] | None, - ]: - input_mask = [] - ar_mask = [] - tokens = [] - if not self.pi05: - # add a single state token - state_token = self.state_proj(obs.state)[:, None, :] - tokens.append(state_token) - input_mask.append(jnp.ones((obs.state.shape[0], 1), dtype=jnp.bool_)) - # image/language inputs do not attend to state or actions - ar_mask += [True] - - action_tokens = self.action_in_proj(noisy_actions) - # embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1] - time_emb = posemb_sincos(timestep, self.action_in_proj.out_features, min_period=4e-3, max_period=4.0) - if self.pi05: - # time MLP (for adaRMS) - time_emb = self.time_mlp_in(time_emb) - time_emb = nnx.swish(time_emb) - time_emb = self.time_mlp_out(time_emb) - time_emb = nnx.swish(time_emb) - action_expert_tokens = action_tokens - adarms_cond = time_emb - else: - # mix timestep + action information using an MLP (no adaRMS) - time_tokens = einops.repeat(time_emb, "b emb -> b s emb", s=self.action_horizon) - action_time_tokens = jnp.concatenate([action_tokens, time_tokens], axis=-1) - action_time_tokens = self.action_time_mlp_in(action_time_tokens) - action_time_tokens = nnx.swish(action_time_tokens) - action_time_tokens = self.action_time_mlp_out(action_time_tokens) - action_expert_tokens = action_time_tokens - adarms_cond = None - tokens.append(action_expert_tokens) - input_mask.append(jnp.ones(action_expert_tokens.shape[:2], dtype=jnp.bool_)) - # image/language/state inputs do not attend to action tokens - ar_mask += [True] + ([False] * (self.action_horizon - 1)) - tokens = jnp.concatenate(tokens, axis=1) - input_mask = jnp.concatenate(input_mask, axis=1) - ar_mask = jnp.array(ar_mask) - return tokens, input_mask, ar_mask, adarms_cond - - @override - def compute_loss( - self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool = False - ) -> at.Float[at.Array, "*b ah"]: - preprocess_rng, noise_rng, time_rng = jax.random.split(rng, 3) - observation = _model.preprocess_observation(preprocess_rng, observation, train=train) - - batch_shape = actions.shape[:-2] - noise = jax.random.normal(noise_rng, actions.shape) - time = jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 + 0.001 - time_expanded = time[..., None, None] - x_t = time_expanded * noise + (1 - time_expanded) * actions - u_t = noise - actions - - # one big forward pass of prefix + suffix at once - prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation) - suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix(observation, x_t, time) - input_mask = jnp.concatenate([prefix_mask, suffix_mask], axis=1) - ar_mask = jnp.concatenate([prefix_ar_mask, suffix_ar_mask], axis=0) - attn_mask = make_attn_mask(input_mask, ar_mask) - positions = jnp.cumsum(input_mask, axis=1) - 1 - (prefix_out, suffix_out), _ = self.PaliGemma.llm( - [prefix_tokens, suffix_tokens], mask=attn_mask, positions=positions, adarms_cond=[None, adarms_cond] - ) - v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :]) - - return jnp.mean(jnp.square(v_t - u_t), axis=-1) - - @override - def sample_actions( - self, - rng: at.KeyArrayLike, - observation: _model.Observation, - *, - num_steps: int | at.Int[at.Array, ""] = 10, - noise: at.Float[at.Array, "b ah ad"] | None = None, - ) -> _model.Actions: - observation = _model.preprocess_observation(None, observation, train=False) - # note that we use the convention more common in diffusion literature, where t=1 is noise and t=0 is the target - # distribution. yes, this is the opposite of the pi0 paper, and I'm sorry. - dt = -1.0 / num_steps - batch_size = observation.state.shape[0] - if noise is None: - noise = jax.random.normal(rng, (batch_size, self.action_horizon, self.action_dim)) - - # first fill KV cache with a forward pass of the prefix - prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation) - prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask) - positions = jnp.cumsum(prefix_mask, axis=1) - 1 - _, kv_cache = self.PaliGemma.llm([prefix_tokens, None], mask=prefix_attn_mask, positions=positions) - - def step(carry): - x_t, time = carry - suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix( - observation, x_t, jnp.broadcast_to(time, batch_size) - ) - # `suffix_attn_mask` is shape (b, suffix_len, suffix_len) indicating how the suffix tokens can attend to each - # other - suffix_attn_mask = make_attn_mask(suffix_mask, suffix_ar_mask) - # `prefix_attn_mask` is shape (b, suffix_len, prefix_len) indicating how the suffix tokens can attend to the - # prefix tokens - prefix_attn_mask = einops.repeat(prefix_mask, "b p -> b s p", s=suffix_tokens.shape[1]) - # `combined_mask` is shape (b, suffix_len, prefix_len + suffix_len) indicating how the suffix tokens (which - # generate the queries) can attend to the full prefix + suffix sequence (which generates the keys and values) - full_attn_mask = jnp.concatenate([prefix_attn_mask, suffix_attn_mask], axis=-1) - assert full_attn_mask.shape == ( - batch_size, - suffix_tokens.shape[1], - prefix_tokens.shape[1] + suffix_tokens.shape[1], - ) - # `positions` is shape (b, suffix_len) indicating the positions of the suffix tokens - positions = jnp.sum(prefix_mask, axis=-1)[:, None] + jnp.cumsum(suffix_mask, axis=-1) - 1 - - (prefix_out, suffix_out), _ = self.PaliGemma.llm( - [None, suffix_tokens], - mask=full_attn_mask, - positions=positions, - kv_cache=kv_cache, - adarms_cond=[None, adarms_cond], - ) - assert prefix_out is None - v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :]) - - return x_t + dt * v_t, time + dt - - def cond(carry): - x_t, time = carry - # robust to floating-point error - return time >= -dt / 2 - - x_0, _ = jax.lax.while_loop(cond, step, (noise, 1.0)) - return x_0 diff --git a/capvector-pi05/src/openpi/models/pi0_config.py b/capvector-pi05/src/openpi/models/pi0_config.py deleted file mode 100644 index 26c97f720d491c917fdb870ff8a85102e3c7a3b5..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/models/pi0_config.py +++ /dev/null @@ -1,108 +0,0 @@ -import dataclasses -from typing import TYPE_CHECKING - -import flax.nnx as nnx -import jax -import jax.numpy as jnp -from typing_extensions import override - -from openpi.models import model as _model -import openpi.models.gemma as _gemma -from openpi.shared import array_typing as at -import openpi.shared.nnx_utils as nnx_utils - -if TYPE_CHECKING: - from openpi.models.pi0 import Pi0 - - -@dataclasses.dataclass(frozen=True) -class Pi0Config(_model.BaseModelConfig): - dtype: str = "bfloat16" - paligemma_variant: _gemma.Variant = "gemma_2b" - action_expert_variant: _gemma.Variant = "gemma_300m" - - # Set the model specific defaults. - action_dim: int = 32 - action_horizon: int = 50 - max_token_len: int = None # type: ignore - # Pi05 has two differences from Pi0: - # - the state input is part of the discrete language tokens rather than a continuous input that is part of the suffix - # - the action expert uses adaRMSNorm to inject the flow matching timestep - pi05: bool = False - # This config option is not used directly by the model, but it is read by the ModelTransformFactory. - discrete_state_input: bool = None # type: ignore - - def __post_init__(self): - if self.max_token_len is None: - object.__setattr__(self, "max_token_len", 200 if self.pi05 else 48) - if self.discrete_state_input is None: - object.__setattr__(self, "discrete_state_input", self.pi05) - - @property - @override - def model_type(self) -> _model.ModelType: - if self.pi05: - return _model.ModelType.PI05 - return _model.ModelType.PI0 - - @override - def create(self, rng: at.KeyArrayLike) -> "Pi0": - from openpi.models.pi0 import Pi0 - - return Pi0(self, rngs=nnx.Rngs(rng)) - - @override - def inputs_spec(self, *, batch_size: int = 1) -> tuple[_model.Observation, _model.Actions]: - image_spec = jax.ShapeDtypeStruct([batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32) - image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_) - - with at.disable_typechecking(): - observation_spec = _model.Observation( - images={ - "base_0_rgb": image_spec, - "left_wrist_0_rgb": image_spec, - "right_wrist_0_rgb": image_spec, - }, - image_masks={ - "base_0_rgb": image_mask_spec, - "left_wrist_0_rgb": image_mask_spec, - "right_wrist_0_rgb": image_mask_spec, - }, - state=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32), - tokenized_prompt=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32), - tokenized_prompt_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], bool), - ) - action_spec = jax.ShapeDtypeStruct([batch_size, self.action_horizon, self.action_dim], jnp.float32) - - return observation_spec, action_spec - - def get_freeze_filter(self) -> nnx.filterlib.Filter: - """Returns the freeze filter based on the model config.""" - filters = [] - has_lora = False - gemma_params_filter = nnx_utils.PathRegex(".*llm.*") - action_expert_params_filter = nnx_utils.PathRegex(".*llm.*_1.*") - if "lora" in self.paligemma_variant: - filters.append( - gemma_params_filter, - ) - if "lora" not in self.action_expert_variant: - # If only freeze gemma params, exclude action expert params. - filters.append( - nnx.Not(action_expert_params_filter), - ) - has_lora = True - elif "lora" in self.action_expert_variant: - filters.append( - action_expert_params_filter, - ) - has_lora = True - - if has_lora: - # If any lora is used, exclude all lora params. - filters.append( - nnx.Not(nnx_utils.PathRegex(".*lora.*")), - ) - if not filters: - return nnx.Nothing - return nnx.All(*filters) diff --git a/capvector-pi05/src/openpi/models/pi0_fast.py b/capvector-pi05/src/openpi/models/pi0_fast.py deleted file mode 100644 index 8c3ed5503c27facc30a293d4740195ebc8b5fe46..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/models/pi0_fast.py +++ /dev/null @@ -1,313 +0,0 @@ -import dataclasses -import logging -from typing import Any - -import einops -import flax.nnx as nnx -import flax.nnx.bridge as nnx_bridge -import jax -import jax.numpy as jnp -from typing_extensions import override - -from openpi.models import model as _model -import openpi.models.gemma_fast as _gemma -import openpi.models.siglip as _siglip -from openpi.shared import array_typing as at -import openpi.shared.nnx_utils as nnx_utils - -logger = logging.getLogger("openpi") - -PALIGEMMA_EOS_TOKEN = 1 - - -def make_attn_mask(input_mask, mask_ar): - """Adapted from big_vision. - - Tokens can attend to valid inputs tokens which have a cumulative mask_ar - smaller or equal to theirs. This way `mask_ar` bool[?B, N] can be used to - setup several types of attention, for example: - - [[1 1 1 1 1 1]]: pure causal attention. - - [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between - themselves and the last 3 tokens have a causal attention. The first - entry could also be a 1 without changing behaviour. - - [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a - block can attend all previous blocks and all tokens on the same block. - - Args: - input_mask: bool[B, N] true if its part of the input, false if padding. - mask_ar: bool[?B, N] mask that's true where previous tokens cannot depend on - it and false where it shares the same attention mask as the previous token. - """ - mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape) - cumsum = jnp.cumsum(mask_ar, axis=1) - attn_mask = cumsum[:, None, :] <= cumsum[:, :, None] - valid_mask = input_mask[:, None, :] * input_mask[:, :, None] - return jnp.logical_and(attn_mask, valid_mask) - - -@jax.vmap -def left_to_right_align(x, input_mask, attn_mask): - """Converts input from left-align to right-aligned.""" - # Due to vmap, this is operating in a single example (not batch level). - assert x.ndim == 2 - assert input_mask.ndim == 1 - assert attn_mask.ndim == 2 - assert x.shape[0] == input_mask.shape[0] - assert attn_mask.shape[0] == attn_mask.shape[1], attn_mask.shape - seqlen = jnp.max(input_mask * jnp.arange(input_mask.shape[0])) + 1 - x = jnp.roll(x, -seqlen, axis=0) - input_mask = jnp.roll(input_mask, -seqlen, axis=0) - attn_mask = jnp.roll(attn_mask, -seqlen, axis=(0, 1)) - return x, input_mask, attn_mask - - -def put_along_last_axis(arr, indices, values): - """Like np.put_along_axis(..., axis=-1), since jax is missing it.""" - assert arr.ndim == indices.ndim == values.ndim, (arr.ndim, indices.ndim, values.ndim) - onehot = jax.nn.one_hot(indices, arr.shape[-1], dtype=values.dtype) - put_mask = jnp.einsum("...i,...in->...n", jnp.ones(values.shape, jnp.int32), onehot) - put_values = jnp.einsum("...i,...in->...n", values, onehot) - return jnp.where(put_mask, put_values, arr) - - -@dataclasses.dataclass(frozen=True) -class Pi0FASTConfig(_model.BaseModelConfig): - dtype: str = "bfloat16" - paligemma_variant: _gemma.Variant = "gemma_2b" - - # Set the model specific defaults. - action_dim: int = 32 - action_horizon: int = 32 - max_token_len: int = 250 - - # Tokenizer for the fast model. - fast_model_tokenizer: Any | None = None - # Keyword arguments for the fast model tokenizer. - fast_model_tokenizer_kwargs: dict[str, Any] | None = None - - @property - @override - def model_type(self) -> _model.ModelType: - return _model.ModelType.PI0_FAST - - @override - def create(self, rng: at.KeyArrayLike) -> "Pi0FAST": - return Pi0FAST(self, rngs=nnx.Rngs(rng)) - - @override - def inputs_spec(self, *, batch_size: int = 1) -> tuple[_model.Observation, _model.Actions]: - image_spec = jax.ShapeDtypeStruct([batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32) - image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_) - - with at.disable_typechecking(): - observation_spec = _model.Observation( - images={ - "base_0_rgb": image_spec, - "base_1_rgb": image_spec, - "wrist_0_rgb": image_spec, - }, - image_masks={ - "base_0_rgb": image_mask_spec, - "base_1_rgb": image_mask_spec, - "wrist_0_rgb": image_mask_spec, - }, - state=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32), - tokenized_prompt=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32), - tokenized_prompt_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], bool), - token_ar_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32), - token_loss_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.bool_), - ) - action_spec = jax.ShapeDtypeStruct([batch_size, self.action_horizon, self.action_dim], jnp.float32) - - return observation_spec, action_spec - - def get_freeze_filter(self) -> nnx.filterlib.Filter: - """Returns the freeze filter based on the model config.""" - if "lora" in self.paligemma_variant: - return nnx.All(nnx_utils.PathRegex(".*llm.*"), nnx.Not(nnx_utils.PathRegex(".*lora.*"))) - return nnx.Nothing - - -class Pi0FAST(_model.BaseModel): - def __init__(self, config: Pi0FASTConfig, rngs: nnx.Rngs): - super().__init__(config.action_dim, config.action_horizon, config.max_token_len) - paligemma_config = _gemma.get_config(config.paligemma_variant) - # TODO: rewrite gemma in NNX. For now, use bridge. - llm = nnx_bridge.ToNNX( - _gemma.Module( - **paligemma_config, - embed_dtype=config.dtype, - cache_dtype=config.dtype, - ) - ) - llm.lazy_init(rngs=rngs, method="init") - img = nnx_bridge.ToNNX( - _siglip.Module( - num_classes=paligemma_config.width, - variant="So400m/14", - pool_type="none", - scan=True, - dtype_mm=config.dtype, - ) - ) - img.lazy_init(next(iter(config.fake_obs().images.values())), train=False, rngs=rngs) - self.PaliGemma = nnx.Dict(llm=llm, img=img) - - @at.typecheck - def embed_inputs( - self, obs: _model.Observation - ) -> tuple[at.Float[at.Array, "b s emb"], at.Bool[at.Array, "b s"], at.Int[at.Array, "b s"]]: - input_mask = [] - ar_mask = [] - token_embeddings = [] - # embed images - for name in obs.images: - image_token_embeddings, _ = self.PaliGemma.img(obs.images[name], train=False) - - token_embeddings.append(image_token_embeddings) - input_mask.append( - einops.repeat( - obs.image_masks[name], - "b -> b s", - s=image_token_embeddings.shape[1], - ) - ) - # image tokens attend to each other --> AR mask = 0 - ar_mask.append(0 * input_mask[-1]) - - # add tokenized inputs - assert obs.tokenized_prompt is not None, "Tokenized prompt is required" - assert obs.tokenized_prompt_mask is not None, "Tokenized prompt mask is required" - assert obs.token_ar_mask is not None, "Token auto-regressive mask is required" - tokenized_inputs_embeddings = self.PaliGemma.llm(obs.tokenized_prompt, embed_only=True) - token_embeddings.append(tokenized_inputs_embeddings) - input_mask.append(obs.tokenized_prompt_mask) - ar_mask.append(obs.token_ar_mask) - - # return embeddings, input mask, and ar mask - return ( - jnp.concatenate(token_embeddings, axis=1), - jnp.concatenate(input_mask, axis=1), - jnp.concatenate(ar_mask, axis=1), - ) - - @override - def compute_loss( - self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool = False - ) -> at.Float[at.Array, "*b ah"]: - observation = _model.preprocess_observation( - rng, observation, train=train, image_keys=list(observation.images.keys()) - ) - - # Compute inputs: one big forward pass of prefix + suffix at once - input_token_embeddings, input_mask, ar_mask = self.embed_inputs(observation) - attn_mask = make_attn_mask(input_mask, ar_mask) - - # Compute one-hot targets: we predict *next* token, so shift the input tokens by one. - targets = jax.nn.one_hot( - observation.tokenized_prompt[:, 1:], - self.PaliGemma.llm.module.vocab_size, - ) - - # Each input predicts *next* token, so we don't input the last token. - pre_logits, _, _ = self.PaliGemma.llm( - embedded_prefix=input_token_embeddings[:, :-1], - mask=attn_mask[:, :-1, :-1], - return_prelogits=True, - ) - - # Only decode logits for the target tokens to save memory - # (decoding matmul is large because it is a seq_len x vocab_size dense layer). - logits, _ = self.PaliGemma.llm( - pre_logits=pre_logits[:, -targets.shape[1] :], - ) - logp = jax.nn.log_softmax(logits, axis=-1) - - # Compute CE loss on token targets - assert observation.token_loss_mask is not None, "Token loss mask is required" - loss_mask = observation.token_loss_mask[:, 1:] - token_pplx = jnp.sum(targets * logp, axis=-1) - return -jnp.sum(token_pplx * loss_mask, axis=-1) / jnp.clip(jnp.sum(loss_mask, -1), 1) - - @override - def sample_actions( - self, - rng: at.KeyArrayLike, - observation: _model.Observation, - *, - max_decoding_steps: int | at.Int[at.Array, ""] = 256, - temperature: float = 0.0, - ) -> _model.Actions: - # TODO: this is a hack to get the image keys. - observation = _model.preprocess_observation( - None, observation, train=False, image_keys=list(observation.images.keys()) - ) - - # embed inputs - prefix_token_embeddings, prefix_mask, prefix_ar_mask = self.embed_inputs(observation) - prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask) - - # left to right align all input token sequences - prefix_token_embeddings, prefix_mask, prefix_attn_mask = left_to_right_align( - prefix_token_embeddings, prefix_mask, prefix_attn_mask - ) - prefill_size = prefix_token_embeddings.shape[1] - prefill_len = jnp.sum(prefix_mask, axis=-1) - prefix_start = prefill_size - prefill_len - - # first fill KV cache with a forward pass of the prefix - # pad attention mask to set the size of the KV cache (prefill_size + max_decoding_steps) - prefix_attn_mask = jnp.pad(prefix_attn_mask, ((0, 0), (0, 0), (0, max_decoding_steps))) - prefix_positions = jnp.cumsum(prefix_mask, axis=-1) - 1 - prefix_logits, kv_cache, _ = self.PaliGemma.llm( - embedded_prefix=prefix_token_embeddings, mask=prefix_attn_mask, positions=prefix_positions, decode=True - ) - - # prepare decoding -- final logit decodes the first token - last_logit = prefix_logits[:, -1:] - output_tokens = jnp.zeros((last_logit.shape[0], max_decoding_steps)) - - def step(carry): - rng, last_logit, output_tokens, cache, _, step = carry - - # Sample token from last logit - # Split RNG for this step - rng, rng_step = jax.random.split(rng) - token = jax.lax.cond( - temperature > 0.0, - lambda _: jax.random.categorical(rng_step, last_logit / temperature, axis=-1), - lambda _: jnp.argmax(last_logit, axis=-1), - operand=None, - ) - output_tokens = put_along_last_axis(output_tokens, jnp.broadcast_to(step, (token.shape[0], 1)), token) - - # Check for early stopping --> stop if all batch elements have EOS token - has_eos = jnp.any(token == PALIGEMMA_EOS_TOKEN, axis=-1) - all_eos = jnp.all(has_eos) - - # Decode one step - token_embedding = self.PaliGemma.llm(token, embed_only=True) - positions = prefill_len[:, None] + step + 1 - mask = jnp.logical_and( - jnp.arange(prefill_size + max_decoding_steps)[None, None, :] >= prefix_start[:, None, None], - jnp.arange(prefill_size + max_decoding_steps)[None, None, :] - < (jnp.broadcast_to(prefill_size + step + 1, (prefix_start.shape[0], 1, 1))), - ) - last_logit, kv_cache, _ = self.PaliGemma.llm( - embedded_prefix=token_embedding, mask=mask, positions=positions, decode=True, kv_cache=cache - ) - - return rng, last_logit, output_tokens, kv_cache, all_eos, step + 1 - - def cond(carry): - _, _, _, _, all_eos, step = carry - return (~all_eos) & (step < max_decoding_steps) - - # Use lax.while_loop so we can jit the full decoding loop. - _, _, output_tokens, _, _, _ = jax.lax.while_loop( - cond, step, (rng, last_logit, output_tokens, kv_cache, False, 0) - ) - return output_tokens diff --git a/capvector-pi05/src/openpi/models/pi0_test.py b/capvector-pi05/src/openpi/models/pi0_test.py deleted file mode 100644 index 793739d137c88665a66f9396513f7264654e8cb1..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/models/pi0_test.py +++ /dev/null @@ -1,46 +0,0 @@ -import flax.nnx as nnx -import jax - -import openpi.models.pi0_config as _pi0_config - - -def _get_frozen_state(config: _pi0_config.Pi0Config) -> nnx.State: - abstract_model = nnx.eval_shape(config.create, jax.random.key(0)) - - freeze_filter = config.get_freeze_filter() - return nnx.state(abstract_model, nnx.All(nnx.Param, freeze_filter)).flat_state() - - -def test_pi0_full_finetune(): - config = _pi0_config.Pi0Config() - state = _get_frozen_state(config) - assert len(state) == 0 - - -def test_pi0_gemma_lora(): - config = _pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora") - state = _get_frozen_state(config) - assert len(state) == 9 - assert all("lora" not in p for p in state) - assert all("llm" in p for p in state) - assert all("_1" not in p for p in state) - - -def test_pi0_action_expert_lora(): - config = _pi0_config.Pi0Config(action_expert_variant="gemma_300m_lora") - state = _get_frozen_state(config) - # excluding embedder, rest of the params should be same as gemma_lora. - assert len(state) == 8 - assert all("lora" not in p for p in state) - assert all("llm" in p for p in state) - # all frozen params should have _1 in their path since it's the action expert. - assert all(any("_1" in p for p in path) for path in state) - - -def test_pi0_all_lora(): - config = _pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora") - state = _get_frozen_state(config) - # sum of gemma_lora and action_expert_lora's frozen params. - assert len(state) == 17 - assert all("lora" not in p for p in state) - assert all("llm" in p for p in state) diff --git a/capvector-pi05/src/openpi/models/siglip.py b/capvector-pi05/src/openpi/models/siglip.py deleted file mode 100644 index e306802ac98a31174f7922a98018e13a1af58647..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/models/siglip.py +++ /dev/null @@ -1,373 +0,0 @@ -# Copyright 2024 Big Vision Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""A refactored and simplified ViT adoptation for Pi, taken from big_vision.""" - -from collections.abc import Sequence - -import flax.linen as nn -import jax -import jax.numpy as jnp -import numpy as np - -import openpi.training.sharding as sharding - - -def posemb_sincos_2d(h, w, width, temperature=10_000.0, dtype=jnp.float32): - """Follows the MoCo v3 logic.""" - y, x = jnp.mgrid[:h, :w] - - assert width % 4 == 0, "Width must be mult of 4 for sincos posemb" - omega = jnp.arange(width // 4) / (width // 4 - 1) - omega = 1.0 / (temperature**omega) - y = jnp.einsum("m,d->md", y.flatten(), omega) - x = jnp.einsum("m,d->md", x.flatten(), omega) - pe = jnp.concatenate([jnp.sin(x), jnp.cos(x), jnp.sin(y), jnp.cos(y)], axis=1) - return jnp.asarray(pe, dtype)[None, :, :] - - -def get_posemb(self, typ, seqshape, width, name, dtype=jnp.float32): - if typ == "learn": - return self.param( - name, - nn.initializers.normal(stddev=1 / np.sqrt(width)), - (1, np.prod(seqshape), width), - dtype, - ) - if typ == "sincos2d": - return posemb_sincos_2d(*seqshape, width, dtype=dtype) - raise ValueError(f"Unknown posemb type: {typ}") - - -class MlpBlock(nn.Module): - """Transformer MLP / feed-forward block.""" - - mlp_dim: int | None = None # Defaults to 4x input dim - dropout: float = 0.0 - dtype_mm: str = "float32" - - @nn.compact - def __call__(self, x, deterministic=True): # noqa: FBT002 - """Applies Transformer MlpBlock module.""" - inits = { - "kernel_init": nn.initializers.xavier_uniform(), - "bias_init": nn.initializers.normal(stddev=1e-6), - } - - _, _, d = x.shape # n,l,d - x = nn.Dense(self.mlp_dim or 4 * d, dtype=self.dtype_mm, **inits)(x) - x = nn.gelu(x) - x = nn.Dropout(rate=self.dropout)(x, deterministic) - return nn.Dense(d, dtype=self.dtype_mm, **inits)(x) - - -class Encoder1DBlock(nn.Module): - """Single transformer encoder block (MHSA + MLP).""" - - mlp_dim: int | None = None # Defaults to 4x input dim - num_heads: int = 12 - dropout: float = 0.0 - dtype_mm: str = "float32" - - @nn.compact - def __call__(self, x, deterministic=True): # noqa: FBT002 - out = {} - x = sharding.activation_sharding_constraint(x) - y = nn.LayerNorm(dtype=self.dtype_mm)(x) - y = out["sa"] = nn.MultiHeadDotProductAttention( - num_heads=self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - deterministic=deterministic, - dtype=self.dtype_mm, - )(y, y) - y = sharding.activation_sharding_constraint(y) - y = nn.Dropout(rate=self.dropout)(y, deterministic) - x = out["+sa"] = x + y - - y = nn.LayerNorm(dtype=self.dtype_mm)(x) - y = out["mlp"] = MlpBlock( - mlp_dim=self.mlp_dim, - dropout=self.dropout, - dtype_mm=self.dtype_mm, - )(y, deterministic) - y = sharding.activation_sharding_constraint(y) - y = nn.Dropout(rate=self.dropout)(y, deterministic) - x = out["+mlp"] = x + y - x = sharding.activation_sharding_constraint(x) - return x, out - - -class Encoder(nn.Module): - """Transformer Model Encoder for sequence to sequence translation.""" - - depth: int - mlp_dim: int | None = None # Defaults to 4x input dim - num_heads: int = 12 - dropout: float = 0.0 - scan: bool = False - remat_policy: str = "nothing_saveable" - dtype_mm: str = "float32" - - @nn.compact - def __call__(self, x, deterministic=True): # noqa: FBT002 - out = {} - - if self.scan: - block = nn.remat( - Encoder1DBlock, - prevent_cse=False, - static_argnums=(2,), # 0=self, 2=deterministic - policy=getattr(jax.checkpoint_policies, self.remat_policy, None), - ) - x, scan_out = nn.scan( - block, - variable_axes={"params": 0}, - split_rngs={"params": True, "dropout": True}, - in_axes=nn.broadcast, - length=self.depth, - )( - name="encoderblock", - dtype_mm=self.dtype_mm, - mlp_dim=self.mlp_dim, - num_heads=self.num_heads, - dropout=self.dropout, - )(x, deterministic) - for lyr in range(self.depth): - out[f"block{lyr:02d}"] = jax.tree.map(lambda o, lyr=lyr: o[lyr], scan_out) - else: - # Input Encoder - for lyr in range(self.depth): - block_cur = Encoder1DBlock( - name=f"encoderblock_{lyr}", - dtype_mm=self.dtype_mm, - mlp_dim=self.mlp_dim, - num_heads=self.num_heads, - dropout=self.dropout, - ) - x, out[f"block{lyr:02d}"] = block_cur(x, deterministic) - out["pre_ln"] = x # Alias for last block, but without the number in it. - - return nn.LayerNorm(name="encoder_norm", dtype=self.dtype_mm)(x), out - - -class MAPHead(nn.Module): - """Multihead Attention Pooling.""" - - mlp_dim: int | None = None # Defaults to 4x input dim - num_heads: int = 12 - dtype_mm: str = "float32" - - @nn.compact - def __call__(self, x): - n, _, d = x.shape # n,l,d - probe = self.param("probe", nn.initializers.xavier_uniform(), (1, 1, d), x.dtype) - probe = jnp.tile(probe, [n, 1, 1]) - - x = nn.MultiHeadDotProductAttention( - num_heads=self.num_heads, - dtype=self.dtype_mm, - kernel_init=nn.initializers.xavier_uniform(), - )(probe, x) - - y = nn.LayerNorm(dtype=self.dtype_mm)(x) - x = x + MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype_mm)(y) - return x[:, 0] - - -class _Module(nn.Module): - """ViT model.""" - - num_classes: int | None = None - patch_size: Sequence[int] = (16, 16) - width: int = 768 - depth: int = 12 - mlp_dim: int | None = None # Defaults to 4x input dim - num_heads: int = 12 - posemb: str = "learn" # Can also be "sincos2d" - rep_size: int | bool = False - dropout: float = 0.0 - pool_type: str = "gap" # Can also be "map" or "tok" - head_zeroinit: bool = True - scan: bool = False - # or "dots_with_no_batch_dims_saveable" for more speed (memory costly) - remat_policy: str = "nothing_saveable" - dtype_mm: str = "float32" - - @nn.compact - def __call__(self, image, *, train=False): - out = {} - - # Kevin edit: do patch extraction and posemb in float32, - # because I feel like it's a bit safer. - image = jnp.asarray(image, jnp.float32) - - # Patch extraction - x = out["stem"] = nn.Conv( - self.width, - self.patch_size, - strides=self.patch_size, - padding="VALID", - name="embedding", - dtype=jnp.float32, - )(image) - - n, h, w, c = x.shape - x = jnp.reshape(x, [n, h * w, c]) - - # Add posemb before adding extra token. - x = out["with_posemb"] = x + get_posemb(self, self.posemb, (h, w), c, "pos_embedding", jnp.float32) - - if self.pool_type == "tok": - cls = self.param("cls", nn.initializers.zeros, (1, 1, c), x.dtype) - x = jnp.concatenate([jnp.tile(cls, [n, 1, 1]), x], axis=1) - - n, _, c = x.shape # n,l,d - x = nn.Dropout(rate=self.dropout)(x, not train) - - # Kevin edit: now cast back to dtype_mm (potentially half precision) - x = x.astype(self.dtype_mm) - - x, out["encoder"] = Encoder( - depth=self.depth, - mlp_dim=self.mlp_dim, - num_heads=self.num_heads, - dropout=self.dropout, - scan=self.scan, - remat_policy=self.remat_policy, - dtype_mm=self.dtype_mm, - name="Transformer", - )(x, deterministic=not train) - encoded = out["encoded"] = x - - if self.pool_type == "map": - x = out["head_input"] = MAPHead( - num_heads=self.num_heads, - mlp_dim=self.mlp_dim, - dtype=self.dtype_mm, - )(x) - elif self.pool_type == "gap": - x = out["head_input"] = jnp.mean(x, axis=1) - elif self.pool_type == "0": - x = out["head_input"] = x[:, 0] - elif self.pool_type == "tok": - x = out["head_input"] = x[:, 0] - encoded = encoded[:, 1:] - elif self.pool_type == "none": - pass - else: - raise ValueError(f"Unknown pool type: '{self.pool_type}'") - - x_2d = jnp.reshape(encoded, [n, h, w, -1]) - - if self.rep_size: - rep_size = self.width if self.rep_size is True else self.rep_size - hid = nn.Dense(rep_size, dtype=self.dtype_mm, name="pre_logits") - # NOTE: In the past we did not include tanh in pre_logits. - # For few-shot, it should not matter much, as it whitens anyways. - x_2d = nn.tanh(hid(x_2d)) - x = nn.tanh(hid(x)) - - out["pre_logits_2d"] = x_2d - out["pre_logits"] = x - - if self.num_classes: - kw = {"kernel_init": nn.initializers.zeros} if self.head_zeroinit else {} - head = nn.Dense(self.num_classes, dtype=self.dtype_mm, name="head", **kw) - x_2d = out["logits_2d"] = head(x_2d) - x = out["logits"] = head(x) - - return x, out - - -def Module(num_classes=None, *, variant=None, **kw): # pylint: disable=invalid-name # noqa: N802 - """Factory function, because linen really don't like what I'm doing!""" - return _Module(num_classes, **{**decode_variant(variant), **kw}) - - -def decode_variant(variant): - """Converts a string like "B" or "B/32" into a params dict.""" - if variant is None: - return {} - - v, patch = variant, {} - if "/" in variant: - v, patch = variant.split("/") - patch = {"patch_size": (int(patch), int(patch))} - - return { - # pylint:disable=line-too-long - # Reference: Table 2 of https://arxiv.org/abs/2106.04560. - "width": { - "mu": 32, - "Ti": 192, - "S": 384, - "M": 512, - "B": 768, - "L": 1024, - "So400m": 1152, - "H": 1280, - "g": 1408, - "g-opt": 1536, - "G": 1664, - "G-opt": 1536, - "e": 1792, - }[v], - "depth": { - "mu": 1, - "Ti": 12, - "S": 12, - "M": 12, - "B": 12, - "L": 24, - "So400m": 27, - "H": 32, - "g": 40, - "g-opt": 40, - "G": 48, - "G-opt": 48, - "e": 56, - }[v], - "mlp_dim": { - "mu": 128, - "Ti": 768, - "S": 1536, - "M": 2048, - "B": 3072, - "L": 4096, - "So400m": 4304, - "H": 5120, - "g": 6144, - "g-opt": 6144, - "G": 8192, - "G-opt": 8192, - "e": 15360, - }[v], - "num_heads": { - "mu": 2, - "Ti": 3, - "S": 6, - "M": 8, - "B": 12, - "L": 16, - "So400m": 16, - "H": 16, - "g": 16, - "g-opt": 16, - "G": 16, - "G-opt": 16, - "e": 16, - }[v], - # pylint:enable=line-too-long - **patch, - } diff --git a/capvector-pi05/src/openpi/models/tokenizer.py b/capvector-pi05/src/openpi/models/tokenizer.py deleted file mode 100644 index ec36ff6f83524e4eadec2c4474f92ee1b5f2bb20..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/models/tokenizer.py +++ /dev/null @@ -1,371 +0,0 @@ -import logging -import os - -import jax -import numpy as np -import orbax.checkpoint as ocp -import sentencepiece -from transformers import AutoProcessor - -import openpi.models.utils.fsq_tokenizer as fsq_tokenizer -import openpi.shared.download as download - - -class PaligemmaTokenizer: - def __init__(self, max_len: int = 48): - self._max_len = max_len - - path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"}) - with path.open("rb") as f: - self._tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read()) - - def tokenize(self, prompt: str, state: np.ndarray | None = None) -> tuple[np.ndarray, np.ndarray]: - cleaned_text = prompt.strip().replace("_", " ").replace("\n", " ") - if state is not None: - # This is the Pi05 format, where the state is part of the discrete language input. - discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 - state_str = " ".join(map(str, discretized_state)) - full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: " - tokens = self._tokenizer.encode(full_prompt, add_bos=True) - else: - # This is the Pi0 format, where the state is part of the continuous action expert input. - # tokenize "\n" separately as the "start of answer" token - tokens = self._tokenizer.encode(cleaned_text, add_bos=True) + self._tokenizer.encode("\n") - tokens_len = len(tokens) - if tokens_len < self._max_len: - padding = [False] * (self._max_len - tokens_len) - mask = [True] * tokens_len + padding - tokens = tokens + padding - else: - if len(tokens) > self._max_len: - logging.warning( - f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. " - "Consider increasing the `max_token_len` in your model config if this happens frequently." - ) - tokens = tokens[: self._max_len] - mask = [True] * self._max_len - - return np.asarray(tokens), np.asarray(mask) - - -class FASTTokenizer: - def __init__(self, max_len: int = 256, fast_tokenizer_path: str = "physical-intelligence/fast"): - self._max_len = max_len - - # Download base PaliGemma tokenizer - path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"}) - with path.open("rb") as f: - self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read()) - - # Instantiate FAST tokenizer - self._fast_tokenizer = AutoProcessor.from_pretrained(fast_tokenizer_path, trust_remote_code=True) - self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens - - def tokenize( - self, prompt: str, state: np.ndarray, actions: np.ndarray | None - ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - cleaned_text = prompt.lower().strip().replace("_", " ") - - # Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1]) - discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 - - # Convention: prefix includes prompt and string-representation of state, followed by ';' - state_str = " ".join(map(str, discretized_state)) - prefix = f"Task: {cleaned_text}, State: {state_str};\n" - prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True) - - if actions is not None: - # Tokenize actions with FAST tokenizer --> map to last tokens in PaliGemma vocab - action_tokens = self._fast_tokenizer(actions[None])[0] - action_tokens_in_pg = self._act_tokens_to_paligemma_tokens(action_tokens) - - # Convention: postfix contains 'Action:' followed by FAST tokens, followed by '|' - postfix_tokens = ( - self._paligemma_tokenizer.encode("Action: ") - + action_tokens_in_pg.tolist() - + self._paligemma_tokenizer.encode("|", add_eos=True) - ) - else: - postfix_tokens = [] - - # Create output token sequence & masks - # AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens) - tokens = prefix_tokens + postfix_tokens - token_mask = [True] * len(tokens) - ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens) - loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only - - # Pad tokens to max length - tokens_len = len(tokens) - if tokens_len < self._max_len: - padding = [False] * (self._max_len - tokens_len) - tokens = tokens + padding - token_mask = token_mask + padding - ar_mask = ar_mask + padding - loss_mask = loss_mask + padding - else: - if len(tokens) > self._max_len: - logging.warning( - f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. " - "Consider increasing the `max_token_len` in your model config if this happens frequently." - ) - tokens = tokens[: self._max_len] - token_mask = token_mask[: self._max_len] - ar_mask = ar_mask[: self._max_len] - loss_mask = loss_mask[: self._max_len] - - return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask) - - def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray: - # Decode predicted output tokens - decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist()) - - # Extract actions from FAST model outputs - if "Action: " not in decoded_tokens: - return np.zeros((action_horizon, action_dim), dtype=np.float32) - - # Extract actions from decoded tokens - raw_action_tokens = np.array( - self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip()) - ) - action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens) - return self._fast_tokenizer.decode( - [action_tokens.tolist()], time_horizon=action_horizon, action_dim=action_dim - )[0] - - def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray: - if isinstance(tokens, list): - tokens = np.array(tokens) - return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens - - -########################################################################### -## The tokenizers below are used for RoboArena baseline implementations. ## -## They are *not* used for pi0-style models. ## -########################################################################### - - -class BinningTokenizer: - """ - Standard RT-2 / OpenVLA style binning tokenizer. - """ - - def __init__(self, max_len: int = 256, n_bins: int = 256): - self._max_len = max_len - self._n_bins = n_bins - - # Download base PaliGemma tokenizer - path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"}) - with path.open("rb") as f: - self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read()) - - self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens - - def tokenize( - self, prompt: str, state: np.ndarray, actions: np.ndarray | None - ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - """Tokenize a prompt and state into a sequence of tokens. - - Args: - prompt: The text prompt to tokenize. - state: The state array to discretize and tokenize. - actions: Must be None. Action encoding is not currently supported. - - Returns: - A tuple of (tokens, token_mask, ar_mask, targets). - - Raises: - NotImplementedError: If actions is not None. - """ - cleaned_text = prompt.lower().strip().replace("_", " ") - - # Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1]) - discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 - - # Convention: prefix includes prompt and string-representation of state, followed by ';' - state_str = " ".join(map(str, discretized_state)) - prefix = f"Task: {cleaned_text}, State: {state_str};\n" - prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True) - - if actions is not None: - raise NotImplementedError("BinningTokenizer does not support encoding actions atm (only for inference use)") - postfix_tokens = [] - - # Create output token sequence & masks - # AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens) - tokens = prefix_tokens + postfix_tokens - token_mask = [True] * len(tokens) - ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens) - loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only - - # Pad tokens to max length - tokens_len = len(tokens) - if tokens_len < self._max_len: - padding = [False] * (self._max_len - tokens_len) - tokens = tokens + padding - token_mask = token_mask + padding - ar_mask = ar_mask + padding - loss_mask = loss_mask + padding - else: - if len(tokens) > self._max_len: - logging.warning( - f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. " - "Consider increasing the `max_token_len` in your model config if this happens frequently." - ) - tokens = tokens[: self._max_len] - token_mask = token_mask[: self._max_len] - ar_mask = ar_mask[: self._max_len] - loss_mask = loss_mask[: self._max_len] - - return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask) - - def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray: - # Decode predicted output tokens - decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist()) - - # Extract actions from FAST model outputs - if "Action: " not in decoded_tokens: - return np.zeros((action_horizon, action_dim), dtype=np.float32) - - # Extract actions from decoded tokens - raw_action_tokens = np.array( - self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip()) - ) - action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens) - if len(action_tokens) < action_horizon * action_dim: - return np.zeros([action_horizon, action_dim], dtype=np.float32) - action_tokens = action_tokens[: (action_horizon * action_dim)].reshape([action_horizon, action_dim]) - return action_tokens / self._n_bins * 2 - 1 - - def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray: - if isinstance(tokens, list): - tokens = np.array(tokens) - return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens - - -class FSQTokenizer: - """ - FSQ tokenizer from the FAST paper baselines. - """ - - def __init__(self, max_len: int = 256, fsq_tokenizer_path: str | None = None): - self._max_len = max_len - - assert fsq_tokenizer_path is not None, "fsq_tokenizer_path must be provided" - # Download tokenizer - path = download.maybe_download(fsq_tokenizer_path) - tok_path = os.path.join(path, os.listdir(path)[0]) - - # Split step from path - step = int(tok_path.split("/")[-1]) - base_path = tok_path.rsplit("/", 1)[0] - - mgr = ocp.CheckpointManager( - base_path, - item_handlers={ - "params": ocp.StandardCheckpointHandler(), - "opt_state": ocp.StandardCheckpointHandler(), - "config": ocp.JsonCheckpointHandler(), - }, - options=ocp.CheckpointManagerOptions(max_to_keep=1), - ) - - try: - restored = mgr.restore( - step, args=ocp.args.Composite(config=ocp.args.JsonRestore(), params=ocp.args.StandardRestore()) - ) - config = restored["config"] - self._params = restored["params"] - self._fsq_tokenizer = fsq_tokenizer.FsqAttentionTokenizer(**config) - except Exception as e: - raise RuntimeError( - f"Failed to load FSQ tokenizer checkpoint from {fsq_tokenizer_path}. Error: {e!s}" - ) from e - - # Compile tokenize and detokenize functions - self._tokenize_fn = jax.jit( - lambda params, x: self._fsq_tokenizer.apply({"params": params}, x, method=self._fsq_tokenizer.tokenize) - ) - self._detokenize_fn = jax.jit( - lambda params, x: self._fsq_tokenizer.apply({"params": params}, x, method=self._fsq_tokenizer.detokenize) - ) - - # Download base PaliGemma tokenizer - path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"}) - with path.open("rb") as f: - self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read()) - - self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens - - def tokenize( - self, prompt: str, state: np.ndarray, actions: np.ndarray | None - ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - cleaned_text = prompt.lower().strip().replace("_", " ") - - # Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1]) - discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 - - # Convention: prefix includes prompt and string-representation of state, followed by ';' - state_str = " ".join(map(str, discretized_state)) - prefix = f"Task: {cleaned_text}, State: {state_str};\n" - prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True) - - if actions is not None: - raise NotImplementedError("FSQTokenizer does not support encoding actions atm (only for inference use)") - postfix_tokens = [] - - # Create output token sequence & masks - # AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens) - tokens = prefix_tokens + postfix_tokens - token_mask = [True] * len(tokens) - ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens) - loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only - - # Pad tokens to max length - tokens_len = len(tokens) - if tokens_len < self._max_len: - padding = [False] * (self._max_len - tokens_len) - tokens = tokens + padding - token_mask = token_mask + padding - ar_mask = ar_mask + padding - loss_mask = loss_mask + padding - else: - if len(tokens) > self._max_len: - logging.warning( - f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. " - "Consider increasing the `max_token_len` in your model config if this happens frequently." - ) - tokens = tokens[: self._max_len] - token_mask = token_mask[: self._max_len] - ar_mask = ar_mask[: self._max_len] - loss_mask = loss_mask[: self._max_len] - - return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask) - - def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray: - # Decode predicted output tokens - decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist()) - - # Extract actions from FAST model outputs - if "Action: " not in decoded_tokens: - return np.zeros((action_horizon, action_dim), dtype=np.float32) - - # Extract actions from decoded tokens - raw_action_tokens = np.array( - self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip()) - ) - action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens) - try: - # Move computation to CPU and compile on-demand - device = jax.devices("cpu")[0] - with jax.default_device(device): - detok_act = self._detokenize_fn(self._params, action_tokens[None, ...])[0] - return detok_act[: action_horizon * action_dim].reshape([action_horizon, action_dim]) - except Exception as e: - logging.warning(f"Error decoding FSQ: {e}") - return np.zeros((action_horizon, action_dim)) - - def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray: - if isinstance(tokens, list): - tokens = np.array(tokens) - return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens diff --git a/capvector-pi05/src/openpi/models/tokenizer_test.py b/capvector-pi05/src/openpi/models/tokenizer_test.py deleted file mode 100644 index 3182e0a190ec94275f6af3ddb8bf896ee70be9c1..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/models/tokenizer_test.py +++ /dev/null @@ -1,27 +0,0 @@ -import numpy as np - -from openpi.models import tokenizer as _tokenizer - - -def test_tokenize(): - tokenizer = _tokenizer.PaligemmaTokenizer(max_len=10) - tokens, masks = tokenizer.tokenize("Hello, world!") - - assert tokens.shape == (10,) - assert masks.shape == (10,) - - -def test_fast_tokenizer(): - prompt = "Hello, world!" - state = np.random.rand(5).astype(np.float32) - action = np.random.rand(3, 2).astype(np.float32) - tokenizer = _tokenizer.FASTTokenizer(max_len=256) - tokens, token_masks, ar_masks, loss_masks = tokenizer.tokenize(prompt, state, action) - - assert tokens.shape == (256,) - assert token_masks.shape == (256,) - assert ar_masks.shape == (256,) - assert loss_masks.shape == (256,) - - act = tokenizer.extract_actions(tokens, 3, 2) - assert act.shape == (3, 2) diff --git a/capvector-pi05/src/openpi/models/utils/fsq_tokenizer.py b/capvector-pi05/src/openpi/models/utils/fsq_tokenizer.py deleted file mode 100644 index 3c8f4033d1a438d8652414871b2c46fd5a39af0a..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/models/utils/fsq_tokenizer.py +++ /dev/null @@ -1,472 +0,0 @@ -import math -from typing import Any, Literal - -import chex -from einops import einops -from flax import linen as nn -from flax.linen.module import Module -from flax.linen.module import compact -from flax.struct import dataclass -from flax.typing import Array -import jax -import jax.numpy as jnp - - -class FsqCodebook(nn.Module): - input_dim: int - target_codebook_size: int - codebook_type: Literal["fsq", "lfq"] - - _bins_per_dim: tuple[int] | None = None - - @property - def bins_per_dim(self) -> tuple[int]: - if self._bins_per_dim is not None: - return self._bins_per_dim - - if self.codebook_type == "fsq": - return self._get_bins_fsq(self.target_codebook_size) - elif self.codebook_type == "lfq": # noqa: RET505 - return self._get_bins_lfq(self.target_codebook_size) - elif self.codebook_type == "custom": - return self._get_bins_custom(self.target_codebook_size) - else: - raise ValueError(f"Codebook type {self.codebook_type} not supported.") - - @property - def place_values(self) -> jnp.ndarray: - place_values = [1] - for b in self.bins_per_dim[:-1]: - place_values.append(place_values[-1] * b) - return jnp.array(place_values) - - @staticmethod - def _get_bins_fsq(target_codebook_size: int) -> tuple[int]: - """ - Get bins per dimension based on codebook size, from the original FSQ paper. - """ - if target_codebook_size == 2**8: - return (8, 6, 5) - elif target_codebook_size == 2**10: # noqa: RET505 - return (8, 5, 5, 5) - elif target_codebook_size == 2**12: - return (7, 5, 5, 5, 5) - elif target_codebook_size == 2**14: - return (8, 8, 8, 6, 5) - elif target_codebook_size == 2**16: - return (8, 8, 8, 5, 5, 5) - else: - raise ValueError(f"Codebook size {target_codebook_size} not supported.") - - @staticmethod - def _get_bins_custom(target_codebook_size: int) -> tuple[int]: - if target_codebook_size == 2**8: - return (16, 16) - elif target_codebook_size == 2**10: # noqa: RET505 - return (32, 32) - elif target_codebook_size == 2**12: - return (64, 64) - elif target_codebook_size == 2**14: - return (128, 128) - elif target_codebook_size == 2**16: - return (256, 256) - return None - - @staticmethod - def _get_bins_lfq(target_codebook_size: int) -> tuple[int]: - """ - Get bins per dimension according to the Lookup-Free Quantization paper (2 bins per dimension) - """ - assert target_codebook_size & (target_codebook_size - 1) == 0, "Codebook size should be a power of two for LFQ" - - return (2,) * int(math.log2(target_codebook_size)) - - def setup(self): - self.proj_down = nn.Dense(len(self.bins_per_dim)) - self.proj_up = nn.Dense(self.input_dim) - - def __call__(self, inputs: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: - tokens, z = self.encode(inputs) - output = self.decode(tokens, z_grad=z) - return tokens, output - - def encode(self, inputs: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: - bases = jnp.array(self.bins_per_dim) - - x = self.proj_down(inputs) - z = jnp.tanh(x) - - # Quantize - digits = jnp.round((z + 1) * (bases - 1) / 2).astype(jnp.int32) - tokens = self.undigitize(digits) - - return tokens, z - - def decode(self, tokens: jnp.ndarray, z_grad: jax.Array | None = None) -> jnp.ndarray: - bases = jnp.array(self.bins_per_dim) - digits = self.digitize(tokens) - - z_q = digits / (bases - 1) * 2 - 1 - - if z_grad is not None: - chex.assert_equal_shape([z_q, z_grad]) - z_q = jax.lax.stop_gradient(z_q - z_grad) + z_grad - - return self.proj_up(z_q) - - def undigitize(self, digits: jnp.ndarray) -> jnp.ndarray: - return jnp.sum(digits * jnp.array(self.place_values), axis=-1) - - def digitize(self, tokens: jnp.ndarray) -> jnp.ndarray: - return (tokens[..., None] // jnp.array(self.place_values)) % jnp.array(self.bins_per_dim) - - @property - def vocab_size(self) -> int: - return math.prod(self.bins_per_dim) - - -class ResNetDownBlock(nn.Module): - stride: int = 1 - n_filters: int = 64 - dropout_rate: float = 0.0 - group_size: int = 32 - - @nn.compact - def __call__(self, x: jnp.ndarray, *, train: bool = True) -> jnp.ndarray: - skip = x - - if self.stride > 1 or x.shape[-1] != self.n_filters: - skip = nn.Conv(self.n_filters, (self.stride,), (self.stride,), "SAME")(skip) - - x = nn.Conv(self.n_filters, (3,), (self.stride,), "SAME")(x) - x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x) - x = nn.Dropout(self.dropout_rate)(x, deterministic=not train) - x = nn.relu(x) - x = nn.Conv(self.n_filters, (3,), (1,), "SAME")(x) - - return skip + x - - -class ResNetUpBlock(nn.Module): - stride: int = 1 - n_filters: int = 64 - dropout_rate: float = 0.0 - group_size: int = 32 - - @nn.compact - def __call__(self, x: jnp.ndarray, *, train: bool = True) -> jnp.ndarray: - skip = x - - if self.stride > 1: - skip = nn.ConvTranspose(self.n_filters, (self.stride,), (self.stride,), "SAME")(skip) - - x = nn.ConvTranspose(self.n_filters, (3,), (self.stride,), "SAME")(x) - x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x) - x = nn.Dropout(self.dropout_rate)(x, deterministic=not train) - x = nn.relu(x) - x = nn.ConvTranspose(self.n_filters, (3,), (1,), "SAME")(x) - - return skip + x - - -@dataclass -class LfqCodebookOutput: - tokens: jnp.ndarray - z: jnp.ndarray - z_q: jnp.ndarray - token_log_probs: jnp.ndarray - commit_loss: jnp.ndarray - - -class LookupFreeQuantization(nn.Module): - num_dims: int - latent_dim: int - - def setup(self): - self.codebook = jnp.array([-1, 1]) - self.activation = nn.tanh - - self.project_down = nn.Dense(self.num_dims) - self.project_up = nn.Dense(self.latent_dim) - - def encode(self, z: jnp.ndarray) -> jnp.ndarray: - z = self.project_down(z) - token_squared_distances = jnp.square(z[..., None] - self.codebook) - token_bits = jnp.argmin(token_squared_distances, axis=-1) - return jnp.sum(token_bits * (2 ** jnp.arange(self.num_dims)), axis=-1) - - def decode(self, tokens: jnp.ndarray) -> jnp.ndarray: - token_bits = (tokens[..., None] & (2 ** jnp.arange(self.num_dims))).astype(jnp.int32) - return self.project_up(self.codebook[token_bits]) - - def loss(self, x: jnp.ndarray) -> LfqCodebookOutput: - z = self.project_down(x) - z = self.activation(z) - - token_squared_distances = jnp.square(z[..., None] - self.codebook) - tokens = jnp.argmin(token_squared_distances, axis=-1) - - token_bit_log_probs = -token_squared_distances - # Compute token log probs for tokens 0..2^num_dims-1 by summing corresponding log-probs - token_bit_expansions = jnp.bitwise_and( - jnp.arange(2**self.num_dims)[None, :], 2 ** jnp.arange(self.num_dims)[:, None] - ).astype(jnp.int32) - token_log_probs = ( - token_bit_log_probs[..., 0] @ (1 - token_bit_expansions) - + token_bit_log_probs[..., 1] @ token_bit_expansions - ) # (batch_size, num_tokens, 2 ** num_dims) - token_log_probs = jax.lax.stop_gradient(jax.nn.log_softmax(token_log_probs, axis=-1)) - chex.assert_shape(token_log_probs, (*x.shape[:-1], 2**self.num_dims)) - - z_q = self.codebook[tokens] - commit_loss = jnp.square(z - z_q).mean() - z_q = jax.lax.stop_gradient(z_q - z) + z - - z_q = self.project_up(z_q) - z = self.project_up(z) - - tokens = jnp.sum(tokens * (len(self.codebook) ** jnp.arange(self.num_dims)), axis=-1) - return LfqCodebookOutput( - tokens=tokens, - z=z, - z_q=z_q, - token_log_probs=jnp.zeros(()), - commit_loss=commit_loss, - ) - - -def make_block_causal_attention_matrix(q: jnp.ndarray, k: jnp.ndarray, bs_q: int, bs_k: int) -> jnp.ndarray: - return nn.make_attention_mask(q, k, pairwise_fn=lambda x, y: jnp.greater_equal(x // bs_k, y // bs_q)) - - -class GeGLU(Module): - """Gated Linear Unit with GELU (GeGLU) activation function. - GeGLU is a Flax layer that combines a linear transformation with a GELU - activation function in a gating mechanism. It is often used in Transformer models - to provide non-linear capabilities while preserving a strong linear component. - - Attributes: - features: the number of output features (default: None). - """ - - output_dim: int = -1 - - @compact - def __call__(self, inputs: Array) -> Array: - """Applies the GeGLU activation to the inputs. - Args: - inputs: the nd-array to apply the GeGLU activation function to. - Returns: - The transformed input. - """ - output_dim = inputs.shape[-1] if self.output_dim == -1 else self.output_dim - - x = nn.Dense(output_dim * 2)(inputs) - x, gate = x[..., :output_dim], x[..., output_dim:] - return x * nn.gelu(gate) - - -class CrossAttentionLayer(nn.Module): - dropout_rate: float = 0.0 - num_heads: int = None - causal: bool = False - mlp_ratio: float = 4.0 - - @nn.compact - def __call__( - self, - x: jnp.ndarray, - y: jnp.ndarray, - *, - mask_self: jnp.ndarray | None = None, - mask_cross: jnp.ndarray | None = None, - train: bool = True, - ) -> jnp.ndarray: - d_embed = x.shape[-1] - seq_len_q = x.shape[-2] - seq_len_k = y.shape[-2] - - if self.causal: - # One block size will be 1 - bs_q = max(seq_len_q // seq_len_k, 1) - bs_k = max(seq_len_k // seq_len_q, 1) - - mask_self = nn.make_causal_mask(x[..., 0]) - mask_cross = make_block_causal_attention_matrix(x[..., 0], y[..., 0], bs_q, bs_k) - - # Self-attention block - skip = x - x = nn.LayerNorm()(x) - x = nn.MultiHeadDotProductAttention( - num_heads=self.num_heads or d_embed // 64, - dropout_rate=self.dropout_rate, - deterministic=not train, - )(x, x, x, mask=mask_self) - x = skip + x - - # Cross-attention block - skip = x - x = nn.LayerNorm()(x) - x = nn.MultiHeadDotProductAttention( - num_heads=self.num_heads or d_embed // 64, - dropout_rate=self.dropout_rate, - deterministic=not train, - )(x, y, y, mask=mask_cross) - x = skip + x - - # MLP block - skip = x - x = nn.LayerNorm()(x) - x = nn.Dense(int(d_embed * self.mlp_ratio))(x) - x = nn.Dropout(self.dropout_rate)(x, deterministic=not train) - x = GeGLU()(x) - x = nn.Dense(d_embed)(x) - return skip + x - - -def sinusoidal_pe_init(_, shape: tuple[int, int]) -> jnp.ndarray: - seq_len, d_embed = shape - - position = jnp.arange(0, seq_len, 1) - div_term = jnp.exp(jnp.arange(0, d_embed, 2) * -(jnp.log(10000.0) / d_embed)) - return jnp.concatenate( - [ - jnp.sin(position[:, jnp.newaxis] * div_term), - jnp.cos(position[:, jnp.newaxis] * div_term), - ], - axis=-1, - ) - - -class TokenizerEncoderDecoder(nn.Module): - num_tokens: int - num_cross_tokens: int - num_layers: int - causal: bool - - mlp_ratio: float = 4.0 - use_state_conditioning: bool = False - - @nn.compact - def __call__( - self, - y: jnp.ndarray, - *, - train: bool = True, - state_conditioning: jnp.ndarray | None = None, - mask: jnp.ndarray | None = None, - ) -> jnp.ndarray: - x = self.param("q_embed", sinusoidal_pe_init, (self.num_tokens, y.shape[-1])) - x = jax.numpy.broadcast_to(x, y.shape[:-2] + x.shape[-2:]) - - if mask is not None: - # mask is (batch_dims..., num_cross_tokens) - chex.assert_equal_shape([y[..., 0], mask]) - attn_mask = einops.repeat(mask, "... kv -> ... 1 q kv", q=self.num_tokens) - else: - attn_mask = jnp.ones((*y.shape[:-2], 1, self.num_tokens, self.num_cross_tokens)) - - if self.use_state_conditioning: - assert state_conditioning is not None, "State conditioning is required for this model." - state_embed = nn.Dense(y.shape[-1], name="state_proj")(state_conditioning)[..., None, :] - y = jnp.concatenate([y, state_embed], axis=-2) - attn_mask = jnp.concatenate([attn_mask, jnp.ones_like(attn_mask[..., 0:1])], axis=-1) - - y = y + self.param("y_pos_enc", sinusoidal_pe_init, y.shape[-2:]) - - for _ in range(self.num_layers): - x = CrossAttentionLayer(causal=self.causal, mlp_ratio=self.mlp_ratio)( - x, y, train=train, mask_self=None, mask_cross=attn_mask - ) - - return x - - -class FsqAttentionTokenizer(nn.Module): - embed_dim: int - data_dim: int - data_horizon: int - num_tokens: int - num_layers: int - target_codebook_size: int - causal: bool = False - mlp_ratio: float = 2.0 - - bound: float | None = None - - use_state_conditioning: bool = False - - @property - def vocab_size(self) -> int: - return math.prod(FsqCodebook._get_bins_fsq(self.target_codebook_size)) # noqa: SLF001 - - def setup(self): - self.proj = nn.Dense(self.embed_dim) - self.encoder = TokenizerEncoderDecoder( - num_tokens=self.num_tokens, - num_cross_tokens=self.data_horizon, - num_layers=self.num_layers, - causal=self.causal, - use_state_conditioning=self.use_state_conditioning, - mlp_ratio=self.mlp_ratio, - ) - self.codebook = FsqCodebook( - input_dim=self.embed_dim, - target_codebook_size=self.target_codebook_size, - codebook_type="custom", - ) - self.decoder = TokenizerEncoderDecoder( - num_tokens=self.data_horizon, - num_cross_tokens=self.num_tokens, - num_layers=self.num_layers, - causal=self.causal, - use_state_conditioning=self.use_state_conditioning, - mlp_ratio=self.mlp_ratio, - ) - - self.proj_mean = nn.Dense(self.data_dim) - self.out_scale = self.param("out_scale", lambda _: jnp.full((), 1.0)) - - def tokenize( - self, action: jnp.ndarray, *, obs: jnp.ndarray | None = None, train: bool = False - ) -> tuple[jnp.ndarray, jnp.ndarray]: - if self.bound is not None: - action = jnp.clip(action, -self.bound, self.bound) - - x = self.proj(action) - x = self.encoder(x, train=train, state_conditioning=obs) - - return self.codebook.encode(x) - - def detokenize(self, tokens: jnp.ndarray, *, obs: jnp.ndarray | None = None) -> jnp.ndarray: - x = self.decoder(self.codebook.decode(tokens), state_conditioning=obs) - mean = self.proj_mean(x) - return mean * self.out_scale - - def loss( - self, action: jnp.ndarray, *, obs: jnp.ndarray | None = None, train: bool = True - ) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]: - # Encode - x = self.proj(action) - z = self.encoder(x, train=train, state_conditioning=obs) - - # Quantize - tokens, z = self.codebook(z) - - # Decode - x = self.decoder(z, train=train, state_conditioning=obs) - mean = self.proj_mean(x) * self.out_scale - - mse = jnp.mean(jnp.square(action - mean)) - mae = jnp.mean(jnp.abs(action - mean)) - - return mse, { - "mse": mse, - "mae": mae, - } - - def __call__(self, *args: Any, **kwargs: Any) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]: - """ - Dummy for .init - """ - return self.loss(*args, **kwargs) diff --git a/capvector-pi05/src/openpi/models/vit.py b/capvector-pi05/src/openpi/models/vit.py deleted file mode 100644 index 1408e28d5ae9dd75aaf19aa44dd6c82fd9dda7bb..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/models/vit.py +++ /dev/null @@ -1,307 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""ViT implementation adapted from https://github.com/google-research/vision_transformer/blob/main/vit_jax/models_vit.py.""" - -from collections.abc import Callable -from typing import Any - -import flax.linen as nn -import jax -import jax.numpy as jnp - -from openpi.models import resnet as models_resnet - -Array = Any -PRNGKey = Any -Shape = tuple[int] -Dtype = Any - - -class IdentityLayer(nn.Module): - """Identity layer, convenient for giving a name to an array.""" - - @nn.compact - def __call__(self, x): - return x - - -class AddPositionEmbs(nn.Module): - """Adds learned positional embeddings to the inputs. - - Attributes: - posemb_init: positional embedding initializer. - """ - - posemb_init: Callable[[PRNGKey, Shape, Dtype], Array] - param_dtype: Dtype = jnp.float32 - - @nn.compact - def __call__(self, inputs): - """Applies the AddPositionEmbs module. - - Args: - inputs: Inputs to the layer. - - Returns: - Output tensor with shape `(bs, timesteps, in_dim)`. - """ - # inputs.shape is (batch_size, seq_len, emb_dim). - assert inputs.ndim == 3, f"Number of dimensions should be 3, but it is: {inputs.ndim}" - pos_emb_shape = (1, inputs.shape[1], inputs.shape[2]) - pe = self.param("pos_embedding", self.posemb_init, pos_emb_shape, self.param_dtype) - return inputs + pe - - -class MlpBlock(nn.Module): - """Transformer MLP / feed-forward block.""" - - mlp_dim: int - dtype: Dtype = jnp.float32 - param_dtype: Dtype = jnp.float32 - out_dim: int | None = None - dropout_rate: float = 0.1 - kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.xavier_uniform() - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.normal(stddev=1e-6) - - @nn.compact - def __call__(self, inputs, *, deterministic): - """Applies Transformer MlpBlock module.""" - actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim - x = nn.Dense( - features=self.mlp_dim, - dtype=self.dtype, - param_dtype=self.param_dtype, - kernel_init=self.kernel_init, - bias_init=self.bias_init, - )( # pytype: disable=wrong-arg-types - inputs - ) - x = nn.gelu(x) - x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) - output = nn.Dense( - features=actual_out_dim, - dtype=self.dtype, - param_dtype=self.param_dtype, - kernel_init=self.kernel_init, - bias_init=self.bias_init, - )( # pytype: disable=wrong-arg-types - x - ) - return nn.Dropout(rate=self.dropout_rate)(output, deterministic=deterministic) - - -class Encoder1DBlock(nn.Module): - """Transformer encoder layer. - - Attributes: - inputs: input data. - mlp_dim: dimension of the mlp on top of attention block. - dtype: the dtype of the computation (default: float32). - dropout_rate: dropout rate. - attention_dropout_rate: dropout for attention heads. - deterministic: bool, deterministic or not (to apply dropout). - num_heads: Number of heads in nn.MultiHeadDotProductAttention - """ - - mlp_dim: int - num_heads: int - dtype: Dtype = jnp.float32 - dropout_rate: float = 0.1 - attention_dropout_rate: float = 0.1 - - @nn.compact - def __call__(self, inputs, deterministic): - """Applies Encoder1DBlock module. - - Args: - inputs: Inputs to the layer. - deterministic: Dropout will not be applied when set to true. - - Returns: - output after transformer encoder block. - """ - - # Attention block. - assert inputs.ndim == 3, f"Expected (batch, seq, hidden) got {inputs.shape}" - x = nn.LayerNorm(dtype=self.dtype)(inputs) - x = nn.MultiHeadDotProductAttention( - dtype=self.dtype, - kernel_init=nn.initializers.xavier_uniform(), - broadcast_dropout=False, - deterministic=deterministic, - dropout_rate=self.attention_dropout_rate, - num_heads=self.num_heads, - # why isn't this true by default??? - force_fp32_for_softmax=True, - )(x, x) - x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) - x = x + inputs - - # MLP block. - y = nn.LayerNorm(dtype=self.dtype)(x) - y = MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate)( - y, deterministic=deterministic - ) - - return x + y, None - - -class Encoder(nn.Module): - """Transformer Model Encoder for sequence to sequence translation. - - Attributes: - num_layers: number of layers - mlp_dim: dimension of the mlp on top of attention block - num_heads: Number of heads in nn.MultiHeadDotProductAttention - dropout_rate: dropout rate. - attention_dropout_rate: dropout rate in self attention. - """ - - dtype: jax.typing.DTypeLike - num_layers: int - mlp_dim: int - num_heads: int - dropout_rate: float = 0.1 - attention_dropout_rate: float = 0.1 - add_position_embedding: bool = True - - @nn.compact - def __call__(self, x, *, train): - """Applies Transformer model on the inputs. - - Args: - x: Inputs to the layer. - train: Set to `True` when training. - - Returns: - output of a transformer encoder. - """ - assert x.ndim == 3 # (batch, len, emb) - - if self.add_position_embedding: - x = AddPositionEmbs( - posemb_init=nn.initializers.normal(stddev=0.02), # from BERT. - name="posembed_input", - )(x) - x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) - - x = x.astype(self.dtype) - # Input Encoder - block = nn.remat(Encoder1DBlock, prevent_cse=False, static_argnums=(2,)) - x, _ = nn.scan( - block, - variable_axes={"params": 0}, - split_rngs={"params": True, "dropout": True}, - in_axes=nn.broadcast, - length=self.num_layers, - )( - name="encoderblock", - mlp_dim=self.mlp_dim, - dropout_rate=self.dropout_rate, - attention_dropout_rate=self.attention_dropout_rate, - dtype=self.dtype, - num_heads=self.num_heads, - )(x, not train) - return nn.LayerNorm(name="encoder_norm", dtype=self.dtype)(x) - - -class VisionTransformer(nn.Module): - """VisionTransformer.""" - - dtype: jax.typing.DTypeLike - num_classes: int - patches: Any - transformer: Any - hidden_size: int - resnet: Any | None = None - representation_size: int | None = None - classifier: str = "token" - head_bias_init: float = 0.0 - encoder: type[nn.Module] = Encoder - model_name: str | None = None - - @nn.compact - def __call__(self, inputs, *, train): - x = inputs - # (Possibly partial) ResNet root. - if self.resnet is not None: - width = int(64 * self.resnet.width_factor) - - # Root block. - x = models_resnet.StdConv( - features=width, kernel_size=(7, 7), strides=(2, 2), use_bias=False, name="conv_root" - )(x) - x = nn.GroupNorm(name="gn_root")(x) - x = nn.relu(x) - x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding="SAME") - - # ResNet stages. - if self.resnet.num_layers: - x = models_resnet.ResNetStage( - block_size=self.resnet.num_layers[0], nout=width, first_stride=(1, 1), name="block1" - )(x) - for i, block_size in enumerate(self.resnet.num_layers[1:], 1): - x = models_resnet.ResNetStage( - block_size=block_size, nout=width * 2**i, first_stride=(2, 2), name=f"block{i + 1}" - )(x) - - n, h, w, c = x.shape - - # We can merge s2d+emb into a single conv; it's the same. - x = nn.Conv( - features=self.hidden_size, - kernel_size=self.patches.size, - strides=self.patches.size, - padding="VALID", - name="embedding", - )(x) - - # Here, x is a grid of embeddings. - - # (Possibly partial) Transformer. - if self.transformer is not None: - n, h, w, c = x.shape - x = jnp.reshape(x, [n, h * w, c]) - - # If we want to add a class token, add it here. - if self.classifier in ["token", "token_unpooled"]: - cls = self.param("cls", nn.initializers.zeros, (1, 1, c)) - cls = jnp.tile(cls, [n, 1, 1]) - x = jnp.concatenate([cls, x], axis=1) - - x = self.encoder(name="Transformer", **self.transformer, dtype=self.dtype)(x, train=train) - - if self.classifier == "token": - x = x[:, 0] - elif self.classifier == "gap": - x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) # (1,) or (1,2) - elif self.classifier in ["unpooled", "token_unpooled"]: - pass - else: - raise ValueError(f"Invalid classifier={self.classifier}") - - if self.representation_size is not None: - x = nn.Dense(features=self.representation_size, name="pre_logits")(x) - x = nn.tanh(x) - else: - x = IdentityLayer(name="pre_logits")(x) - - if self.num_classes: - x = nn.Dense( - features=self.num_classes, - name="head", - kernel_init=nn.initializers.zeros, - bias_init=nn.initializers.constant(self.head_bias_init), - )(x) - return x diff --git a/capvector-pi05/src/openpi/models_pytorch/gemma_pytorch.py b/capvector-pi05/src/openpi/models_pytorch/gemma_pytorch.py deleted file mode 100644 index 54ee1693970659b0780d0ce8f0413df8854511b2..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/models_pytorch/gemma_pytorch.py +++ /dev/null @@ -1,291 +0,0 @@ -from typing import Literal - -import pytest -import torch -from torch import nn -from transformers import GemmaForCausalLM -from transformers import PaliGemmaForConditionalGeneration -from transformers.models.auto import CONFIG_MAPPING -from transformers.models.gemma import modeling_gemma - - -class PaliGemmaWithExpertModel(nn.Module): - def __init__( - self, - vlm_config, - action_expert_config, - use_adarms=None, - precision: Literal["bfloat16", "float32"] = "bfloat16", - ): - if use_adarms is None: - use_adarms = [False, False] - super().__init__() - - vlm_config_hf = CONFIG_MAPPING["paligemma"]() - vlm_config_hf._vocab_size = 257152 # noqa: SLF001 - vlm_config_hf.image_token_index = 257152 - vlm_config_hf.text_config.hidden_size = vlm_config.width - vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim - vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads - vlm_config_hf.text_config.head_dim = vlm_config.head_dim - vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth - vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads - vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh" - vlm_config_hf.text_config.torch_dtype = "float32" - vlm_config_hf.text_config.vocab_size = 257152 - vlm_config_hf.text_config.use_adarms = use_adarms[0] - vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None - vlm_config_hf.vision_config.intermediate_size = 4304 - vlm_config_hf.vision_config.projection_dim = 2048 - vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast" - vlm_config_hf.vision_config.torch_dtype = "float32" - - action_expert_config_hf = CONFIG_MAPPING["gemma"]( - head_dim=action_expert_config.head_dim, - hidden_size=action_expert_config.width, - intermediate_size=action_expert_config.mlp_dim, - num_attention_heads=action_expert_config.num_heads, - num_hidden_layers=action_expert_config.depth, - num_key_value_heads=action_expert_config.num_kv_heads, - vocab_size=257152, - hidden_activation="gelu_pytorch_tanh", - torch_dtype="float32", - use_adarms=use_adarms[1], - adarms_cond_dim=action_expert_config.width if use_adarms[1] else None, - ) - - self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf) - self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf) - self.gemma_expert.model.embed_tokens = None - - self.to_bfloat16_for_selected_params(precision) - - def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"): - if precision == "bfloat16": - self.to(dtype=torch.bfloat16) - elif precision == "float32": - self.to(dtype=torch.float32) - return - else: - raise ValueError(f"Invalid precision: {precision}") - - params_to_keep_float32 = [ - "vision_tower.vision_model.embeddings.patch_embedding.weight", - "vision_tower.vision_model.embeddings.patch_embedding.bias", - "vision_tower.vision_model.embeddings.position_embedding.weight", - "input_layernorm", - "post_attention_layernorm", - "model.norm", - ] - - for name, param in self.named_parameters(): - if any(selector in name for selector in params_to_keep_float32): - param.data = param.data.to(dtype=torch.float32) - - def embed_image(self, image: torch.Tensor): - return self.paligemma.model.get_image_features(image) - - def embed_language_tokens(self, tokens: torch.Tensor): - return self.paligemma.language_model.embed_tokens(tokens) - - def forward( - self, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: list[torch.FloatTensor] | pytest.Cache | None = None, - inputs_embeds: list[torch.FloatTensor] | None = None, - use_cache: bool | None = None, - adarms_cond: list[torch.Tensor] | None = None, - output_hidden_states: bool | None = None, - ): - if adarms_cond is None: - adarms_cond = [None, None] - if inputs_embeds[1] is None: - prefix_output = self.paligemma.language_model.forward( - inputs_embeds=inputs_embeds[0], - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - use_cache=use_cache, - adarms_cond=adarms_cond[0] if adarms_cond is not None else None, - ) - prefix_past_key_values = prefix_output.past_key_values - prefix_output = prefix_output.last_hidden_state - suffix_output = None - elif inputs_embeds[0] is None: - suffix_output = self.gemma_expert.model.forward( - inputs_embeds=inputs_embeds[1], - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - use_cache=use_cache, - adarms_cond=adarms_cond[1] if adarms_cond is not None else None, - ) - suffix_output = suffix_output.last_hidden_state - prefix_output = None - prefix_past_key_values = None - else: - models = [self.paligemma.language_model, self.gemma_expert.model] - num_layers = self.paligemma.config.text_config.num_hidden_layers - - # Check if gradient checkpointing is enabled for any of the models - use_gradient_checkpointing = ( - hasattr(self.gemma_expert.model, "gradient_checkpointing") - and self.gemma_expert.model.gradient_checkpointing - and self.training - ) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training) - - # Force enable gradient checkpointing if we're in training mode and the model supports it - if self.training and hasattr(self.gemma_expert.model, "gradient_checkpointing"): - if not self.gemma_expert.model.gradient_checkpointing: - print("Forcing gradient checkpointing to be enabled for Gemma expert model") - self.gemma_expert.model.gradient_checkpointing = True - use_gradient_checkpointing = True - - # Debug gradient checkpointing status - if hasattr(self, "_debug_gc_printed") and not self._debug_gc_printed: - print(f"Gemma expert model gradient checkpointing: {use_gradient_checkpointing}") - print(f"Model training mode: {self.training}") - print( - f"Gemma expert model has gradient_checkpointing attr: {hasattr(self.gemma_expert.model, 'gradient_checkpointing')}" - ) - if hasattr(self.gemma_expert.model, "gradient_checkpointing"): - print( - f"Gemma expert model gradient_checkpointing value: {self.gemma_expert.model.gradient_checkpointing}" - ) - self._debug_gc_printed = True - - # Define the complete layer computation function for gradient checkpointing - def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond): - models = [self.paligemma.language_model, self.gemma_expert.model] - - query_states = [] - key_states = [] - value_states = [] - gates = [] - for i, hidden_states in enumerate(inputs_embeds): - layer = models[i].layers[layer_idx] - hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901 - gates.append(gate) - - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) - query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) - key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) - value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - - query_states.append(query_state) - key_states.append(key_state) - value_states.append(value_state) - - # Concatenate and process attention - query_states = torch.cat(query_states, dim=2) - key_states = torch.cat(key_states, dim=2) - value_states = torch.cat(value_states, dim=2) - - dummy_tensor = torch.zeros( - query_states.shape[0], - query_states.shape[2], - query_states.shape[-1], - device=query_states.device, - dtype=query_states.dtype, - ) - cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids) - query_states, key_states = modeling_gemma.apply_rotary_pos_emb( - query_states, key_states, cos, sin, unsqueeze_dim=1 - ) - - batch_size = query_states.shape[0] - scaling = self.paligemma.language_model.layers[layer_idx].self_attn.scaling - - # Attention computation - att_output, _ = modeling_gemma.eager_attention_forward( - self.paligemma.language_model.layers[layer_idx].self_attn, - query_states, - key_states, - value_states, - attention_mask, - scaling, - ) - # Get head_dim from the current layer, not from the model - head_dim = self.paligemma.language_model.layers[layer_idx].self_attn.head_dim - att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim) - - # Process layer outputs - outputs_embeds = [] - start_pos = 0 - for i, hidden_states in enumerate(inputs_embeds): - layer = models[i].layers[layer_idx] - end_pos = start_pos + hidden_states.shape[1] - - if att_output.dtype != layer.self_attn.o_proj.weight.dtype: - att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) - out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos]) - - # first residual - out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001 - after_first_residual = out_emb.clone() - out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i]) - # Convert to bfloat16 if the next layer (mlp) uses bfloat16 - if layer.mlp.up_proj.weight.dtype == torch.bfloat16: - out_emb = out_emb.to(dtype=torch.bfloat16) - - out_emb = layer.mlp(out_emb) - # second residual - out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001 - outputs_embeds.append(out_emb) - start_pos = end_pos - - return outputs_embeds - - # Process all layers with gradient checkpointing if enabled - all_hidden_states = () if output_hidden_states else None - for layer_idx in range(num_layers): - if output_hidden_states: - all_hidden_states += (inputs_embeds,) - if use_gradient_checkpointing: - inputs_embeds = torch.utils.checkpoint.checkpoint( - compute_layer_complete, - layer_idx, - inputs_embeds, - attention_mask, - position_ids, - adarms_cond, - use_reentrant=False, - preserve_rng_state=False, - ) - else: - inputs_embeds = compute_layer_complete( - layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond - ) - - # Old code removed - now using compute_layer_complete function above - - # final norm - # Define final norm computation function for gradient checkpointing - def compute_final_norms(inputs_embeds, adarms_cond): - outputs_embeds = [] - for i, hidden_states in enumerate(inputs_embeds): - out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i]) - outputs_embeds.append(out_emb) - return outputs_embeds - - # Apply gradient checkpointing to final norm if enabled - if use_gradient_checkpointing: - outputs_embeds = torch.utils.checkpoint.checkpoint( - compute_final_norms, inputs_embeds, adarms_cond, use_reentrant=False, preserve_rng_state=False - ) - else: - outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond) - - if output_hidden_states: - all_hidden_states += (outputs_embeds,) - - prefix_output = outputs_embeds[0] - suffix_output = outputs_embeds[1] - prefix_past_key_values = None - - if output_hidden_states: - return [prefix_output, suffix_output], prefix_past_key_values, all_hidden_states - - return [prefix_output, suffix_output], prefix_past_key_values diff --git a/capvector-pi05/src/openpi/models_pytorch/pi0_align_pytorch.py b/capvector-pi05/src/openpi/models_pytorch/pi0_align_pytorch.py deleted file mode 100644 index abd9a73f968808fe49d7ca5d168c095ab699e466..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/models_pytorch/pi0_align_pytorch.py +++ /dev/null @@ -1,528 +0,0 @@ -import logging -import math - -import torch -from torch import Tensor -from torch import nn -import torch.nn.functional as F # noqa: N812 - -import openpi.models.gemma as _gemma -from openpi.models_pytorch.gemma_pytorch import PaliGemmaWithExpertModel -import openpi.models_pytorch.preprocessing_pytorch as _preprocessing - -from vggt.utils.load_fn import preprocess_images_from_openpi -from vggt.heads.utils import custom_pooling - - -def get_safe_dtype(target_dtype, device_type): - """Get a safe dtype for the given device type.""" - if device_type == "cpu": - # CPU doesn't support bfloat16, use float32 instead - if target_dtype == torch.bfloat16: - return torch.float32 - if target_dtype == torch.float64: - return torch.float64 - return target_dtype - - -def create_sinusoidal_pos_embedding( - time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu" -) -> Tensor: - """Computes sine-cosine positional embedding vectors for scalar positions.""" - if dimension % 2 != 0: - raise ValueError(f"dimension ({dimension}) must be divisible by 2") - - if time.ndim != 1: - raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") - - dtype = get_safe_dtype(torch.float64, device.type) - fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) - period = min_period * (max_period / min_period) ** fraction - - # Compute the outer product - scaling_factor = 1.0 / period * 2 * math.pi - sin_input = scaling_factor[None, :] * time[:, None] - return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) - - -def sample_beta(alpha, beta, bsize, device): - alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device) - beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device) - dist = torch.distributions.Beta(alpha_t, beta_t) - return dist.sample((bsize,)) - - -def make_att_2d_masks(pad_masks, att_masks): - """Copied from big_vision. - - Tokens can attend to valid inputs tokens which have a cumulative mask_ar - smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to - setup several types of attention, for example: - - [[1 1 1 1 1 1]]: pure causal attention. - - [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between - themselves and the last 3 tokens have a causal attention. The first - entry could also be a 1 without changing behaviour. - - [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a - block can attend all previous blocks and all tokens on the same block. - - Args: - input_mask: bool[B, N] true if its part of the input, false if padding. - mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on - it and 0 where it shares the same attention mask as the previous token. - """ - if att_masks.ndim != 2: - raise ValueError(att_masks.ndim) - if pad_masks.ndim != 2: - raise ValueError(pad_masks.ndim) - - cumsum = torch.cumsum(att_masks, dim=1) - att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None] - pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] - return att_2d_masks & pad_2d_masks - - -class PI0Pytorch(nn.Module): - def __init__(self, config, extra_config): - super().__init__() - self.config = config - self.pi05 = config.pi05 - - paligemma_config = _gemma.get_config(config.paligemma_variant) - action_expert_config = _gemma.get_config(config.action_expert_variant) - - self.LLM_width = paligemma_config.width - - self.paligemma_with_expert = PaliGemmaWithExpertModel( - paligemma_config, - action_expert_config, - use_adarms=[False, True] if self.pi05 else [False, False], - precision=config.dtype, - ) - - self.action_in_proj = nn.Linear(32, action_expert_config.width) - self.action_out_proj = nn.Linear(action_expert_config.width, 32) - - if self.pi05: - self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width) - self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) - else: - self.state_proj = nn.Linear(32, action_expert_config.width) - self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width) - self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) - - torch.set_float32_matmul_precision("high") - self.sample_actions = torch.compile(self.sample_actions, mode="max-autotune") - - # Initialize gradient checkpointing flag - self.gradient_checkpointing_enabled = False - - # Specific config for SpatialForcing alignment - self.vla_layers_align = extra_config.vla_layers_align - self.vggt_layers_align = extra_config.vggt_layers_align - self.pooling_func = extra_config.pooling_func - self.use_vggt_pe = extra_config.use_vggt_pe - - msg = "transformers_replace is not installed correctly. Please install it with `uv pip install transformers==4.53.2` and `cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/`." - try: - from transformers.models.siglip import check - - if not check.check_whether_transformers_replace_is_installed_correctly(): - raise ValueError(msg) - except ImportError: - raise ValueError(msg) from None - - def gradient_checkpointing_enable(self): - """Enable gradient checkpointing for memory optimization.""" - self.gradient_checkpointing_enabled = True - self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True - self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True - self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True - - logging.info("Enabled gradient checkpointing for PI0Pytorch model") - - def gradient_checkpointing_disable(self): - """Disable gradient checkpointing.""" - self.gradient_checkpointing_enabled = False - self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False - self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False - self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False - - logging.info("Disabled gradient checkpointing for PI0Pytorch model") - - def is_gradient_checkpointing_enabled(self): - """Check if gradient checkpointing is enabled.""" - return self.gradient_checkpointing_enabled - - def _apply_checkpoint(self, func, *args, **kwargs): - """Helper method to apply gradient checkpointing if enabled.""" - if self.gradient_checkpointing_enabled and self.training: - return torch.utils.checkpoint.checkpoint( - func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs - ) - return func(*args, **kwargs) - - def _prepare_attention_masks_4d(self, att_2d_masks): - """Helper method to prepare 4D attention masks for transformer.""" - att_2d_masks_4d = att_2d_masks[:, None, :, :] - return torch.where(att_2d_masks_4d, 0.0, -2.3819763e38) - - def _preprocess_observation(self, observation, *, train=True, get_wo_aug=False): - """Helper method to preprocess observation.""" - observation = _preprocessing.preprocess_observation_pytorch(observation, train=train, get_wo_aug=get_wo_aug) - return ( - list(observation.images.values()), - list(observation.img_wo_aug.values()) if get_wo_aug else None, - list(observation.image_padding_mask.values()), - list(observation.image_masks.values()), - observation.tokenized_prompt, - observation.tokenized_prompt_mask, - observation.state, - ) - - def sample_noise(self, shape, device): - return torch.normal( - mean=0.0, - std=1.0, - size=shape, - dtype=torch.float32, - device=device, - ) - - def sample_time(self, bsize, device): - time_beta = sample_beta(1.5, 1.0, bsize, device) - time = time_beta * 0.999 + 0.001 - return time.to(dtype=torch.float32, device=device) - - def embed_prefix( - self, images, img_masks, lang_tokens, lang_masks - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Embed images with SigLIP and language tokens with embedding layer to prepare - for PaliGemma transformer processing. - """ - embs = [] - pad_masks = [] - att_masks = [] - - # Process images - for img, img_mask in zip(images, img_masks, strict=True): - - def image_embed_func(img): - return self.paligemma_with_expert.embed_image(img) - - img_emb = self._apply_checkpoint(image_embed_func, img) - - bsize, num_img_embs = img_emb.shape[:2] - - embs.append(img_emb) - pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs)) - - # Create attention masks so that image tokens attend to each other - att_masks += [0] * num_img_embs - - img_len = len(att_masks) - - # Process language tokens - def lang_embed_func(lang_tokens): - lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) - lang_emb_dim = lang_emb.shape[-1] - return lang_emb * math.sqrt(lang_emb_dim) - - lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens) - - embs.append(lang_emb) - pad_masks.append(lang_masks) - - # full attention between image and language inputs - num_lang_embs = lang_emb.shape[1] - att_masks += [0] * num_lang_embs - - embs = torch.cat(embs, dim=1) - pad_masks = torch.cat(pad_masks, dim=1) - att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) - - # Get batch size from the first dimension of the concatenated tensors - bsize = pad_masks.shape[0] - att_masks = att_masks[None, :].expand(bsize, len(att_masks)) - - return embs, pad_masks, att_masks, img_len - - def embed_suffix(self, state, noisy_actions, timestep): - """Embed state, noisy_actions, timestep to prepare for Expert Gemma processing.""" - embs = [] - pad_masks = [] - att_masks = [] - - if not self.pi05: - if self.state_proj.weight.dtype == torch.float32: - state = state.to(torch.float32) - - # Embed state - def state_proj_func(state): - return self.state_proj(state) - - state_emb = self._apply_checkpoint(state_proj_func, state) - - embs.append(state_emb[:, None, :]) - bsize = state_emb.shape[0] - device = state_emb.device - - state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device) - pad_masks.append(state_mask) - - # Set attention masks so that image and language inputs do not attend to state or actions - att_masks += [1] - - # Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1] - time_emb = create_sinusoidal_pos_embedding( - timestep, self.action_in_proj.out_features, min_period=4e-3, max_period=4.0, device=timestep.device - ) - time_emb = time_emb.type(dtype=timestep.dtype) - - # Fuse timestep + action information using an MLP - def action_proj_func(noisy_actions): - return self.action_in_proj(noisy_actions) - - action_emb = self._apply_checkpoint(action_proj_func, noisy_actions) - - if not self.pi05: - time_emb = time_emb[:, None, :].expand_as(action_emb) - action_time_emb = torch.cat([action_emb, time_emb], dim=2) - - # Apply MLP layers - def mlp_func(action_time_emb): - x = self.action_time_mlp_in(action_time_emb) - x = F.silu(x) # swish == silu - return self.action_time_mlp_out(x) - - action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb) - adarms_cond = None - else: - # time MLP (for adaRMS) - def time_mlp_func(time_emb): - x = self.time_mlp_in(time_emb) - x = F.silu(x) # swish == silu - x = self.time_mlp_out(x) - return F.silu(x) - - time_emb = self._apply_checkpoint(time_mlp_func, time_emb) - action_time_emb = action_emb - adarms_cond = time_emb - - # Add to input tokens - embs.append(action_time_emb) - - bsize, action_time_dim = action_time_emb.shape[:2] - action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device) - pad_masks.append(action_time_mask) - - # Set attention masks so that image, language and state inputs do not attend to action tokens - att_masks += [1] + ([0] * (self.config.action_horizon - 1)) - - embs = torch.cat(embs, dim=1) - pad_masks = torch.cat(pad_masks, dim=1) - att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device) - att_masks = att_masks[None, :].expand(bsize, len(att_masks)) - - return embs, pad_masks, att_masks, adarms_cond - - def forward(self, observation, actions, vggt, align_proj, noise=None, time=None) -> Tensor: - """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)""" - images, img_wo_aug, img_padding_mask, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation( - observation, train=True, get_wo_aug=True - ) - img_resize_wo_aug = preprocess_images_from_openpi(img_wo_aug) # specific for VGGT with 518px input - - # =================================== VLA action loss =================================== - - if noise is None: - noise = self.sample_noise(actions.shape, actions.device) - - if time is None: - time = self.sample_time(actions.shape[0], actions.device) - - time_expanded = time[:, None, None] - x_t = time_expanded * noise + (1 - time_expanded) * actions - u_t = noise - actions - - prefix_embs, prefix_pad_masks, prefix_att_masks, img_len = self.embed_prefix( - images, img_masks, lang_tokens, lang_masks - ) - suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time) - if ( - self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype - == torch.bfloat16 - ): - suffix_embs = suffix_embs.to(dtype=torch.bfloat16) - prefix_embs = prefix_embs.to(dtype=torch.bfloat16) - - pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) - att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) - - att_2d_masks = make_att_2d_masks(pad_masks, att_masks) - position_ids = torch.cumsum(pad_masks, dim=1) - 1 - - # Prepare attention masks - att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks) - - # Apply gradient checkpointing if enabled - def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond): - (_, suffix_out), _, all_hidden_states = self.paligemma_with_expert.forward( - attention_mask=att_2d_masks_4d, - position_ids=position_ids, - past_key_values=None, - inputs_embeds=[prefix_embs, suffix_embs], - use_cache=False, - adarms_cond=[None, adarms_cond], - output_hidden_states=True, - ) - return suffix_out, all_hidden_states - - suffix_out, all_hidden_states = self._apply_checkpoint( - forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond - ) - - suffix_out = suffix_out[:, -self.config.action_horizon :] - suffix_out = suffix_out.to(dtype=torch.float32) - - # Apply gradient checkpointing to final action projection if enabled - def action_out_proj_func(suffix_out): - return self.action_out_proj(suffix_out) - - v_t = self._apply_checkpoint(action_out_proj_func, suffix_out) - - action_loss = F.mse_loss(u_t, v_t) - - # =================================== Alignment loss =================================== - - # VLA hidden states - (prefix_hidden, _) = all_hidden_states[self.vla_layers_align] # 18 total layers of paligemma - vision_hidden = prefix_hidden[:, :img_len, :] - - # VGGT hidden states - with torch.autocast("cuda", dtype=torch.bfloat16), torch.no_grad(): - vggt_output = vggt(img_resize_wo_aug) - agg_vggt_hidden = vggt_output["features"][self.vggt_layers_align] # 24 for total layers of VGGT - patch_start_idx = vggt_output["patch_start_idx"] - original_img = vggt_output["images"] - vggt_hidden = agg_vggt_hidden[:, :, patch_start_idx:, :] - - # Resample VGGT hidden states to match the resolution of VLA hidden states - H, W = original_img.shape[-2:] - patch_h, patch_w = H // vggt.patch_size, W // vggt.patch_size - vggt_hidden = custom_pooling( - vggt_hidden, (patch_h, patch_w), (H, W), vision_hidden, self.pooling_func, self.use_vggt_pe - ) - - # empty image feature masks for alignment loss - tokens_per_img = img_len // len(images) - img_masks_stack = torch.stack(img_masks, dim=1) - align_mask = torch.repeat_interleave(img_masks_stack, repeats=tokens_per_img, dim=1) - - # useless non-rectangular image padding feature masks for alignment loss - img_padding_mask = torch.stack(img_padding_mask, dim=1) - target_size = img_padding_mask.shape[-1] // 14 # 224/14, where 14 is the patch size of Gemma encoder - mask_downsampled = F.interpolate( - img_padding_mask.float(), - size=(target_size, target_size), - mode='nearest' - ).bool().flatten(start_dim=1) - assert align_mask.shape == mask_downsampled.shape, \ - "align_mask shape don't match img_padding_mask shape, please manually modify the patch size of Gemma encoder (now is 14)" - align_mask = mask_downsampled & align_mask - - # calculate align loss - with torch.autocast("cuda", dtype=torch.bfloat16): - align_loss = align_proj(vision_hidden, vggt_hidden, align_mask) - - return action_loss, align_loss - - @torch.no_grad() - def sample_actions(self, device, observation, noise=None, num_steps=10) -> Tensor: - """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)""" - bsize = observation.state.shape[0] - if noise is None: - actions_shape = (bsize, self.config.action_horizon, self.config.action_dim) - noise = self.sample_noise(actions_shape, device) - - images, _, _, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(observation, train=False) - - prefix_embs, prefix_pad_masks, prefix_att_masks, _ = self.embed_prefix(images, img_masks, lang_tokens, lang_masks) - prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) - prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 - - # Compute image and language key value cache - prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks) - self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001 - - _, past_key_values = self.paligemma_with_expert.forward( - attention_mask=prefix_att_2d_masks_4d, - position_ids=prefix_position_ids, - past_key_values=None, - inputs_embeds=[prefix_embs, None], - use_cache=True, - ) - - dt = -1.0 / num_steps - dt = torch.tensor(dt, dtype=torch.float32, device=device) - - x_t = noise - time = torch.tensor(1.0, dtype=torch.float32, device=device) - while time >= -dt / 2: - expanded_time = time.expand(bsize) - v_t = self.denoise_step( - state, - prefix_pad_masks, - past_key_values, - x_t, - expanded_time, - ) - - # Euler step - use new tensor assignment instead of in-place operation - x_t = x_t + dt * v_t - time += dt - return x_t - - def denoise_step( - self, - state, - prefix_pad_masks, - past_key_values, - x_t, - timestep, - ): - """Apply one denoising step of the noise `x_t` at a given timestep.""" - suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep) - - suffix_len = suffix_pad_masks.shape[1] - batch_size = prefix_pad_masks.shape[0] - prefix_len = prefix_pad_masks.shape[1] - - prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len) - - suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) - - full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2) - - prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] - position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 - - # Prepare attention masks - full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks) - self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001 - - outputs_embeds, _ = self.paligemma_with_expert.forward( - attention_mask=full_att_2d_masks_4d, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=[None, suffix_embs], - use_cache=False, - adarms_cond=[None, adarms_cond], - ) - - suffix_out = outputs_embeds[1] - suffix_out = suffix_out[:, -self.config.action_horizon :] - suffix_out = suffix_out.to(dtype=torch.float32) - return self.action_out_proj(suffix_out) diff --git a/capvector-pi05/src/openpi/models_pytorch/pi0_pytorch.py b/capvector-pi05/src/openpi/models_pytorch/pi0_pytorch.py deleted file mode 100644 index c32308574dd3d76149a1b1236e5d956d6de100cf..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/models_pytorch/pi0_pytorch.py +++ /dev/null @@ -1,461 +0,0 @@ -import logging -import math - -import torch -from torch import Tensor -from torch import nn -import torch.nn.functional as F # noqa: N812 - -import openpi.models.gemma as _gemma -from openpi.models_pytorch.gemma_pytorch import PaliGemmaWithExpertModel -import openpi.models_pytorch.preprocessing_pytorch as _preprocessing - - -def get_safe_dtype(target_dtype, device_type): - """Get a safe dtype for the given device type.""" - if device_type == "cpu": - # CPU doesn't support bfloat16, use float32 instead - if target_dtype == torch.bfloat16: - return torch.float32 - if target_dtype == torch.float64: - return torch.float64 - return target_dtype - - -def create_sinusoidal_pos_embedding( - time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu" -) -> Tensor: - """Computes sine-cosine positional embedding vectors for scalar positions.""" - if dimension % 2 != 0: - raise ValueError(f"dimension ({dimension}) must be divisible by 2") - - if time.ndim != 1: - raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") - - dtype = get_safe_dtype(torch.float64, device.type) - fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) - period = min_period * (max_period / min_period) ** fraction - - # Compute the outer product - scaling_factor = 1.0 / period * 2 * math.pi - sin_input = scaling_factor[None, :] * time[:, None] - return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) - - -def sample_beta(alpha, beta, bsize, device): - alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device) - beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device) - dist = torch.distributions.Beta(alpha_t, beta_t) - return dist.sample((bsize,)) - - -def make_att_2d_masks(pad_masks, att_masks): - """Copied from big_vision. - - Tokens can attend to valid inputs tokens which have a cumulative mask_ar - smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to - setup several types of attention, for example: - - [[1 1 1 1 1 1]]: pure causal attention. - - [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between - themselves and the last 3 tokens have a causal attention. The first - entry could also be a 1 without changing behaviour. - - [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a - block can attend all previous blocks and all tokens on the same block. - - Args: - input_mask: bool[B, N] true if its part of the input, false if padding. - mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on - it and 0 where it shares the same attention mask as the previous token. - """ - if att_masks.ndim != 2: - raise ValueError(att_masks.ndim) - if pad_masks.ndim != 2: - raise ValueError(pad_masks.ndim) - - cumsum = torch.cumsum(att_masks, dim=1) - att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None] - pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] - return att_2d_masks & pad_2d_masks - - -class PI0Pytorch(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.pi05 = config.pi05 - - paligemma_config = _gemma.get_config(config.paligemma_variant) - action_expert_config = _gemma.get_config(config.action_expert_variant) - - self.paligemma_with_expert = PaliGemmaWithExpertModel( - paligemma_config, - action_expert_config, - use_adarms=[False, True] if self.pi05 else [False, False], - precision=config.dtype, - ) - - self.action_in_proj = nn.Linear(32, action_expert_config.width) - self.action_out_proj = nn.Linear(action_expert_config.width, 32) - - if self.pi05: - self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width) - self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) - else: - self.state_proj = nn.Linear(32, action_expert_config.width) - self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width) - self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) - - torch.set_float32_matmul_precision("high") - self.sample_actions = torch.compile(self.sample_actions, mode="max-autotune") - - # Initialize gradient checkpointing flag - self.gradient_checkpointing_enabled = False - - msg = "transformers_replace is not installed correctly. Please install it with `uv pip install transformers==4.53.2` and `cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/`." - try: - from transformers.models.siglip import check - - if not check.check_whether_transformers_replace_is_installed_correctly(): - raise ValueError(msg) - except ImportError: - raise ValueError(msg) from None - - def gradient_checkpointing_enable(self): - """Enable gradient checkpointing for memory optimization.""" - self.gradient_checkpointing_enabled = True - self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True - self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True - self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True - - logging.info("Enabled gradient checkpointing for PI0Pytorch model") - - def gradient_checkpointing_disable(self): - """Disable gradient checkpointing.""" - self.gradient_checkpointing_enabled = False - self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False - self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False - self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False - - logging.info("Disabled gradient checkpointing for PI0Pytorch model") - - def is_gradient_checkpointing_enabled(self): - """Check if gradient checkpointing is enabled.""" - return self.gradient_checkpointing_enabled - - def _apply_checkpoint(self, func, *args, **kwargs): - """Helper method to apply gradient checkpointing if enabled.""" - if self.gradient_checkpointing_enabled and self.training: - return torch.utils.checkpoint.checkpoint( - func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs - ) - return func(*args, **kwargs) - - def _prepare_attention_masks_4d(self, att_2d_masks): - """Helper method to prepare 4D attention masks for transformer.""" - att_2d_masks_4d = att_2d_masks[:, None, :, :] - return torch.where(att_2d_masks_4d, 0.0, -2.3819763e38) - - def _preprocess_observation(self, observation, *, train=True): - """Helper method to preprocess observation.""" - observation = _preprocessing.preprocess_observation_pytorch(observation, train=train) - return ( - list(observation.images.values()), - list(observation.image_masks.values()), - observation.tokenized_prompt, - observation.tokenized_prompt_mask, - observation.state, - ) - - def sample_noise(self, shape, device): - return torch.normal( - mean=0.0, - std=1.0, - size=shape, - dtype=torch.float32, - device=device, - ) - - def sample_time(self, bsize, device): - time_beta = sample_beta(1.5, 1.0, bsize, device) - time = time_beta * 0.999 + 0.001 - return time.to(dtype=torch.float32, device=device) - - def embed_prefix( - self, images, img_masks, lang_tokens, lang_masks - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Embed images with SigLIP and language tokens with embedding layer to prepare - for PaliGemma transformer processing. - """ - embs = [] - pad_masks = [] - att_masks = [] - - # Process images - for img, img_mask in zip(images, img_masks, strict=True): - - def image_embed_func(img): - return self.paligemma_with_expert.embed_image(img) - - img_emb = self._apply_checkpoint(image_embed_func, img) - - bsize, num_img_embs = img_emb.shape[:2] - - embs.append(img_emb) - pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs)) - - # Create attention masks so that image tokens attend to each other - att_masks += [0] * num_img_embs - - # Process language tokens - def lang_embed_func(lang_tokens): - lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) - lang_emb_dim = lang_emb.shape[-1] - return lang_emb * math.sqrt(lang_emb_dim) - - lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens) - - embs.append(lang_emb) - pad_masks.append(lang_masks) - - # full attention between image and language inputs - num_lang_embs = lang_emb.shape[1] - att_masks += [0] * num_lang_embs - - embs = torch.cat(embs, dim=1) - pad_masks = torch.cat(pad_masks, dim=1) - att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) - - # Get batch size from the first dimension of the concatenated tensors - bsize = pad_masks.shape[0] - att_masks = att_masks[None, :].expand(bsize, len(att_masks)) - - return embs, pad_masks, att_masks - - def embed_suffix(self, state, noisy_actions, timestep): - """Embed state, noisy_actions, timestep to prepare for Expert Gemma processing.""" - embs = [] - pad_masks = [] - att_masks = [] - - if not self.pi05: - if self.state_proj.weight.dtype == torch.float32: - state = state.to(torch.float32) - - # Embed state - def state_proj_func(state): - return self.state_proj(state) - - state_emb = self._apply_checkpoint(state_proj_func, state) - - embs.append(state_emb[:, None, :]) - bsize = state_emb.shape[0] - device = state_emb.device - - state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device) - pad_masks.append(state_mask) - - # Set attention masks so that image and language inputs do not attend to state or actions - att_masks += [1] - - # Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1] - time_emb = create_sinusoidal_pos_embedding( - timestep, self.action_in_proj.out_features, min_period=4e-3, max_period=4.0, device=timestep.device - ) - time_emb = time_emb.type(dtype=timestep.dtype) - - # Fuse timestep + action information using an MLP - def action_proj_func(noisy_actions): - return self.action_in_proj(noisy_actions) - - action_emb = self._apply_checkpoint(action_proj_func, noisy_actions) - - if not self.pi05: - time_emb = time_emb[:, None, :].expand_as(action_emb) - action_time_emb = torch.cat([action_emb, time_emb], dim=2) - - # Apply MLP layers - def mlp_func(action_time_emb): - x = self.action_time_mlp_in(action_time_emb) - x = F.silu(x) # swish == silu - return self.action_time_mlp_out(x) - - action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb) - adarms_cond = None - else: - # time MLP (for adaRMS) - def time_mlp_func(time_emb): - x = self.time_mlp_in(time_emb) - x = F.silu(x) # swish == silu - x = self.time_mlp_out(x) - return F.silu(x) - - time_emb = self._apply_checkpoint(time_mlp_func, time_emb) - action_time_emb = action_emb - adarms_cond = time_emb - - # Add to input tokens - embs.append(action_time_emb) - - bsize, action_time_dim = action_time_emb.shape[:2] - action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device) - pad_masks.append(action_time_mask) - - # Set attention masks so that image, language and state inputs do not attend to action tokens - att_masks += [1] + ([0] * (self.config.action_horizon - 1)) - - embs = torch.cat(embs, dim=1) - pad_masks = torch.cat(pad_masks, dim=1) - att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device) - att_masks = att_masks[None, :].expand(bsize, len(att_masks)) - - return embs, pad_masks, att_masks, adarms_cond - - def forward(self, observation, actions, noise=None, time=None) -> Tensor: - """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)""" - images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(observation, train=True) - - if noise is None: - noise = self.sample_noise(actions.shape, actions.device) - - if time is None: - time = self.sample_time(actions.shape[0], actions.device) - - time_expanded = time[:, None, None] - x_t = time_expanded * noise + (1 - time_expanded) * actions - u_t = noise - actions - - prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks) - suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time) - if ( - self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype - == torch.bfloat16 - ): - suffix_embs = suffix_embs.to(dtype=torch.bfloat16) - prefix_embs = prefix_embs.to(dtype=torch.bfloat16) - - pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) - att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) - - att_2d_masks = make_att_2d_masks(pad_masks, att_masks) - position_ids = torch.cumsum(pad_masks, dim=1) - 1 - - # Prepare attention masks - att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks) - - # Apply gradient checkpointing if enabled - def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond): - (_, suffix_out), _ = self.paligemma_with_expert.forward( - attention_mask=att_2d_masks_4d, - position_ids=position_ids, - past_key_values=None, - inputs_embeds=[prefix_embs, suffix_embs], - use_cache=False, - adarms_cond=[None, adarms_cond], - ) - return suffix_out - - suffix_out = self._apply_checkpoint( - forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond - ) - - suffix_out = suffix_out[:, -self.config.action_horizon :] - suffix_out = suffix_out.to(dtype=torch.float32) - - # Apply gradient checkpointing to final action projection if enabled - def action_out_proj_func(suffix_out): - return self.action_out_proj(suffix_out) - - v_t = self._apply_checkpoint(action_out_proj_func, suffix_out) - - return F.mse_loss(u_t, v_t, reduction="none") - - @torch.no_grad() - def sample_actions(self, device, observation, noise=None, num_steps=10) -> Tensor: - """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)""" - bsize = observation.state.shape[0] - if noise is None: - actions_shape = (bsize, self.config.action_horizon, self.config.action_dim) - noise = self.sample_noise(actions_shape, device) - - images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(observation, train=False) - - prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks) - prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) - prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 - - # Compute image and language key value cache - prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks) - self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001 - - _, past_key_values = self.paligemma_with_expert.forward( - attention_mask=prefix_att_2d_masks_4d, - position_ids=prefix_position_ids, - past_key_values=None, - inputs_embeds=[prefix_embs, None], - use_cache=True, - ) - - dt = -1.0 / num_steps - dt = torch.tensor(dt, dtype=torch.float32, device=device) - - x_t = noise - time = torch.tensor(1.0, dtype=torch.float32, device=device) - while time >= -dt / 2: - expanded_time = time.expand(bsize) - v_t = self.denoise_step( - state, - prefix_pad_masks, - past_key_values, - x_t, - expanded_time, - ) - - # Euler step - use new tensor assignment instead of in-place operation - x_t = x_t + dt * v_t - time += dt - return x_t - - def denoise_step( - self, - state, - prefix_pad_masks, - past_key_values, - x_t, - timestep, - ): - """Apply one denoising step of the noise `x_t` at a given timestep.""" - suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep) - - suffix_len = suffix_pad_masks.shape[1] - batch_size = prefix_pad_masks.shape[0] - prefix_len = prefix_pad_masks.shape[1] - - prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len) - - suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) - - full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2) - - prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] - position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 - - # Prepare attention masks - full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks) - self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001 - - outputs_embeds, _ = self.paligemma_with_expert.forward( - attention_mask=full_att_2d_masks_4d, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=[None, suffix_embs], - use_cache=False, - adarms_cond=[None, adarms_cond], - ) - - suffix_out = outputs_embeds[1] - suffix_out = suffix_out[:, -self.config.action_horizon :] - suffix_out = suffix_out.to(dtype=torch.float32) - return self.action_out_proj(suffix_out) diff --git a/capvector-pi05/src/openpi/models_pytorch/preprocessing_pytorch.py b/capvector-pi05/src/openpi/models_pytorch/preprocessing_pytorch.py deleted file mode 100644 index cb90f3672a761f4dcd6b2a51ba51e12c2d2ecce1..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/models_pytorch/preprocessing_pytorch.py +++ /dev/null @@ -1,190 +0,0 @@ -from collections.abc import Sequence -import logging - -import torch - -from openpi.shared import image_tools - -logger = logging.getLogger("openpi") - -# Constants moved from model.py -IMAGE_KEYS = ( - "base_0_rgb", - "left_wrist_0_rgb", - "right_wrist_0_rgb", -) - -IMAGE_RESOLUTION = (224, 224) - - -def preprocess_observation_pytorch( - observation, - *, - train: bool = False, - get_wo_aug: bool = False, - image_keys: Sequence[str] = IMAGE_KEYS, - image_resolution: tuple[int, int] = IMAGE_RESOLUTION, -): - """Torch.compile-compatible version of preprocess_observation_pytorch with simplified type annotations. - - This function avoids complex type annotations that can cause torch.compile issues. - """ - if not set(image_keys).issubset(observation.images): - raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}") - - batch_shape = observation.state.shape[:-1] - - out_images = {} - out_images_wo_aug = {} - for key in image_keys: - image = observation.images[key] - - # TODO: This is a hack to handle both [B, C, H, W] and [B, H, W, C] formats - # Handle both [B, C, H, W] and [B, H, W, C] formats - is_channels_first = image.shape[1] == 3 # Check if channels are in dimension 1 - - if is_channels_first: - # Convert [B, C, H, W] to [B, H, W, C] for processing - image = image.permute(0, 2, 3, 1) - - if image.shape[1:3] != image_resolution: - logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}") - image = image_tools.resize_with_pad_torch(image, *image_resolution) - - if train: - # Convert from [-1, 1] to [0, 1] for PyTorch augmentations - image = image / 2.0 + 0.5 - - # Apply PyTorch-based augmentations - if "wrist" not in key and not get_wo_aug: - # Geometric augmentations for non-wrist cameras - height, width = image.shape[1:3] - - # Random crop and resize - crop_height = int(height * 0.95) - crop_width = int(width * 0.95) - - # Random crop - max_h = height - crop_height - max_w = width - crop_width - if max_h > 0 and max_w > 0: - # Use tensor operations instead of .item() for torch.compile compatibility - start_h = torch.randint(0, max_h + 1, (1,), device=image.device) - start_w = torch.randint(0, max_w + 1, (1,), device=image.device) - image = image[:, start_h : start_h + crop_height, start_w : start_w + crop_width, :] - - # Resize back to original size - image = torch.nn.functional.interpolate( - image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w] - size=(height, width), - mode="bilinear", - align_corners=False, - ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] - - # Random rotation (small angles) - # Use tensor operations instead of .item() for torch.compile compatibility - angle = torch.rand(1, device=image.device) * 10 - 5 # Random angle between -5 and 5 degrees - if torch.abs(angle) > 0.1: # Only rotate if angle is significant - # Convert to radians - angle_rad = angle * torch.pi / 180.0 - - # Create rotation matrix - cos_a = torch.cos(angle_rad) - sin_a = torch.sin(angle_rad) - - # Apply rotation using grid_sample - grid_x = torch.linspace(-1, 1, width, device=image.device) - grid_y = torch.linspace(-1, 1, height, device=image.device) - - # Create meshgrid - grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing="ij") - - # Expand to batch dimension - grid_x = grid_x.unsqueeze(0).expand(image.shape[0], -1, -1) - grid_y = grid_y.unsqueeze(0).expand(image.shape[0], -1, -1) - - # Apply rotation transformation - grid_x_rot = grid_x * cos_a - grid_y * sin_a - grid_y_rot = grid_x * sin_a + grid_y * cos_a - - # Stack and reshape for grid_sample - grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1) - - image = torch.nn.functional.grid_sample( - image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w] - grid, - mode="bilinear", - padding_mode="zeros", - align_corners=False, - ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] - - # Save original images (with color_aug, but without rotation) for VGGT input - img_inv_padding = image.clone() if not (image == 0).all() else torch.ones_like(image) - img_inv_padding[~observation.image_padding_mask[key]] = 1.0 # Set padding areas to white - img_inv_padding = img_inv_padding.permute(0, 3, 1, 2) if is_channels_first else img_inv_padding - out_images_wo_aug[key] = img_inv_padding.contiguous() - - # Color augmentations for all cameras - # Random brightness - # Use tensor operations instead of .item() for torch.compile compatibility - brightness_factor = 0.7 + torch.rand(1, device=image.device) * 0.6 # Random factor between 0.7 and 1.3 - image = image * brightness_factor - - # Random contrast - # Use tensor operations instead of .item() for torch.compile compatibility - contrast_factor = 0.6 + torch.rand(1, device=image.device) * 0.8 # Random factor between 0.6 and 1.4 - mean = image.mean(dim=[1, 2, 3], keepdim=True) - image = (image - mean) * contrast_factor + mean - - # Random saturation (convert to HSV, modify S, convert back) - # For simplicity, we'll just apply a random scaling to the color channels - # Use tensor operations instead of .item() for torch.compile compatibility - saturation_factor = 0.5 + torch.rand(1, device=image.device) * 1.0 # Random factor between 0.5 and 1.5 - gray = image.mean(dim=-1, keepdim=True) - image = gray + (image - gray) * saturation_factor - - # Clamp values to [0, 1] - image = torch.clamp(image, 0, 1) - - # Back to [-1, 1] - image = image * 2.0 - 1.0 - - # Convert back to [B, C, H, W] format if it was originally channels-first - if is_channels_first: - image = image.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] - - out_images[key] = image - - # obtain mask - out_masks = {} - for key in out_images: - if key not in observation.image_masks: - # do not mask by default - out_masks[key] = torch.ones(batch_shape, dtype=torch.bool, device=observation.state.device) - else: - out_masks[key] = observation.image_masks[key] - - # obtain image padding mask for non-rectangular images - img_padding_mask = {key: observation.image_padding_mask[key] for key in out_images} - - # Create a simple object with the required attributes instead of using the complex Observation class - class SimpleProcessedObservation: - def __init__(self, **kwargs): - for key, value in kwargs.items(): - setattr(self, key, value) - - result_kwargs = { - "images": out_images, - "image_padding_mask": img_padding_mask, - "image_masks": out_masks, - "state": observation.state, - "tokenized_prompt": observation.tokenized_prompt, - "tokenized_prompt_mask": observation.tokenized_prompt_mask, - "token_ar_mask": observation.token_ar_mask, - "token_loss_mask": observation.token_loss_mask, - } - - if get_wo_aug: - result_kwargs["img_wo_aug"] = out_images_wo_aug - - return SimpleProcessedObservation(**result_kwargs) \ No newline at end of file diff --git a/capvector-pi05/src/openpi/models_pytorch/projectors.py b/capvector-pi05/src/openpi/models_pytorch/projectors.py deleted file mode 100644 index d9ff14bd50cb0d343f5450dae307aa2f3db871ac..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/models_pytorch/projectors.py +++ /dev/null @@ -1,64 +0,0 @@ -"""Implementation of additional projectors for additional inputs to the VLA models.""" -import torch -import torch.nn as nn -import openpi.models.gemma as _gemma - -class AlignProjector(nn.Module): - """ - calculate the alignment between LLM and VGGT embeddings. - """ - def __init__( - self, - llm_dim: int, - vggt_dim: int, - use_vlm_norm: bool = False, - ) -> None: - super().__init__() - - self.llm_dim = llm_dim - self.vggt_dim = vggt_dim - - self.fc1 = nn.Linear(self.llm_dim, 2 * self.vggt_dim, bias=True) - self.fc2 = nn.Linear(2 * self.vggt_dim, 2 * self.vggt_dim, bias=True) - self.act_fn1 = nn.GELU() - - self.vlm_norm = nn.LayerNorm(llm_dim) if use_vlm_norm else None - - self.initialize_weights() - - def initialize_weights(self): - # Initialize transformer layers: - def _basic_init(module): - if isinstance(module, nn.Linear): - torch.nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - self.apply(_basic_init) - - def align_dimension(self, LLM_embedding: torch.Tensor = None) -> torch.Tensor: - if self.vlm_norm is not None: - LLM_embedding = self.vlm_norm(LLM_embedding) - projected_features = self.fc1(LLM_embedding) - projected_features = self.act_fn1(projected_features) - projected_features = self.fc2(projected_features) - return projected_features - - def compute_align_loss_cosine(self, vision_hidden, vggt_hidden, align_mask): - # vision_hidden has a shape of (bs, N, D) - def mean_flat(x): - return torch.mean(x, dim=list(range(1, len(x.size())))) - align_loss = 0 - bsz = vision_hidden.shape[0] - for _vision, _vggt, _mask in zip(vision_hidden, vggt_hidden, align_mask): - _vision = torch.nn.functional.normalize(_vision, dim=-1) - _vggt = torch.nn.functional.normalize(_vggt, dim=-1) - # align_loss += 1 - torch.mean(vision_hidden * vggt_hidden).sum(dim=-1).mean() - align_loss += 1 - mean_flat((_vision * _vggt)[_mask].sum(dim=-1)) # Cosine similarity loss - align_loss /= bsz # Average over batch size - return align_loss - - def forward(self, LLM_emb, target_emb, align_mask): - # project vla dimension and calculate align loss - LLM_emb = self.align_dimension(LLM_emb) - align_loss = self.compute_align_loss_cosine(LLM_emb, target_emb, align_mask).mean() # mean for sequence length - return align_loss diff --git a/capvector-pi05/src/openpi/models_pytorch/transformers_replace/models/gemma/configuration_gemma.py b/capvector-pi05/src/openpi/models_pytorch/transformers_replace/models/gemma/configuration_gemma.py deleted file mode 100644 index 472dd16f9377a6868e5f7659a76519b37483d31a..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/models_pytorch/transformers_replace/models/gemma/configuration_gemma.py +++ /dev/null @@ -1,173 +0,0 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py. -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_gemma.py file directly. One of our CI enforces this. -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# coding=utf-8 -# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. -# -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional -from ...configuration_utils import PretrainedConfig - - -class GemmaConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the Gemma-7B. - e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b) - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - Args: - vocab_size (`int`, *optional*, defaults to 256000): - Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`GemmaModel`] - hidden_size (`int`, *optional*, defaults to 3072): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 24576): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 28): - Number of hidden layers in the Transformer decoder. - num_attention_heads (`int`, *optional*, defaults to 16): - Number of attention heads for each attention layer in the Transformer decoder. - num_key_value_heads (`int`, *optional*, defaults to 16): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details, check out [this - paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to - `num_attention_heads`. - head_dim (`int`, *optional*, defaults to 256): - The attention head dimension. - hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): - The legacy activation function. It is overwritten by the `hidden_activation`. - hidden_activation (`str` or `function`, *optional*): - The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` - if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. - max_position_embeddings (`int`, *optional*, defaults to 8192): - The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - pad_token_id (`int`, *optional*, defaults to 0): - Padding token id. - eos_token_id (`int`, *optional*, defaults to 1): - End of stream token id. - bos_token_id (`int`, *optional*, defaults to 2): - Beginning of stream token id. - tie_word_embeddings (`bool`, *optional*, defaults to `True`): - Whether to tie weight embeddings - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): - Whether to use a bias in the query, key, value and output projection layers during self-attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - use_adarms (`bool`, *optional*, defaults to `False`): - Whether to use ADARMS. - adarms_cond_dim (`int`, *optional*, defaults to `None`): - The dimension of the ADARMS condition. - ```python - >>> from transformers import GemmaModel, GemmaConfig - >>> # Initializing a Gemma gemma-7b style configuration - >>> configuration = GemmaConfig() - >>> # Initializing a model from the gemma-7b style configuration - >>> model = GemmaModel(configuration) - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "gemma" - keys_to_ignore_at_inference = ["past_key_values"] - base_model_tp_plan = { - "layers.*.self_attn.q_proj": "colwise", - "layers.*.self_attn.k_proj": "colwise", - "layers.*.self_attn.v_proj": "colwise", - "layers.*.self_attn.o_proj": "rowwise", - "layers.*.mlp.gate_proj": "colwise", - "layers.*.mlp.up_proj": "colwise", - "layers.*.mlp.down_proj": "rowwise", - } - base_model_pp_plan = { - "embed_tokens": (["input_ids"], ["inputs_embeds"]), - "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), - "norm": (["hidden_states"], ["hidden_states"]), - } - - def __init__( - self, - vocab_size=256000, - hidden_size=3072, - intermediate_size=24576, - num_hidden_layers=28, - num_attention_heads=16, - num_key_value_heads=16, - head_dim=256, - hidden_act="gelu_pytorch_tanh", - hidden_activation=None, - max_position_embeddings=8192, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - pad_token_id=0, - eos_token_id=1, - bos_token_id=2, - tie_word_embeddings=True, - rope_theta=10000.0, - attention_bias=False, - attention_dropout=0.0, - use_adarms: bool = False, - adarms_cond_dim: Optional[int] = None, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.head_dim = head_dim - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.hidden_activation = hidden_activation - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.rope_theta = rope_theta - self.attention_bias = attention_bias - self.attention_dropout = attention_dropout - self.use_adarms = use_adarms - self.adarms_cond_dim = adarms_cond_dim - - # Set default for adarms_cond_dim if use_adarms is True - if self.use_adarms and self.adarms_cond_dim is None: - self.adarms_cond_dim = self.hidden_size - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - -__all__ = ["GemmaConfig"] \ No newline at end of file diff --git a/capvector-pi05/src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py b/capvector-pi05/src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py deleted file mode 100644 index 8377a5bf8562945fe0f3b2c3545a91c7d7ac9238..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py +++ /dev/null @@ -1,862 +0,0 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py. -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_gemma.py file directly. One of our CI enforces this. -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# coding=utf-8 -# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. -# -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Callable, Optional, Union - -import torch -from torch import nn - -from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache -from ...generation import GenerationMixin -from ...masking_utils import create_causal_mask -from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - SequenceClassifierOutputWithPast, - TokenClassifierOutput, -) -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...processing_utils import Unpack -from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging -from .configuration_gemma import GemmaConfig - - -logger = logging.get_logger(__name__) - - -class GemmaRMSNorm(nn.Module): - def __init__(self, dim: int, eps: float = 1e-6, cond_dim: Optional[int] = None): - super().__init__() - self.eps = eps - self.dim = dim - self.cond_dim = cond_dim - - # Dense layer for adaptive normalization (if cond_dim is provided) - if cond_dim is not None: - #self.dense = nn.Linear(cond_dim, dim * 3, bias=True, dtype=torch.bfloat16) - self.dense = nn.Linear(cond_dim, dim * 3, bias=True) - # Initialize with zeros (matches source implementation) - nn.init.zeros_(self.dense.weight) - else: - self.weight = nn.Parameter(torch.zeros(dim, dtype=torch.bfloat16)) - self.dense = None - - def _norm(self, x): - # Compute variance in float32 (like the source implementation) - var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True) - # Compute normalization in float32 - normed_inputs = x * torch.rsqrt(var + self.eps) - return normed_inputs - - def forward(self, x, cond=None): - dtype = x.dtype # original dtype, could be half-precision - normed_inputs = self._norm(x) - - if cond is None or self.dense is None: - # regular RMSNorm - # scale by learned parameter in float32 (matches source implementation) - normed_inputs = normed_inputs * (1.0 + self.weight.float()) - return normed_inputs.to(dtype), None # return in original dtype with None gate - - # adaptive RMSNorm (if cond is provided and dense layer exists) - if cond.shape[-1] != self.cond_dim: - raise ValueError(f"Expected cond dimension {self.cond_dim}, got {cond.shape[-1]}") - - #self.dense.to(dtype=torch.bfloat16).to(dtype=torch.float32) - modulation = self.dense(cond) - # Reshape modulation to broadcast properly: [batch, 1, features] for [batch, seq, features] - if len(x.shape) == 3: # [batch, seq, features] - modulation = modulation.unsqueeze(1) - - scale, shift, gate = torch.chunk(modulation, 3, dim=-1) - - # Apply adaptive normalization: use model weight dtype to ensure compatibility - # model_dtype = self.dense.weight.dtype # Use the model's dtype (bfloat16) - # scale = scale.to(model_dtype) - # shift = shift.to(model_dtype) - # gate = gate.to(model_dtype) - # normed_inputs = normed_inputs.to(model_dtype) # Convert normed_inputs to model dtype - - normed_inputs = normed_inputs * (1 + scale.to(torch.float32)) + shift.to(torch.float32) - - return normed_inputs.to(dtype), gate.to(dtype) - - def extra_repr(self): - repr_str = f"{tuple(self.weight.shape)}, eps={self.eps}" - if self.dense is not None: - repr_str += f", adaptive=True, cond_dim={self.cond_dim}" - return repr_str - - -class GemmaMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - -class GemmaRotaryEmbedding(nn.Module): - def __init__(self, config: GemmaConfig, device=None): - super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - @torch.no_grad() - @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) - def forward(self, x, position_ids): - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) - position_ids_expanded = position_ids[:, None, :].float() - - device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def _gated_residual(x, y, gate): - """ - Applies gated residual connection with optional gate parameter. - - Args: - x: Input tensor (residual) - y: Output tensor to be added - gate: Optional gate tensor to modulate the addition - - Returns: - x + y if gate is None, otherwise x + y * gate - """ - if x is None and y is None: - return None - if x is None or y is None: - return x if x is not None else y - if gate is None: - return x + y - return x + y * gate - - -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - **kwargs, -): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - -class GemmaAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: GemmaConfig, layer_idx: int): - super().__init__() - self.config = config - self.layer_idx = layer_idx - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.scaling = self.head_dim**-0.5 - self.attention_dropout = config.attention_dropout - self.is_causal = True - - self.q_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias - ) - self.k_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.v_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias - ) - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, - cache_position: Optional[torch.LongTensor] = None, - use_cache: bool = False, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) - - query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - # Use cache if provided - if past_key_value is not None: - if use_cache: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - else: - key_states = torch.cat([past_key_value[self.layer_idx][0], key_states], dim=2) - value_states = torch.cat([past_key_value[self.layer_idx][1], value_states], dim=2) - - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - **kwargs, - ) - - attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights - - -class GemmaDecoderLayer(GradientCheckpointingLayer): - def __init__(self, config: GemmaConfig, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx) - - self.mlp = GemmaMLP(config) - cond_dim = getattr(config, 'adarms_cond_dim', None) if getattr(config, 'use_adarms', False) else None - self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim) - self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - adarms_cond: Optional[torch.Tensor] = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: - residual = hidden_states - hidden_states, gate = self.input_layernorm(hidden_states, adarms_cond) - - # Self Attention - hidden_states, self_attn_weights = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - hidden_states = _gated_residual(residual, hidden_states, gate) - - # Fully Connected - residual = hidden_states - hidden_states, gate = self.post_attention_layernorm(hidden_states, adarms_cond) - hidden_states = self.mlp(hidden_states) - hidden_states = _gated_residual(residual, hidden_states, gate) - - outputs = (hidden_states,) - if output_attentions: - outputs += (self_attn_weights,) - - return outputs - - -@auto_docstring -class GemmaPreTrainedModel(PreTrainedModel): - config_class = GemmaConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["GemmaDecoderLayer"] - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_3 = True - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True - _supports_static_cache = True - _supports_attention_backend = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, GemmaRMSNorm): - if hasattr(module, 'weight'): - module.weight.data.fill_(1.0) - - -@auto_docstring -class GemmaModel(GemmaPreTrainedModel): - def __init__(self, config: GemmaConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - - cond_dim = getattr(config, 'adarms_cond_dim', None) if getattr(config, 'use_adarms', False) else None - self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim) - self.rotary_emb = GemmaRotaryEmbedding(config=config) - self.gradient_checkpointing = False - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - @can_return_tuple - @auto_docstring - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - adarms_cond: Optional[torch.Tensor] = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> BaseModelOutputWithPast: - """ - adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): - Condition for ADARMS. - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) - use_cache = False - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if use_cache and past_key_values is None: - past_key_values = DynamicCache() - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - causal_mask = create_causal_mask( - config=self.config, - input_embeds=inputs_embeds, - attention_mask=attention_mask, - cache_position=cache_position, - past_key_values=past_key_values, - position_ids=position_ids, - ) - - # embed positions - hidden_states = inputs_embeds - # Convert to bfloat16 if the first layer uses bfloat16 - if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16: - hidden_states = hidden_states.to(torch.bfloat16) - - # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - # normalized - # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 - # See https://github.com/huggingface/transformers/pull/29402 - normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) - #hidden_states = hidden_states * normalizer - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - adarms_cond=adarms_cond, - **kwargs, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states, _ = self.norm(hidden_states, adarms_cond) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_values if use_cache else None, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... - - -@auto_docstring -class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] - _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - - def __init__(self, config): - super().__init__(config) - self.model = GemmaModel(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @can_return_tuple - @auto_docstring - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - adarms_cond: Optional[torch.Tensor] = None, - **kwargs: Unpack[KwargsForCausalLM], - ) -> CausalLMOutputWithPast: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): - Condition for ADARMS. - - Example: - - ```python - >>> from transformers import AutoTokenizer, GemmaForCausalLM - - >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b") - >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b") - - >>> prompt = "What is your favorite condiment?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "What is your favorite condiment?" - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs: BaseModelOutputWithPast = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - cache_position=cache_position, - adarms_cond=adarms_cond, - **kwargs, - ) - - hidden_states = outputs.last_hidden_state - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - logits = self.lm_head(hidden_states[:, slice_indices, :]) - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@auto_docstring( - custom_intro=""" - The Gemma Model transformer with a sequence classification head on top (linear layer). - - [`GemmaForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """ -) -class GemmaForSequenceClassification(GemmaPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = GemmaModel(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @can_return_tuple - @auto_docstring - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - adarms_cond: Optional[torch.Tensor] = None, - ) -> SequenceClassifierOutputWithPast: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - - adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): - Condition for ADARMS. - """ - - transformer_outputs: BaseModelOutputWithPast = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - adarms_cond=adarms_cond, - ) - hidden_states = transformer_outputs.last_hidden_state - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - last_non_pad_token = -1 - elif input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) - last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - -@auto_docstring -class GemmaForTokenClassification(GemmaPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = GemmaModel(config) - if getattr(config, "classifier_dropout", None) is not None: - classifier_dropout = config.classifier_dropout - elif getattr(config, "hidden_dropout", None) is not None: - classifier_dropout = config.hidden_dropout - else: - classifier_dropout = 0.1 - self.dropout = nn.Dropout(classifier_dropout) - self.score = nn.Linear(config.hidden_size, config.num_labels) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @can_return_tuple - @auto_docstring - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - adarms_cond: Optional[torch.Tensor] = None, - ) -> TokenClassifierOutput: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - - adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): - Condition for ADARMS. - """ - - outputs: BaseModelOutputWithPast = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - adarms_cond=adarms_cond, - ) - sequence_output = outputs.last_hidden_state - sequence_output = self.dropout(sequence_output) - logits = self.score(sequence_output) - - loss = None - if labels is not None: - loss = self.loss_function(logits, labels, self.config) - - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -__all__ = [ - "GemmaModel", - "GemmaForCausalLM", - "GemmaForSequenceClassification", - "GemmaForTokenClassification", - "GemmaPreTrainedModel", -] diff --git a/capvector-pi05/src/openpi/models_pytorch/transformers_replace/models/paligemma/modeling_paligemma.py b/capvector-pi05/src/openpi/models_pytorch/transformers_replace/models/paligemma/modeling_paligemma.py deleted file mode 100644 index a627b73246277095e3354b93158c98e1fa776897..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/models_pytorch/transformers_replace/models/paligemma/modeling_paligemma.py +++ /dev/null @@ -1,622 +0,0 @@ -# coding=utf-8 -# Copyright 2024 the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch PaliGemmamodel.""" - -from dataclasses import dataclass -from typing import Optional, Union - -import torch -import torch.utils.checkpoint -from torch import nn - -from ...cache_utils import Cache, HybridCache, StaticCache -from ...generation import GenerationMixin -from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import BaseModelOutputWithPast -from ...modeling_utils import PreTrainedModel -from ...processing_utils import Unpack -from ...utils import LossKwargs, ModelOutput, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging -from ..auto import AutoModel -from .configuration_paligemma import PaliGemmaConfig - - -logger = logging.get_logger(__name__) - - -@dataclass -@auto_docstring( - custom_intro=""" - Base class for Paligemma outputs, with hidden states and attentions. - """ -) -class PaligemmaModelOutputWithPast(BaseModelOutputWithPast): - r""" - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - image_hidden_states (`torch.FloatTensor`, *optional*): - A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. - image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. - """ - - image_hidden_states: Optional[torch.FloatTensor] = None - - -@dataclass -@auto_docstring( - custom_intro=""" - Base class for PaliGemma causal language model (or autoregressive) outputs. - """ -) -class PaliGemmaCausalLMOutputWithPast(ModelOutput): - r""" - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - image_hidden_states (`torch.FloatTensor`, *optional*): - A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. - image_hidden_states of the model produced by the vision encoder after projecting last hidden state. - """ - - loss: Optional[torch.FloatTensor] = None - logits: Optional[torch.FloatTensor] = None - past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None - hidden_states: Optional[tuple[torch.FloatTensor]] = None - attentions: Optional[tuple[torch.FloatTensor]] = None - image_hidden_states: Optional[torch.FloatTensor] = None - - -class PaliGemmaMultiModalProjector(nn.Module): - def __init__(self, config: PaliGemmaConfig): - super().__init__() - self.linear = nn.Linear(config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True) - - def forward(self, image_features): - hidden_states = self.linear(image_features) - - return hidden_states - - -@auto_docstring -class PaliGemmaPreTrainedModel(PreTrainedModel): - config_class = PaliGemmaConfig - base_model_prefix = "" - supports_gradient_checkpointing = True - _no_split_modules = ["PaliGemmaMultiModalProjector"] - _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True - _supports_quantized_cache = True - _supports_static_cache = True - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_flex_attn = True - _supports_attention_backend = True - - def _init_weights(self, module): - # important: this ported version of PaliGemmaisn't meant for training from scratch - only - # inference and fine-tuning - std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) - - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - - -@auto_docstring( - custom_intro=""" - The Base Paligemma model which consists of a vision backbone and a language model withou language modeling head., - """ -) -class PaliGemmaModel(PaliGemmaPreTrainedModel): - _checkpoint_conversion_mapping = {"language_model.model": "language_model"} - # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch - accepts_loss_kwargs = False - - def __init__(self, config: PaliGemmaConfig): - super().__init__(config) - self.vision_tower = AutoModel.from_config(config=config.vision_config) - self.multi_modal_projector = PaliGemmaMultiModalProjector(config) - self.vocab_size = config.text_config.vocab_size - - language_model = AutoModel.from_config(config=config.text_config) - self.language_model = language_model - - self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 - self.post_init() - - # Copied from transformers.models.llava.modeling_llava.LlavaModel.get_input_embeddings with Llava->PaliGemma - def get_input_embeddings(self): - return self.language_model.get_input_embeddings() - - # Copied from transformers.models.llava.modeling_llava.LlavaModel.set_input_embeddings with Llava->PaliGemma - def set_input_embeddings(self, value): - self.language_model.set_input_embeddings(value) - - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - - def _update_causal_mask( - self, - attention_mask, - token_type_ids=None, - past_key_values=None, - cache_position=None, - input_tensor=None, - is_training: Optional[bool] = None, - ): - if self.config.text_config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - is_training = is_training if is_training is not None else self.training - using_static_cache = isinstance(past_key_values, StaticCache) - min_dtype = torch.finfo(self.dtype).min - if input_tensor is None: - input_tensor = attention_mask - - inputs_lead_dim, sequence_length = input_tensor.shape[:2] - if using_static_cache: - target_length = past_key_values.get_max_cache_shape() - elif isinstance(past_key_values, HybridCache): - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else cache_position[0] + sequence_length + 1 - ) - - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - return attention_mask - - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device - ) - # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below - if sequence_length != 1: - if is_training: - causal_mask = torch.triu(causal_mask, diagonal=1) - else: - causal_mask[:, :sequence_length] = 0.0 - - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - - # First unmask prefix tokens during training - if is_training: - if token_type_ids is None: - raise ValueError("Token type ids must be provided during training") - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0 - ) - - # Then apply padding mask (will mask pad tokens) - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - def get_image_features(self, pixel_values: torch.FloatTensor): - """ - Obtains image last hidden states from the vision tower and apply multimodal projection. - - Args: - pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) - The tensors corresponding to the input images. - Returns: - image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). - """ - image_outputs = self.vision_tower(pixel_values) - selected_image_feature = image_outputs.last_hidden_state - image_features = self.multi_modal_projector(selected_image_feature) - return image_features - - @can_return_tuple - @auto_docstring - def forward( - self, - input_ids: torch.LongTensor = None, - pixel_values: torch.FloatTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None, - token_type_ids: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[tuple, PaligemmaModelOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration - - >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224") - >>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224") - - >>> prompt = "Where is the cat standing?" - >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, text=prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(**inputs,) - >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Where is the cat standing?\nsnow" - ```""" - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - is_training = token_type_ids is not None and labels is not None - - # Replace image id woth PAD if the image token if OOV, to avoid index-errors - if input_ids is not None and self.config.image_token_id >= self.vocab_size: - special_image_mask = input_ids == self.config.image_token_id - llm_input_ids = input_ids.clone() - llm_input_ids[special_image_mask] = 0 - else: - llm_input_ids = input_ids - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(llm_input_ids) - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - - if position_ids is None: - position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed - - # Merge text and images - if pixel_values is not None: - image_features = self.get_image_features(pixel_values) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - else: - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] - raise ValueError( - f"Number of images does not match number of special image tokens in the input text. " - f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " - "tokens from image embeddings." - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - - causal_mask = self._update_causal_mask( - attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training - ) - outputs = self.language_model( - attention_mask=causal_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=True, - cache_position=cache_position, - **kwargs, - ) - - return PaligemmaModelOutputWithPast( - last_hidden_state=outputs.last_hidden_state, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, - ) - - -class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... - - -@auto_docstring( - custom_intro=""" - The Base Paligemma model which consists of a vision backbone and a language model without language modeling head., - """ -) -class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin): - _checkpoint_conversion_mapping = { - "^language_model.model": "model.language_model", - "^vision_tower": "model.vision_tower", - "^multi_modal_projector": "model.multi_modal_projector", - "^language_model.lm_head": "lm_head", - } - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config: PaliGemmaConfig): - super().__init__(config) - self.model = PaliGemmaModel(config) - self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) - self.post_init() - - def get_input_embeddings(self): - return self.model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.model.set_input_embeddings(value) - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - - def get_image_features(self, pixel_values): - return self.model.get_image_features(pixel_values) - - # Make modules available throught conditional class for BC - @property - def language_model(self): - return self.model.language_model - - @property - def vision_tower(self): - return self.model.vision_tower - - @property - def multi_modal_projector(self): - return self.model.multi_modal_projector - - @can_return_tuple - @auto_docstring - def forward( - self, - input_ids: torch.LongTensor = None, - pixel_values: torch.FloatTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None, - token_type_ids: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[tuple, PaliGemmaCausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration - - >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224") - >>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224") - - >>> prompt = "Where is the cat standing?" - >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, text=prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(**inputs,) - >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Where is the cat standing?\nsnow" - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.model( - input_ids=input_ids, - pixel_values=pixel_values, - token_type_ids=token_type_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - labels=labels, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=True, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs[0] - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - logits = self.lm_head(hidden_states[:, slice_indices, :]) - - loss = None - if labels is not None: - loss = self.loss_function( - logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs - ) - - return PaliGemmaCausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=outputs.image_hidden_states, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - inputs_embeds=None, - cache_position=None, - position_ids=None, - pixel_values=None, - attention_mask=None, - token_type_ids=None, - use_cache=True, - logits_to_keep=None, - labels=None, - **kwargs, - ): - # Overwritten -- custom `position_ids` and `pixel_values` handling - model_inputs = super().prepare_inputs_for_generation( - input_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - position_ids=position_ids, - cache_position=cache_position, - use_cache=use_cache, - logits_to_keep=logits_to_keep, - token_type_ids=token_type_ids, - **kwargs, - ) - - # position_ids in Paligemma are 1-indexed - if model_inputs.get("position_ids") is not None: - model_inputs["position_ids"] += 1 - # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore - # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always - if cache_position[0] == 0: - model_inputs["pixel_values"] = pixel_values - is_training = token_type_ids is not None and labels is not None - if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): - input_tensor = inputs_embeds if inputs_embeds is not None else input_ids - causal_mask = self.model._update_causal_mask( - attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training - ) - model_inputs["attention_mask"] = causal_mask - - return model_inputs - - @staticmethod - # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - -__all__ = ["PaliGemmaForConditionalGeneration", "PaliGemmaPreTrainedModel", "PaliGemmaModel"] diff --git a/capvector-pi05/src/openpi/models_pytorch/transformers_replace/models/siglip/check.py b/capvector-pi05/src/openpi/models_pytorch/transformers_replace/models/siglip/check.py deleted file mode 100644 index 89cc2ad4359d5273f3631410cbebbe845100cce6..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/models_pytorch/transformers_replace/models/siglip/check.py +++ /dev/null @@ -1,4 +0,0 @@ -import transformers - -def check_whether_transformers_replace_is_installed_correctly(): - return transformers.__version__ == "4.53.2" \ No newline at end of file diff --git a/capvector-pi05/src/openpi/models_pytorch/transformers_replace/models/siglip/modeling_siglip.py b/capvector-pi05/src/openpi/models_pytorch/transformers_replace/models/siglip/modeling_siglip.py deleted file mode 100644 index 0bf8bec4a068c964dd038ce9060513f61a0b8ff4..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/models_pytorch/transformers_replace/models/siglip/modeling_siglip.py +++ /dev/null @@ -1,1237 +0,0 @@ -# coding=utf-8 -# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch Siglip model.""" - -import math -import warnings -from dataclasses import dataclass -from typing import Any, Callable, Optional, Union - -import numpy as np -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from torch.nn.init import _calculate_fan_in_and_fan_out - -from ...activations import ACT2FN -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask -from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int -from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig - - -logger = logging.get_logger(__name__) - - -def _trunc_normal_(tensor, mean, std, a, b): - # Cut & paste from PyTorch official master until it's in a few official releases - RW - # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - def norm_cdf(x): - # Computes standard normal cumulative distribution function - return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 - - if (mean < a - 2 * std) or (mean > b + 2 * std): - warnings.warn( - "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " - "The distribution of values may be incorrect.", - stacklevel=2, - ) - - # Values are generated by using a truncated uniform distribution and - # then using the inverse CDF for the normal distribution. - # Get upper and lower cdf values - l = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) - - # Uniformly fill tensor with values from [l, u], then translate to - # [2l-1, 2u-1]. - tensor.uniform_(2 * l - 1, 2 * u - 1) - - # Use inverse cdf transform for normal distribution to get truncated - # standard normal - tensor.erfinv_() - - # Transform to proper mean, std - tensor.mul_(std * math.sqrt(2.0)) - tensor.add_(mean) - - # Clamp to ensure it's in the proper range - tensor.clamp_(min=a, max=b) - - -def trunc_normal_tf_( - tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 -) -> torch.Tensor: - """Fills the input Tensor with values drawn from a truncated - normal distribution. The values are effectively drawn from the - normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` - with values outside :math:`[a, b]` redrawn until they are within - the bounds. The method used for generating the random values works - best when :math:`a \\leq \text{mean} \\leq b`. - - NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the - bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 - and the result is subsequently scaled and shifted by the mean and std args. - - Args: - tensor: an n-dimensional `torch.Tensor` - mean: the mean of the normal distribution - std: the standard deviation of the normal distribution - a: the minimum cutoff value - b: the maximum cutoff value - """ - with torch.no_grad(): - _trunc_normal_(tensor, 0, 1.0, a, b) - tensor.mul_(std).add_(mean) - - -def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): - fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) - if mode == "fan_in": - denom = fan_in - elif mode == "fan_out": - denom = fan_out - elif mode == "fan_avg": - denom = (fan_in + fan_out) / 2 - - variance = scale / denom - - if distribution == "truncated_normal": - # constant is stddev of standard normal truncated to (-2, 2) - trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) - elif distribution == "normal": - with torch.no_grad(): - tensor.normal_(std=math.sqrt(variance)) - elif distribution == "uniform": - bound = math.sqrt(3 * variance) - with torch.no_grad(): - tensor.uniform_(-bound, bound) - else: - raise ValueError(f"invalid distribution {distribution}") - - -def lecun_normal_(tensor): - variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") - - -def default_flax_embed_init(tensor): - variance_scaling_(tensor, mode="fan_in", distribution="normal") - - -@dataclass -@auto_docstring( - custom_intro=""" - Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. - """ -) -# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip -class SiglipVisionModelOutput(ModelOutput): - r""" - image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): - The image embeddings obtained by applying the projection layer to the pooler_output. - """ - - image_embeds: Optional[torch.FloatTensor] = None - last_hidden_state: Optional[torch.FloatTensor] = None - hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None - attentions: Optional[tuple[torch.FloatTensor, ...]] = None - - -@dataclass -@auto_docstring( - custom_intro=""" - Base class for text model's outputs that also contains a pooling of the last hidden states. - """ -) -# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip -class SiglipTextModelOutput(ModelOutput): - r""" - text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): - The text embeddings obtained by applying the projection layer to the pooler_output. - """ - - text_embeds: Optional[torch.FloatTensor] = None - last_hidden_state: Optional[torch.FloatTensor] = None - hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None - attentions: Optional[tuple[torch.FloatTensor, ...]] = None - - -@dataclass -@auto_docstring -# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip -class SiglipOutput(ModelOutput): - r""" - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): - Contrastive loss for image-text similarity. - logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): - The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text - similarity scores. - logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): - The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image - similarity scores. - text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): - The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`]. - image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): - The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`]. - text_model_output (`BaseModelOutputWithPooling`): - The output of the [`SiglipTextModel`]. - vision_model_output (`BaseModelOutputWithPooling`): - The output of the [`SiglipVisionModel`]. - """ - - loss: Optional[torch.FloatTensor] = None - logits_per_image: Optional[torch.FloatTensor] = None - logits_per_text: Optional[torch.FloatTensor] = None - text_embeds: Optional[torch.FloatTensor] = None - image_embeds: Optional[torch.FloatTensor] = None - text_model_output: BaseModelOutputWithPooling = None - vision_model_output: BaseModelOutputWithPooling = None - - def to_tuple(self) -> tuple[Any]: - return tuple( - self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() - for k in self.keys() - ) - - -class SiglipVisionEmbeddings(nn.Module): - def __init__(self, config: SiglipVisionConfig): - super().__init__() - self.config = config - self.embed_dim = config.hidden_size - self.image_size = config.image_size - self.patch_size = config.patch_size - - self.patch_embedding = nn.Conv2d( - in_channels=config.num_channels, - out_channels=self.embed_dim, - kernel_size=self.patch_size, - stride=self.patch_size, - padding="valid", - ) - - self.num_patches = (self.image_size // self.patch_size) ** 2 - self.num_positions = self.num_patches - self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) - self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) - - def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: - """ - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution - images. This method is also adapted to support torch.jit tracing and no class embeddings. - - Adapted from: - - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and - - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 - """ - - num_patches = embeddings.shape[1] - num_positions = self.position_embedding.weight.shape[0] - - # always interpolate when tracing to ensure the exported model works for dynamic input shapes - if not torch.jit.is_tracing() and num_patches == num_positions and height == width: - return self.position_embedding(self.position_ids) - - patch_pos_embed = self.position_embedding.weight.unsqueeze(0) - - dim = embeddings.shape[-1] - - new_height = height // self.patch_size - new_width = width // self.patch_size - - sqrt_num_positions = torch_int(num_positions**0.5) - patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) - patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) - - patch_pos_embed = nn.functional.interpolate( - patch_pos_embed, - size=(new_height, new_width), - mode="bicubic", - align_corners=False, - ) - - patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return patch_pos_embed - - def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: - _, _, height, width = pixel_values.shape - target_dtype = self.patch_embedding.weight.dtype - patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] - embeddings = patch_embeds.flatten(2).transpose(1, 2) - - if interpolate_pos_encoding: - embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) - else: - embeddings = embeddings + self.position_embedding(self.position_ids) - return embeddings - - -# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip -class SiglipTextEmbeddings(nn.Module): - def __init__(self, config: SiglipTextConfig): - super().__init__() - embed_dim = config.hidden_size - - self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) - self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) - - # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.register_buffer( - "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False - ) - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - ) -> torch.Tensor: - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] - max_position_embedding = self.position_embedding.weight.shape[0] - - if seq_length > max_position_embedding: - raise ValueError( - f"Sequence length must be less than max_position_embeddings (got `sequence length`: " - f"{seq_length} and max_position_embeddings: {max_position_embedding}" - ) - - if position_ids is None: - position_ids = self.position_ids[:, :seq_length] - - if inputs_embeds is None: - inputs_embeds = self.token_embedding(input_ids) - - position_embeddings = self.position_embedding(position_ids) - embeddings = inputs_embeds + position_embeddings - - return embeddings - - -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - **kwargs, -): - attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - - attn_output = torch.matmul(attn_weights, value) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - -class SiglipAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config): - super().__init__() - self.config = config - self.embed_dim = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.embed_dim // self.num_heads - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads})." - ) - self.scale = self.head_dim**-0.5 - self.dropout = config.attention_dropout - self.is_causal = False - - self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - """Input shape: Batch x Time x Channel""" - - batch_size, seq_length, embed_dim = hidden_states.shape - - queries = self.q_proj(hidden_states) - keys = self.k_proj(hidden_states) - values = self.v_proj(hidden_states) - - queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and output_attentions: - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - queries, - keys, - values, - attention_mask, - is_causal=self.is_causal, - scaling=self.scale, - dropout=0.0 if not self.training else self.dropout, - ) - - attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() - attn_output = self.out_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights - - -# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip -class SiglipMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.activation_fn = ACT2FN[config.hidden_act] - self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) - self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.fc1(hidden_states) - hidden_states = self.activation_fn(hidden_states) - hidden_states = self.fc2(hidden_states) - return hidden_states - - -class SiglipEncoderLayer(GradientCheckpointingLayer): - def __init__(self, config: Union[SiglipVisionConfig, SiglipTextConfig]): - super().__init__() - self.embed_dim = config.hidden_size - self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) - self.self_attn = SiglipAttention(config) - self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) - self.mlp = SiglipMLP(config) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - output_attentions: Optional[bool] = False, - ) -> tuple[torch.FloatTensor]: - """ - Args: - hidden_states (`torch.FloatTensor`): - Input to the layer of shape `(batch, seq_len, embed_dim)`. - attention_mask (`torch.FloatTensor`): - Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - """ - residual = hidden_states - - hidden_states = self.layer_norm1(hidden_states) - hidden_states, attn_weights = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.layer_norm2(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - -@auto_docstring -class SiglipPreTrainedModel(PreTrainedModel): - config_class = SiglipConfig - base_model_prefix = "siglip" - supports_gradient_checkpointing = True - - _no_split_modules = [ - "SiglipTextEmbeddings", - "SiglipEncoderLayer", - "SiglipVisionEmbeddings", - "SiglipEncoderLayer", - "SiglipMultiheadAttentionPoolingHead", - ] - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_flex_attn = True - _supports_attention_backend = True - - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, SiglipVisionEmbeddings): - width = ( - self.config.vision_config.hidden_size - if isinstance(self.config, SiglipConfig) - else self.config.hidden_size - ) - nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) - elif isinstance(module, nn.Embedding): - default_flax_embed_init(module.weight) - elif isinstance(module, SiglipAttention): - nn.init.xavier_uniform_(module.q_proj.weight) - nn.init.xavier_uniform_(module.k_proj.weight) - nn.init.xavier_uniform_(module.v_proj.weight) - nn.init.xavier_uniform_(module.out_proj.weight) - nn.init.zeros_(module.q_proj.bias) - nn.init.zeros_(module.k_proj.bias) - nn.init.zeros_(module.v_proj.bias) - nn.init.zeros_(module.out_proj.bias) - elif isinstance(module, SiglipMLP): - nn.init.xavier_uniform_(module.fc1.weight) - nn.init.xavier_uniform_(module.fc2.weight) - nn.init.normal_(module.fc1.bias, std=1e-6) - nn.init.normal_(module.fc2.bias, std=1e-6) - elif isinstance(module, SiglipMultiheadAttentionPoolingHead): - nn.init.xavier_uniform_(module.probe.data) - nn.init.xavier_uniform_(module.attention.in_proj_weight.data) - nn.init.zeros_(module.attention.in_proj_bias.data) - elif isinstance(module, SiglipModel): - logit_scale_init = torch.log(torch.tensor(1.0)) - module.logit_scale.data.fill_(logit_scale_init) - module.logit_bias.data.zero_() - elif isinstance(module, SiglipForImageClassification): - nn.init.normal_( - module.classifier.weight, - std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor, - ) - elif isinstance(module, (nn.Linear, nn.Conv2d)): - lecun_normal_(module.weight) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - - -# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Siglip -class SiglipEncoder(nn.Module): - """ - Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a - [`SiglipEncoderLayer`]. - - Args: - config: SiglipConfig - """ - - def __init__(self, config: SiglipConfig): - super().__init__() - self.config = config - self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) - self.gradient_checkpointing = False - - # Ignore copy - @can_return_tuple - def forward( - self, - inputs_embeds, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - ) -> BaseModelOutput: - r""" - Args: - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - hidden_states = inputs_embeds - for encoder_layer in self.layers: - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - return BaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=encoder_states, - attentions=all_attentions, - ) - - -class SiglipTextTransformer(nn.Module): - def __init__(self, config: SiglipTextConfig): - super().__init__() - self.config = config - embed_dim = config.hidden_size - self.embeddings = SiglipTextEmbeddings(config) - self.encoder = SiglipEncoder(config) - self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) - - self.head = nn.Linear(embed_dim, config.projection_size) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - - @can_return_tuple - @auto_docstring - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - ) -> BaseModelOutputWithPooling: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - if input_ids is None: - raise ValueError("You have to specify input_ids") - - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - - hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) - - # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model. - # expand attention_mask - if attention_mask is not None and not self._use_flash_attention_2: - # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) - - encoder_outputs: BaseModelOutput = self.encoder( - inputs_embeds=hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - last_hidden_state = encoder_outputs.last_hidden_state - last_hidden_state = self.final_layer_norm(last_hidden_state) - - # Assuming "sticky" EOS tokenization, last token is always EOS. - pooled_output = last_hidden_state[:, -1, :] - pooled_output = self.head(pooled_output) - - return BaseModelOutputWithPooling( - last_hidden_state=last_hidden_state, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - -@auto_docstring( - custom_intro=""" - The text model from SigLIP without any head or projection on top. - """ -) -class SiglipTextModel(SiglipPreTrainedModel): - config_class = SiglipTextConfig - - def __init__(self, config: SiglipTextConfig): - super().__init__(config) - self.text_model = SiglipTextTransformer(config) - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self) -> nn.Module: - return self.text_model.embeddings.token_embedding - - def set_input_embeddings(self, value): - self.text_model.embeddings.token_embedding = value - - @can_return_tuple - @auto_docstring - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - ) -> BaseModelOutputWithPooling: - r""" - Examples: - - ```python - >>> from transformers import AutoTokenizer, SiglipTextModel - - >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224") - >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") - - >>> # important: make sure to set padding="max_length" as that's how the model was trained - >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") - - >>> outputs = model(**inputs) - >>> last_hidden_state = outputs.last_hidden_state - >>> pooled_output = outputs.pooler_output # pooled (EOS token) states - ```""" - - return self.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - -class SiglipVisionTransformer(nn.Module): - def __init__(self, config: SiglipVisionConfig): - super().__init__() - self.config = config - embed_dim = config.hidden_size - - self.embeddings = SiglipVisionEmbeddings(config) - self.encoder = SiglipEncoder(config) - self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) - self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head - if self.use_head: - self.head = SiglipMultiheadAttentionPoolingHead(config) - - @can_return_tuple - @auto_docstring - def forward( - self, - pixel_values, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - interpolate_pos_encoding: Optional[bool] = False, - ) -> BaseModelOutputWithPooling: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) - # Convert to bfloat16 if the encoder uses bfloat16 - if len(self.encoder.layers) > 0 and self.encoder.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16: - hidden_states = hidden_states.to(torch.bfloat16) - - encoder_outputs: BaseModelOutput = self.encoder( - inputs_embeds=hidden_states, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - last_hidden_state = encoder_outputs.last_hidden_state - last_hidden_state = self.post_layernorm(last_hidden_state) - - pooler_output = self.head(last_hidden_state) if self.use_head else None - - return BaseModelOutputWithPooling( - last_hidden_state=last_hidden_state, - pooler_output=pooler_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - -class SiglipMultiheadAttentionPoolingHead(nn.Module): - """Multihead Attention Pooling.""" - - def __init__(self, config: SiglipVisionConfig): - super().__init__() - - self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) - self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) - self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.mlp = SiglipMLP(config) - - def forward(self, hidden_state): - batch_size = hidden_state.shape[0] - probe = self.probe.repeat(batch_size, 1, 1) - - hidden_state = self.attention(probe, hidden_state, hidden_state)[0] - - residual = hidden_state - hidden_state = self.layernorm(hidden_state) - hidden_state = residual + self.mlp(hidden_state) - - return hidden_state[:, 0] - - -@auto_docstring( - custom_intro=""" - The vision model from SigLIP without any head or projection on top. - """ -) -class SiglipVisionModel(SiglipPreTrainedModel): - config_class = SiglipVisionConfig - main_input_name = "pixel_values" - - def __init__(self, config: SiglipVisionConfig): - super().__init__(config) - - self.vision_model = SiglipVisionTransformer(config) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self) -> nn.Module: - return self.vision_model.embeddings.patch_embedding - - @can_return_tuple - @auto_docstring - def forward( - self, - pixel_values, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - interpolate_pos_encoding: bool = False, - ) -> BaseModelOutputWithPooling: - r""" - Examples: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, SiglipVisionModel - - >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224") - >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, return_tensors="pt") - - >>> outputs = model(**inputs) - >>> last_hidden_state = outputs.last_hidden_state - >>> pooled_output = outputs.pooler_output # pooled features - ```""" - - return self.vision_model( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - interpolate_pos_encoding=interpolate_pos_encoding, - ) - - -@auto_docstring -class SiglipModel(SiglipPreTrainedModel): - config_class = SiglipConfig - - def __init__(self, config: SiglipConfig): - super().__init__(config) - - if not isinstance(config.text_config, SiglipTextConfig): - raise TypeError( - "config.text_config is expected to be of type SiglipTextConfig but is of type" - f" {type(config.text_config)}." - ) - - if not isinstance(config.vision_config, SiglipVisionConfig): - raise TypeError( - "config.vision_config is expected to be of type SiglipVisionConfig but is of type" - f" {type(config.vision_config)}." - ) - - text_config = config.text_config - vision_config = config.vision_config - - # First, initialize the text and vision models with proper attention implementation - text_model = SiglipTextModel._from_config(text_config) - vision_model = SiglipVisionModel._from_config(vision_config) - - # Second, get the text and vision submodules (for backward compatibility) - self.text_model = text_model.text_model - self.vision_model = vision_model.vision_model - - self.logit_scale = nn.Parameter(torch.randn(1)) - self.logit_bias = nn.Parameter(torch.randn(1)) - - # Initialize weights and apply final processing - self.post_init() - - @auto_docstring - def get_text_features( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - ) -> torch.FloatTensor: - r""" - Returns: - text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by - applying the projection layer to the pooled output of [`SiglipTextModel`]. - - Examples: - - ```python - >>> from transformers import AutoTokenizer, AutoModel - >>> import torch - - >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") - >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") - - >>> # important: make sure to set padding="max_length" as that's how the model was trained - >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") - >>> with torch.no_grad(): - ... text_features = model.get_text_features(**inputs) - ```""" - # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - text_outputs: BaseModelOutputWithPooling = self.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - pooled_output = text_outputs.pooler_output - - return pooled_output - - @auto_docstring - def get_image_features( - self, - pixel_values: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - interpolate_pos_encoding: bool = False, - ) -> torch.FloatTensor: - r""" - Returns: - image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by - applying the projection layer to the pooled output of [`SiglipVisionModel`]. - - Examples: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, AutoModel - >>> import torch - - >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") - >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, return_tensors="pt") - - >>> with torch.no_grad(): - ... image_features = model.get_image_features(**inputs) - ```""" - # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components. - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - vision_outputs: BaseModelOutputWithPooling = self.vision_model( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - interpolate_pos_encoding=interpolate_pos_encoding, - ) - - pooled_output = vision_outputs.pooler_output - - return pooled_output - - @can_return_tuple - @auto_docstring - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - return_loss: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - interpolate_pos_encoding: bool = False, - ) -> SiglipOutput: - r""" - return_loss (`bool`, *optional*): - Whether or not to return the contrastive loss. - - Examples: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, AutoModel - >>> import torch - - >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") - >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"] - >>> # important: we pass `padding=max_length` since the model was trained with this - >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt") - - >>> with torch.no_grad(): - ... outputs = model(**inputs) - - >>> logits_per_image = outputs.logits_per_image - >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities - >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") - 31.9% that image 0 is 'a photo of 2 cats' - ```""" - # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - vision_outputs: BaseModelOutputWithPooling = self.vision_model( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - interpolate_pos_encoding=interpolate_pos_encoding, - ) - - text_outputs: BaseModelOutputWithPooling = self.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - image_embeds = vision_outputs.pooler_output - text_embeds = text_outputs.pooler_output - - # normalized features - image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) - text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) - - # cosine similarity as logits - logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) - - logit_scale, logit_bias = self.logit_scale.to(text_embeds.device), self.logit_bias.to(text_embeds.device) - logits_per_text = logits_per_text * logit_scale.exp() + logit_bias - - logits_per_image = logits_per_text.t() - - loss = None - if return_loss: - # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip.py#L287 - eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device) - m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye - loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text) - nll = -torch.sum(loglik, dim=-1) - loss = nll.mean() - - return SiglipOutput( - loss=loss, - logits_per_image=logits_per_image, - logits_per_text=logits_per_text, - text_embeds=text_embeds, - image_embeds=image_embeds, - text_model_output=text_outputs, - vision_model_output=vision_outputs, - ) - - -@auto_docstring( - custom_intro=""" - SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of - the patch tokens) e.g. for ImageNet. - """ -) -class SiglipForImageClassification(SiglipPreTrainedModel): - main_input_name = "pixel_values" - - def __init__(self, config: SiglipConfig) -> None: - super().__init__(config) - - self.num_labels = config.num_labels - - # Create the vision model with proper attention - # and take only vision_model submodule (for backward compatibility) - vision_model = SiglipVisionModel._from_config(config.vision_config) - self.vision_model = vision_model.vision_model - - # Classifier head - self.classifier = ( - nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() - ) - - # Initialize weights and apply final processing - self.post_init() - - @can_return_tuple - @auto_docstring - def forward( - self, - pixel_values: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - interpolate_pos_encoding: bool = False, - ) -> ImageClassifierOutput: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the image classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - - Examples: - - ```python - >>> from transformers import AutoImageProcessor, SiglipForImageClassification - >>> import torch - >>> from PIL import Image - >>> import requests - - >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> # note: we are loading a `SiglipModel` from the hub here, - >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above. - >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224") - >>> model = SiglipForImageClassification.from_pretrained("google/siglip-base-patch16-224") - - >>> inputs = image_processor(images=image, return_tensors="pt") - >>> outputs = model(**inputs) - >>> logits = outputs.logits - >>> # model predicts one of the two classes - >>> predicted_class_idx = logits.argmax(-1).item() - >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) - Predicted class: LABEL_1 - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - outputs: BaseModelOutputWithPooling = self.vision_model( - pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - interpolate_pos_encoding=interpolate_pos_encoding, - ) - - sequence_output = outputs.last_hidden_state - - # average pool the patch tokens - sequence_output = torch.mean(sequence_output, dim=1) - # apply classifier - logits = self.classifier(sequence_output) - - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(logits, labels) - - return ImageClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -__all__ = [ - "SiglipModel", - "SiglipPreTrainedModel", - "SiglipTextModel", - "SiglipVisionModel", - "SiglipForImageClassification", -] \ No newline at end of file diff --git a/capvector-pi05/src/openpi/policies/aloha_policy.py b/capvector-pi05/src/openpi/policies/aloha_policy.py deleted file mode 100644 index b006f736096b6d8301262be477ac6b2329707337..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/policies/aloha_policy.py +++ /dev/null @@ -1,202 +0,0 @@ -import dataclasses -from typing import ClassVar - -import einops -import numpy as np - -from openpi import transforms - - -def make_aloha_example() -> dict: - """Creates a random input example for the Aloha policy.""" - return { - "state": np.ones((14,)), - "images": { - "cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), - "cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), - "cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), - "cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), - }, - "prompt": "do something", - } - - -@dataclasses.dataclass(frozen=True) -class AlohaInputs(transforms.DataTransformFn): - """Inputs for the Aloha policy. - - Expected inputs: - - images: dict[name, img] where img is [channel, height, width]. name must be in EXPECTED_CAMERAS. - - state: [14] - - actions: [action_horizon, 14] - """ - - # If true, this will convert the joint and gripper values from the standard Aloha space to - # the space used by the pi internal runtime which was used to train the base model. - adapt_to_pi: bool = True - - # The expected cameras names. All input cameras must be in this set. Missing cameras will be - # replaced with black images and the corresponding `image_mask` will be set to False. - EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = ("cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist") - - def __call__(self, data: dict) -> dict: - data = _decode_aloha(data, adapt_to_pi=self.adapt_to_pi) - - in_images = data["images"] - if set(in_images) - set(self.EXPECTED_CAMERAS): - raise ValueError(f"Expected images to contain {self.EXPECTED_CAMERAS}, got {tuple(in_images)}") - - # Assume that base image always exists. - base_image = in_images["cam_high"] - - images = { - "base_0_rgb": base_image, - } - image_masks = { - "base_0_rgb": np.True_, - } - - # Add the extra images. - extra_image_names = { - "left_wrist_0_rgb": "cam_left_wrist", - "right_wrist_0_rgb": "cam_right_wrist", - } - for dest, source in extra_image_names.items(): - if source in in_images: - images[dest] = in_images[source] - image_masks[dest] = np.True_ - else: - images[dest] = np.zeros_like(base_image) - image_masks[dest] = np.False_ - - inputs = { - "image": images, - "image_mask": image_masks, - "state": data["state"], - } - - # Actions are only available during training. - if "actions" in data: - actions = np.asarray(data["actions"]) - actions = _encode_actions_inv(actions, adapt_to_pi=self.adapt_to_pi) - inputs["actions"] = actions - - if "prompt" in data: - inputs["prompt"] = data["prompt"] - - return inputs - - -@dataclasses.dataclass(frozen=True) -class AlohaOutputs(transforms.DataTransformFn): - """Outputs for the Aloha policy.""" - - # If true, this will convert the joint and gripper values from the standard Aloha space to - # the space used by the pi internal runtime which was used to train the base model. - adapt_to_pi: bool = True - - def __call__(self, data: dict) -> dict: - # Only return the first 14 dims. - actions = np.asarray(data["actions"][:, :14]) - return {"actions": _encode_actions(actions, adapt_to_pi=self.adapt_to_pi)} - - -def _joint_flip_mask() -> np.ndarray: - """Used to convert between aloha and pi joint angles.""" - return np.array([1, -1, -1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1]) - - -def _normalize(x, min_val, max_val): - return (x - min_val) / (max_val - min_val) - - -def _unnormalize(x, min_val, max_val): - return x * (max_val - min_val) + min_val - - -def _gripper_to_angular(value): - # Aloha transforms the gripper positions into a linear space. The following code - # reverses this transformation to be consistent with pi0 which is pretrained in - # angular space. - # - # These values are coming from the Aloha code: - # PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED - value = _unnormalize(value, min_val=0.01844, max_val=0.05800) - - # This is the inverse of the angular to linear transformation inside the Interbotix code. - def linear_to_radian(linear_position, arm_length, horn_radius): - value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position) - return np.arcsin(np.clip(value, -1.0, 1.0)) - - # The constants are taken from the Interbotix code. - value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022) - - # pi0 gripper data is normalized (0, 1) between encoder counts (2405, 3110). - # There are 4096 total encoder counts and aloha uses a zero of 2048. - # Converting this to radians means that the normalized inputs are between (0.5476, 1.6296) - return _normalize(value, min_val=0.5476, max_val=1.6296) - - -def _gripper_from_angular(value): - # Convert from the gripper position used by pi0 to the gripper position that is used by Aloha. - # Note that the units are still angular but the range is different. - - # We do not scale the output since the trossen model predictions are already in radians. - # See the comment in _gripper_to_angular for a derivation of the constant - value = value + 0.5476 - - # These values are coming from the Aloha code: - # PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE - return _normalize(value, min_val=-0.6213, max_val=1.4910) - - -def _gripper_from_angular_inv(value): - # Directly inverts the gripper_from_angular function. - value = _unnormalize(value, min_val=-0.6213, max_val=1.4910) - return value - 0.5476 - - -def _decode_aloha(data: dict, *, adapt_to_pi: bool = False) -> dict: - # state is [left_arm_joint_angles, left_arm_gripper, right_arm_joint_angles, right_arm_gripper] - # dim sizes: [6, 1, 6, 1] - state = np.asarray(data["state"]) - state = _decode_state(state, adapt_to_pi=adapt_to_pi) - - def convert_image(img): - img = np.asarray(img) - # Convert to uint8 if using float images. - if np.issubdtype(img.dtype, np.floating): - img = (255 * img).astype(np.uint8) - # Convert from [channel, height, width] to [height, width, channel]. - return einops.rearrange(img, "c h w -> h w c") - - images = data["images"] - images_dict = {name: convert_image(img) for name, img in images.items()} - - data["images"] = images_dict - data["state"] = state - return data - - -def _decode_state(state: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray: - if adapt_to_pi: - # Flip the joints. - state = _joint_flip_mask() * state - # Reverse the gripper transformation that is being applied by the Aloha runtime. - state[[6, 13]] = _gripper_to_angular(state[[6, 13]]) - return state - - -def _encode_actions(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray: - if adapt_to_pi: - # Flip the joints. - actions = _joint_flip_mask() * actions - actions[:, [6, 13]] = _gripper_from_angular(actions[:, [6, 13]]) - return actions - - -def _encode_actions_inv(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray: - if adapt_to_pi: - actions = _joint_flip_mask() * actions - actions[:, [6, 13]] = _gripper_from_angular_inv(actions[:, [6, 13]]) - return actions diff --git a/capvector-pi05/src/openpi/policies/droid_policy.py b/capvector-pi05/src/openpi/policies/droid_policy.py deleted file mode 100644 index 786985d9268f1aed0a57bf2bb780ca6da692683f..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/policies/droid_policy.py +++ /dev/null @@ -1,81 +0,0 @@ -import dataclasses - -import einops -import numpy as np - -from openpi import transforms -from openpi.models import model as _model - - -def make_droid_example() -> dict: - """Creates a random input example for the Droid policy.""" - return { - "observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), - "observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), - "observation/joint_position": np.random.rand(7), - "observation/gripper_position": np.random.rand(1), - "prompt": "do something", - } - - -def _parse_image(image) -> np.ndarray: - image = np.asarray(image) - if np.issubdtype(image.dtype, np.floating): - image = (255 * image).astype(np.uint8) - if image.shape[0] == 3: - image = einops.rearrange(image, "c h w -> h w c") - return image - - -@dataclasses.dataclass(frozen=True) -class DroidInputs(transforms.DataTransformFn): - # Determines which model will be used. - model_type: _model.ModelType - - def __call__(self, data: dict) -> dict: - gripper_pos = np.asarray(data["observation/gripper_position"]) - if gripper_pos.ndim == 0: - # Ensure gripper position is a 1D array, not a scalar, so we can concatenate with joint positions - gripper_pos = gripper_pos[np.newaxis] - state = np.concatenate([data["observation/joint_position"], gripper_pos]) - - # Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically - # stores as float32 (C,H,W), gets skipped for policy inference - base_image = _parse_image(data["observation/exterior_image_1_left"]) - wrist_image = _parse_image(data["observation/wrist_image_left"]) - - match self.model_type: - case _model.ModelType.PI0 | _model.ModelType.PI05: - names = ("base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb") - images = (base_image, wrist_image, np.zeros_like(base_image)) - image_masks = (np.True_, np.True_, np.False_) - case _model.ModelType.PI0_FAST: - names = ("base_0_rgb", "base_1_rgb", "wrist_0_rgb") - # We don't mask out padding images for FAST models. - images = (base_image, np.zeros_like(base_image), wrist_image) - image_masks = (np.True_, np.True_, np.True_) - case _: - raise ValueError(f"Unsupported model type: {self.model_type}") - - inputs = { - "state": state, - "image": dict(zip(names, images, strict=True)), - "image_mask": dict(zip(names, image_masks, strict=True)), - } - - if "actions" in data: - inputs["actions"] = np.asarray(data["actions"]) - - if "prompt" in data: - if isinstance(data["prompt"], bytes): - data["prompt"] = data["prompt"].decode("utf-8") - inputs["prompt"] = data["prompt"] - - return inputs - - -@dataclasses.dataclass(frozen=True) -class DroidOutputs(transforms.DataTransformFn): - def __call__(self, data: dict) -> dict: - # Only return the first 8 dims. - return {"actions": np.asarray(data["actions"][:, :8])} diff --git a/capvector-pi05/src/openpi/policies/libero_policy.py b/capvector-pi05/src/openpi/policies/libero_policy.py deleted file mode 100644 index 7b51e93d201e73e155aba27db9c0fe531d93d074..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/policies/libero_policy.py +++ /dev/null @@ -1,100 +0,0 @@ -import dataclasses - -import einops -import numpy as np - -from openpi import transforms -from openpi.models import model as _model - - -def make_libero_example() -> dict: - """Creates a random input example for the Libero policy.""" - return { - "observation/state": np.random.rand(8), - "observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), - "observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), - "prompt": "do something", - } - - -def _parse_image(image) -> np.ndarray: - image = np.asarray(image) - if np.issubdtype(image.dtype, np.floating): - image = (255 * image).astype(np.uint8) - if image.shape[0] == 3: - image = einops.rearrange(image, "c h w -> h w c") - return image - - -@dataclasses.dataclass(frozen=True) -class LiberoInputs(transforms.DataTransformFn): - """ - This class is used to convert inputs to the model to the expected format. It is used for both training and inference. - - For your own dataset, you can copy this class and modify the keys based on the comments below to pipe - the correct elements of your dataset into the model. - """ - - # Determines which model will be used. - # Do not change this for your own dataset. - model_type: _model.ModelType - - def __call__(self, data: dict) -> dict: - # Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically - # stores as float32 (C,H,W), gets skipped for policy inference. - # Keep this for your own dataset, but if your dataset stores the images - # in a different key than "observation/image" or "observation/wrist_image", - # you should change it below. - # Pi0 models support three image inputs at the moment: one third-person view, - # and two wrist views (left and right). If your dataset does not have a particular type - # of image, e.g. wrist images, you can comment it out here and replace it with zeros like we do for the - # right wrist image below. - base_image = _parse_image(data["observation/image"]) - wrist_image = _parse_image(data["observation/wrist_image"]) - - # Create inputs dict. Do not change the keys in the dict below. - inputs = { - "state": data["observation/state"], - "image": { - "base_0_rgb": base_image, - "left_wrist_0_rgb": wrist_image, - # Pad any non-existent images with zero-arrays of the appropriate shape. - "right_wrist_0_rgb": np.zeros_like(base_image), - }, - "image_mask": { - "base_0_rgb": np.True_, - "left_wrist_0_rgb": np.True_, - # We only mask padding images for pi0 model, not pi0-FAST. Do not change this for your own dataset. - "right_wrist_0_rgb": np.True_ if self.model_type == _model.ModelType.PI0_FAST else np.False_, - }, - } - - # Pad actions to the model action dimension. Keep this for your own dataset. - # Actions are only available during training. - if "actions" in data: - inputs["actions"] = data["actions"] - - # Pass the prompt (aka language instruction) to the model. - # Keep this for your own dataset (but modify the key if the instruction is not - # stored in "prompt"; the output dict always needs to have the key "prompt"). - if "prompt" in data: - inputs["prompt"] = data["prompt"] - - return inputs - - -@dataclasses.dataclass(frozen=True) -class LiberoOutputs(transforms.DataTransformFn): - """ - This class is used to convert outputs from the model back the the dataset specific format. It is - used for inference only. - - For your own dataset, you can copy this class and modify the action dimension based on the comments below. - """ - - def __call__(self, data: dict) -> dict: - # Only return the first N actions -- since we padded actions above to fit the model action - # dimension, we need to now parse out the correct number of actions in the return dict. - # For Libero, we only return the first 7 actions (since the rest is padding). - # For your own dataset, replace `7` with the action dimension of your dataset. - return {"actions": np.asarray(data["actions"][:, :7])} diff --git a/capvector-pi05/src/openpi/policies/policy.py b/capvector-pi05/src/openpi/policies/policy.py deleted file mode 100644 index 334f50d68038e02780d0a937cfdae4a9826c0b7e..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/policies/policy.py +++ /dev/null @@ -1,135 +0,0 @@ -from collections.abc import Sequence -import logging -import pathlib -import time -from typing import Any, TypeAlias - -import flax -import flax.traverse_util -import jax -import jax.numpy as jnp -import numpy as np -from openpi_client import base_policy as _base_policy -import torch -from typing_extensions import override - -from openpi import transforms as _transforms -from openpi.models import model as _model -from openpi.shared import array_typing as at -from openpi.shared import nnx_utils - -BasePolicy: TypeAlias = _base_policy.BasePolicy - - -class Policy(BasePolicy): - def __init__( - self, - model: _model.BaseModel, - *, - rng: at.KeyArrayLike | None = None, - transforms: Sequence[_transforms.DataTransformFn] = (), - output_transforms: Sequence[_transforms.DataTransformFn] = (), - sample_kwargs: dict[str, Any] | None = None, - metadata: dict[str, Any] | None = None, - pytorch_device: str = "cpu", - is_pytorch: bool = False, - ): - """Initialize the Policy. - - Args: - model: The model to use for action sampling. - rng: Random number generator key for JAX models. Ignored for PyTorch models. - transforms: Input data transformations to apply before inference. - output_transforms: Output data transformations to apply after inference. - sample_kwargs: Additional keyword arguments to pass to model.sample_actions. - metadata: Additional metadata to store with the policy. - pytorch_device: Device to use for PyTorch models (e.g., "cpu", "cuda:0"). - Only relevant when is_pytorch=True. - is_pytorch: Whether the model is a PyTorch model. If False, assumes JAX model. - """ - self._model = model - self._input_transform = _transforms.compose(transforms) - self._output_transform = _transforms.compose(output_transforms) - self._sample_kwargs = sample_kwargs or {} - self._metadata = metadata or {} - self._is_pytorch_model = is_pytorch - self._pytorch_device = pytorch_device - - if self._is_pytorch_model: - self._model = self._model.to(pytorch_device) - self._model.eval() - self._sample_actions = model.sample_actions - else: - # JAX model setup - self._sample_actions = nnx_utils.module_jit(model.sample_actions) - self._rng = rng or jax.random.key(0) - - @override - def infer(self, obs: dict, *, noise: np.ndarray | None = None) -> dict: # type: ignore[misc] - # Make a copy since transformations may modify the inputs in place. - inputs = jax.tree.map(lambda x: x, obs) - inputs = self._input_transform(inputs) - if not self._is_pytorch_model: - # Make a batch and convert to jax.Array. - inputs = jax.tree.map(lambda x: jnp.asarray(x)[np.newaxis, ...], inputs) - self._rng, sample_rng_or_pytorch_device = jax.random.split(self._rng) - else: - # Convert inputs to PyTorch tensors and move to correct device - inputs = jax.tree.map(lambda x: torch.from_numpy(np.array(x)).to(self._pytorch_device)[None, ...], inputs) - sample_rng_or_pytorch_device = self._pytorch_device - - # Prepare kwargs for sample_actions - sample_kwargs = dict(self._sample_kwargs) - if noise is not None: - noise = torch.from_numpy(noise).to(self._pytorch_device) if self._is_pytorch_model else jnp.asarray(noise) - - if noise.ndim == 2: # If noise is (action_horizon, action_dim), add batch dimension - noise = noise[None, ...] # Make it (1, action_horizon, action_dim) - sample_kwargs["noise"] = noise - - observation = _model.Observation.from_dict(inputs) - start_time = time.monotonic() - outputs = { - "state": inputs["state"], - "actions": self._sample_actions(sample_rng_or_pytorch_device, observation, **sample_kwargs), - } - model_time = time.monotonic() - start_time - if self._is_pytorch_model: - outputs = jax.tree.map(lambda x: np.asarray(x[0, ...].detach().cpu()), outputs) - else: - outputs = jax.tree.map(lambda x: np.asarray(x[0, ...]), outputs) - - outputs = self._output_transform(outputs) - outputs["policy_timing"] = { - "infer_ms": model_time * 1000, - } - return outputs - - @property - def metadata(self) -> dict[str, Any]: - return self._metadata - - -class PolicyRecorder(_base_policy.BasePolicy): - """Records the policy's behavior to disk.""" - - def __init__(self, policy: _base_policy.BasePolicy, record_dir: str): - self._policy = policy - - logging.info(f"Dumping policy records to: {record_dir}") - self._record_dir = pathlib.Path(record_dir) - self._record_dir.mkdir(parents=True, exist_ok=True) - self._record_step = 0 - - @override - def infer(self, obs: dict) -> dict: # type: ignore[misc] - results = self._policy.infer(obs) - - data = {"inputs": obs, "outputs": results} - data = flax.traverse_util.flatten_dict(data, sep="/") - - output_path = self._record_dir / f"step_{self._record_step}" - self._record_step += 1 - - np.save(output_path, np.asarray(data)) - return results diff --git a/capvector-pi05/src/openpi/policies/policy_config.py b/capvector-pi05/src/openpi/policies/policy_config.py deleted file mode 100644 index 18bc2211348f24fc29df58872115aaf826636e1c..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/policies/policy_config.py +++ /dev/null @@ -1,94 +0,0 @@ -import logging -import os -import pathlib -from typing import Any - -import jax.numpy as jnp - -import openpi.models.model as _model -import openpi.policies.policy as _policy -import openpi.shared.download as download -from openpi.training import checkpoints as _checkpoints -from openpi.training import config as _config -import openpi.transforms as transforms - - -def create_trained_policy( - train_config: _config.TrainConfig, - checkpoint_dir: pathlib.Path | str, - *, - repack_transforms: transforms.Group | None = None, - sample_kwargs: dict[str, Any] | None = None, - default_prompt: str | None = None, - norm_stats: dict[str, transforms.NormStats] | None = None, - pytorch_device: str | None = None, -) -> _policy.Policy: - """Create a policy from a trained checkpoint. - - Args: - train_config: The training config to use to create the model. - checkpoint_dir: The directory to load the model from. - repack_transforms: Optional transforms that will be applied before any other transforms. - sample_kwargs: The kwargs to pass to the `sample_actions` method. If not provided, the default - kwargs will be used. - default_prompt: The default prompt to use for the policy. Will inject the prompt into the input - data if it doesn't already exist. - norm_stats: The norm stats to use for the policy. If not provided, the norm stats will be loaded - from the checkpoint directory. - pytorch_device: Device to use for PyTorch models (e.g., "cpu", "cuda", "cuda:0"). - If None and is_pytorch=True, will use "cuda" if available, otherwise "cpu". - - Note: - The function automatically detects whether the model is PyTorch-based by checking for the - presence of "model.safensors" in the checkpoint directory. - """ - repack_transforms = repack_transforms or transforms.Group() - checkpoint_dir = download.maybe_download(str(checkpoint_dir)) - - # Check if this is a PyTorch model by looking for model.safetensors - weight_path = os.path.join(checkpoint_dir, "model.safetensors") - is_pytorch = os.path.exists(weight_path) - - logging.info("Loading model...") - if is_pytorch: - model = train_config.model.load_pytorch(train_config, weight_path) - model.paligemma_with_expert.to_bfloat16_for_selected_params("bfloat16") - else: - model = train_config.model.load(_model.restore_params(checkpoint_dir / "params", dtype=jnp.bfloat16)) - data_config = train_config.data.create(train_config.assets_dirs, train_config.model) - if norm_stats is None: - # We are loading the norm stats from the checkpoint instead of the config assets dir to make sure - # that the policy is using the same normalization stats as the original training process. - if data_config.asset_id is None: - raise ValueError("Asset id is required to load norm stats.") - norm_stats = _checkpoints.load_norm_stats(checkpoint_dir / "assets", data_config.asset_id) - - # Determine the device to use for PyTorch models - if is_pytorch and pytorch_device is None: - try: - import torch - - pytorch_device = "cuda" if torch.cuda.is_available() else "cpu" - except ImportError: - pytorch_device = "cpu" - - return _policy.Policy( - model, - transforms=[ - *repack_transforms.inputs, - transforms.InjectDefaultPrompt(default_prompt), - *data_config.data_transforms.inputs, - transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm), - *data_config.model_transforms.inputs, - ], - output_transforms=[ - *data_config.model_transforms.outputs, - transforms.Unnormalize(norm_stats, use_quantiles=data_config.use_quantile_norm), - *data_config.data_transforms.outputs, - *repack_transforms.outputs, - ], - sample_kwargs=sample_kwargs, - metadata=train_config.policy_metadata, - is_pytorch=is_pytorch, - pytorch_device=pytorch_device if is_pytorch else None, - ) diff --git a/capvector-pi05/src/openpi/policies/policy_test.py b/capvector-pi05/src/openpi/policies/policy_test.py deleted file mode 100644 index adae783af773deae088f7920078bb8bb598f9194..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/policies/policy_test.py +++ /dev/null @@ -1,34 +0,0 @@ -from openpi_client import action_chunk_broker -import pytest - -from openpi.policies import aloha_policy -from openpi.policies import policy_config as _policy_config -from openpi.training import config as _config - - -@pytest.mark.manual -def test_infer(): - config = _config.get_config("pi0_aloha_sim") - policy = _policy_config.create_trained_policy(config, "gs://openpi-assets/checkpoints/pi0_aloha_sim") - - example = aloha_policy.make_aloha_example() - result = policy.infer(example) - - assert result["actions"].shape == (config.model.action_horizon, 14) - - -@pytest.mark.manual -def test_broker(): - config = _config.get_config("pi0_aloha_sim") - policy = _policy_config.create_trained_policy(config, "gs://openpi-assets/checkpoints/pi0_aloha_sim") - - broker = action_chunk_broker.ActionChunkBroker( - policy, - # Only execute the first half of the chunk. - action_horizon=config.model.action_horizon // 2, - ) - - example = aloha_policy.make_aloha_example() - for _ in range(config.model.action_horizon): - outputs = broker.infer(example) - assert outputs["actions"].shape == (14,) diff --git a/capvector-pi05/src/openpi/py.typed b/capvector-pi05/src/openpi/py.typed deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/capvector-pi05/src/openpi/serving/websocket_policy_server.py b/capvector-pi05/src/openpi/serving/websocket_policy_server.py deleted file mode 100644 index 6e6916d18c52f2521400bc382b3052f0bab2ba6f..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/serving/websocket_policy_server.py +++ /dev/null @@ -1,90 +0,0 @@ -import asyncio -import http -import logging -import time -import traceback - -from openpi_client import base_policy as _base_policy -from openpi_client import msgpack_numpy -import websockets.asyncio.server as _server -import websockets.frames - -logger = logging.getLogger(__name__) - - -class WebsocketPolicyServer: - """Serves a policy using the websocket protocol. See websocket_client_policy.py for a client implementation. - - Currently only implements the `load` and `infer` methods. - """ - - def __init__( - self, - policy: _base_policy.BasePolicy, - host: str = "0.0.0.0", - port: int | None = None, - metadata: dict | None = None, - ) -> None: - self._policy = policy - self._host = host - self._port = port - self._metadata = metadata or {} - logging.getLogger("websockets.server").setLevel(logging.INFO) - - def serve_forever(self) -> None: - asyncio.run(self.run()) - - async def run(self): - async with _server.serve( - self._handler, - self._host, - self._port, - compression=None, - max_size=None, - process_request=_health_check, - ) as server: - await server.serve_forever() - - async def _handler(self, websocket: _server.ServerConnection): - logger.info(f"Connection from {websocket.remote_address} opened") - packer = msgpack_numpy.Packer() - - await websocket.send(packer.pack(self._metadata)) - - prev_total_time = None - while True: - try: - start_time = time.monotonic() - obs = msgpack_numpy.unpackb(await websocket.recv()) - - infer_time = time.monotonic() - action = self._policy.infer(obs) - infer_time = time.monotonic() - infer_time - - action["server_timing"] = { - "infer_ms": infer_time * 1000, - } - if prev_total_time is not None: - # We can only record the last total time since we also want to include the send time. - action["server_timing"]["prev_total_ms"] = prev_total_time * 1000 - - await websocket.send(packer.pack(action)) - prev_total_time = time.monotonic() - start_time - - except websockets.ConnectionClosed: - logger.info(f"Connection from {websocket.remote_address} closed") - break - except Exception: - await websocket.send(traceback.format_exc()) - await websocket.close( - code=websockets.frames.CloseCode.INTERNAL_ERROR, - reason="Internal server error. Traceback included in previous frame.", - ) - raise - - -def _health_check(connection: _server.ServerConnection, request: _server.Request) -> _server.Response | None: - if request.path == "/healthz": - return connection.respond(http.HTTPStatus.OK, "OK\n") - # Continue with the normal request handling. - return None diff --git a/capvector-pi05/src/openpi/shared/__init__.py b/capvector-pi05/src/openpi/shared/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/capvector-pi05/src/openpi/shared/array_typing.py b/capvector-pi05/src/openpi/shared/array_typing.py deleted file mode 100644 index fed20bfa1c79d540815bf2c444e803ba34383c9f..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/shared/array_typing.py +++ /dev/null @@ -1,89 +0,0 @@ -import contextlib -import functools as ft -import inspect -from typing import TypeAlias, TypeVar, cast - -import beartype -import jax -import jax._src.tree_util as private_tree_util -import jax.core -from jaxtyping import ArrayLike -from jaxtyping import Bool # noqa: F401 -from jaxtyping import DTypeLike # noqa: F401 -from jaxtyping import Float -from jaxtyping import Int # noqa: F401 -from jaxtyping import Key # noqa: F401 -from jaxtyping import Num # noqa: F401 -from jaxtyping import PyTree -from jaxtyping import Real # noqa: F401 -from jaxtyping import UInt8 # noqa: F401 -from jaxtyping import config -from jaxtyping import jaxtyped -import jaxtyping._decorator -import torch - -# patch jaxtyping to handle https://github.com/patrick-kidger/jaxtyping/issues/277. -# the problem is that custom PyTree nodes are sometimes initialized with arbitrary types (e.g., `jax.ShapeDtypeStruct`, -# `jax.Sharding`, or even ) due to JAX tracing operations. this patch skips typechecking when the stack trace -# contains `jax._src.tree_util`, which should only be the case during tree unflattening. -_original_check_dataclass_annotations = jaxtyping._decorator._check_dataclass_annotations # noqa: SLF001 -# Redefine Array to include both JAX arrays and PyTorch tensors -Array = jax.Array | torch.Tensor - - -def _check_dataclass_annotations(self, typechecker): - if not any( - frame.frame.f_globals.get("__name__") in {"jax._src.tree_util", "flax.nnx.transforms.compilation"} - for frame in inspect.stack() - ): - return _original_check_dataclass_annotations(self, typechecker) - return None - - -jaxtyping._decorator._check_dataclass_annotations = _check_dataclass_annotations # noqa: SLF001 - -KeyArrayLike: TypeAlias = jax.typing.ArrayLike -Params: TypeAlias = PyTree[Float[ArrayLike, "..."]] - -T = TypeVar("T") - - -# runtime type-checking decorator -def typecheck(t: T) -> T: - return cast(T, ft.partial(jaxtyped, typechecker=beartype.beartype)(t)) - - -@contextlib.contextmanager -def disable_typechecking(): - initial = config.jaxtyping_disable - config.update("jaxtyping_disable", True) # noqa: FBT003 - yield - config.update("jaxtyping_disable", initial) - - -def check_pytree_equality(*, expected: PyTree, got: PyTree, check_shapes: bool = False, check_dtypes: bool = False): - """Checks that two PyTrees have the same structure and optionally checks shapes and dtypes. Creates a much nicer - error message than if `jax.tree.map` is naively used on PyTrees with different structures. - """ - - if errors := list(private_tree_util.equality_errors(expected, got)): - raise ValueError( - "PyTrees have different structure:\n" - + ( - "\n".join( - f" - at keypath '{jax.tree_util.keystr(path)}': expected {thing1}, got {thing2}, so {explanation}.\n" - for path, thing1, thing2, explanation in errors - ) - ) - ) - - if check_shapes or check_dtypes: - - def check(kp, x, y): - if check_shapes and x.shape != y.shape: - raise ValueError(f"Shape mismatch at {jax.tree_util.keystr(kp)}: expected {x.shape}, got {y.shape}") - - if check_dtypes and x.dtype != y.dtype: - raise ValueError(f"Dtype mismatch at {jax.tree_util.keystr(kp)}: expected {x.dtype}, got {y.dtype}") - - jax.tree_util.tree_map_with_path(check, expected, got) diff --git a/capvector-pi05/src/openpi/shared/download.py b/capvector-pi05/src/openpi/shared/download.py deleted file mode 100644 index 7dd5304cb78d912f1bb4d63ce55b7c96f4c555d6..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/shared/download.py +++ /dev/null @@ -1,194 +0,0 @@ -import concurrent.futures -import datetime -import logging -import os -import pathlib -import re -import shutil -import stat -import time -import urllib.parse - -import filelock -import fsspec -import fsspec.generic -import tqdm_loggable.auto as tqdm - -# Environment variable to control cache directory path, ~/.cache/openpi will be used by default. -_OPENPI_DATA_HOME = "OPENPI_DATA_HOME" -DEFAULT_CACHE_DIR = "~/.cache/openpi" - -logger = logging.getLogger(__name__) - - -def get_cache_dir() -> pathlib.Path: - cache_dir = pathlib.Path(os.getenv(_OPENPI_DATA_HOME, DEFAULT_CACHE_DIR)).expanduser().resolve() - cache_dir.mkdir(parents=True, exist_ok=True) - _set_folder_permission(cache_dir) - return cache_dir - - -def maybe_download(url: str, *, force_download: bool = False, **kwargs) -> pathlib.Path: - """Download a file or directory from a remote filesystem to the local cache, and return the local path. - - If the local file already exists, it will be returned directly. - - It is safe to call this function concurrently from multiple processes. - See `get_cache_dir` for more details on the cache directory. - - Args: - url: URL to the file to download. - force_download: If True, the file will be downloaded even if it already exists in the cache. - **kwargs: Additional arguments to pass to fsspec. - - Returns: - Local path to the downloaded file or directory. That path is guaranteed to exist and is absolute. - """ - # Don't use fsspec to parse the url to avoid unnecessary connection to the remote filesystem. - parsed = urllib.parse.urlparse(url) - - # Short circuit if this is a local path. - if parsed.scheme == "": - path = pathlib.Path(url) - if not path.exists(): - raise FileNotFoundError(f"File not found at {url}") - return path.resolve() - - cache_dir = get_cache_dir() - - local_path = cache_dir / parsed.netloc / parsed.path.strip("/") - local_path = local_path.resolve() - - # Check if the cache should be invalidated. - invalidate_cache = False - if local_path.exists(): - if force_download or _should_invalidate_cache(cache_dir, local_path): - invalidate_cache = True - else: - return local_path - - try: - lock_path = local_path.with_suffix(".lock") - with filelock.FileLock(lock_path): - # Ensure consistent permissions for the lock file. - _ensure_permissions(lock_path) - # First, remove the existing cache if it is expired. - if invalidate_cache: - logger.info(f"Removing expired cached entry: {local_path}") - if local_path.is_dir(): - shutil.rmtree(local_path) - else: - local_path.unlink() - - # Download the data to a local cache. - logger.info(f"Downloading {url} to {local_path}") - scratch_path = local_path.with_suffix(".partial") - _download_fsspec(url, scratch_path, **kwargs) - - shutil.move(scratch_path, local_path) - _ensure_permissions(local_path) - - except PermissionError as e: - msg = ( - f"Local file permission error was encountered while downloading {url}. " - f"Please try again after removing the cached data using: `rm -rf {local_path}*`" - ) - raise PermissionError(msg) from e - - return local_path - - -def _download_fsspec(url: str, local_path: pathlib.Path, **kwargs) -> None: - """Download a file from a remote filesystem to the local cache, and return the local path.""" - fs, _ = fsspec.core.url_to_fs(url, **kwargs) - info = fs.info(url) - # Folders are represented by 0-byte objects with a trailing forward slash. - if is_dir := (info["type"] == "directory" or (info["size"] == 0 and info["name"].endswith("/"))): - total_size = fs.du(url) - else: - total_size = info["size"] - with tqdm.tqdm(total=total_size, unit="iB", unit_scale=True, unit_divisor=1024) as pbar: - executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) - future = executor.submit(fs.get, url, local_path, recursive=is_dir) - while not future.done(): - current_size = sum(f.stat().st_size for f in [*local_path.rglob("*"), local_path] if f.is_file()) - pbar.update(current_size - pbar.n) - time.sleep(1) - pbar.update(total_size - pbar.n) - - -def _set_permission(path: pathlib.Path, target_permission: int): - """chmod requires executable permission to be set, so we skip if the permission is already match with the target.""" - if path.stat().st_mode & target_permission == target_permission: - logger.debug(f"Skipping {path} because it already has correct permissions") - return - path.chmod(target_permission) - logger.debug(f"Set {path} to {target_permission}") - - -def _set_folder_permission(folder_path: pathlib.Path) -> None: - """Set folder permission to be read, write and searchable.""" - _set_permission(folder_path, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) - - -def _ensure_permissions(path: pathlib.Path) -> None: - """Since we are sharing cache directory with containerized runtime as well as training script, we need to - ensure that the cache directory has the correct permissions. - """ - - def _setup_folder_permission_between_cache_dir_and_path(path: pathlib.Path) -> None: - cache_dir = get_cache_dir() - relative_path = path.relative_to(cache_dir) - moving_path = cache_dir - for part in relative_path.parts: - _set_folder_permission(moving_path / part) - moving_path = moving_path / part - - def _set_file_permission(file_path: pathlib.Path) -> None: - """Set all files to be read & writable, if it is a script, keep it as a script.""" - file_rw = stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IWGRP | stat.S_IROTH | stat.S_IWOTH - if file_path.stat().st_mode & 0o100: - _set_permission(file_path, file_rw | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH) - else: - _set_permission(file_path, file_rw) - - _setup_folder_permission_between_cache_dir_and_path(path) - for root, dirs, files in os.walk(str(path)): - root_path = pathlib.Path(root) - for file in files: - file_path = root_path / file - _set_file_permission(file_path) - - for dir in dirs: - dir_path = root_path / dir - _set_folder_permission(dir_path) - - -def _get_mtime(year: int, month: int, day: int) -> float: - """Get the mtime of a given date at midnight UTC.""" - date = datetime.datetime(year, month, day, tzinfo=datetime.UTC) - return time.mktime(date.timetuple()) - - -# Map of relative paths, defined as regular expressions, to expiration timestamps (mtime format). -# Partial matching will be used from top to bottom and the first match will be chosen. -# Cached entries will be retained only if they are newer than the expiration timestamp. -_INVALIDATE_CACHE_DIRS: dict[re.Pattern, float] = { - re.compile("openpi-assets/checkpoints/pi0_aloha_pen_uncap"): _get_mtime(2025, 2, 17), - re.compile("openpi-assets/checkpoints/pi0_libero"): _get_mtime(2025, 2, 6), - re.compile("openpi-assets/checkpoints/"): _get_mtime(2025, 2, 3), -} - - -def _should_invalidate_cache(cache_dir: pathlib.Path, local_path: pathlib.Path) -> bool: - """Invalidate the cache if it is expired. Return True if the cache was invalidated.""" - - assert local_path.exists(), f"File not found at {local_path}" - - relative_path = str(local_path.relative_to(cache_dir)) - for pattern, expire_time in _INVALIDATE_CACHE_DIRS.items(): - if pattern.match(relative_path): - # Remove if not newer than the expiration timestamp. - return local_path.stat().st_mtime <= expire_time - - return False diff --git a/capvector-pi05/src/openpi/shared/download_test.py b/capvector-pi05/src/openpi/shared/download_test.py deleted file mode 100644 index 48417fe3ad386a642e03bdb4a16b884ec951e1eb..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/shared/download_test.py +++ /dev/null @@ -1,54 +0,0 @@ -import pathlib - -import pytest - -import openpi.shared.download as download - - -@pytest.fixture(scope="session", autouse=True) -def set_openpi_data_home(tmp_path_factory): - temp_dir = tmp_path_factory.mktemp("openpi_data") - with pytest.MonkeyPatch().context() as mp: - mp.setenv("OPENPI_DATA_HOME", str(temp_dir)) - yield - - -def test_download_local(tmp_path: pathlib.Path): - local_path = tmp_path / "local" - local_path.touch() - - result = download.maybe_download(str(local_path)) - assert result == local_path - - with pytest.raises(FileNotFoundError): - download.maybe_download("bogus") - - -def test_download_gs_dir(): - remote_path = "gs://openpi-assets/testdata/random" - - local_path = download.maybe_download(remote_path) - assert local_path.exists() - - new_local_path = download.maybe_download(remote_path) - assert new_local_path == local_path - - -def test_download_gs(): - remote_path = "gs://openpi-assets/testdata/random/random_512kb.bin" - - local_path = download.maybe_download(remote_path) - assert local_path.exists() - - new_local_path = download.maybe_download(remote_path) - assert new_local_path == local_path - - -def test_download_fsspec(): - remote_path = "gs://big_vision/paligemma_tokenizer.model" - - local_path = download.maybe_download(remote_path, gs={"token": "anon"}) - assert local_path.exists() - - new_local_path = download.maybe_download(remote_path, gs={"token": "anon"}) - assert new_local_path == local_path diff --git a/capvector-pi05/src/openpi/shared/image_tools.py b/capvector-pi05/src/openpi/shared/image_tools.py deleted file mode 100644 index 50548c1b1176616dc45f257f4d97cf744ea7fa98..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/shared/image_tools.py +++ /dev/null @@ -1,186 +0,0 @@ -import functools - -import jax -import jax.numpy as jnp -import torch -import torch.nn.functional as F # noqa: N812 - -import openpi.shared.array_typing as at - - -@functools.partial(jax.jit, static_argnums=(1, 2, 3)) -@at.typecheck -def resize_with_pad( - images: at.UInt8[at.Array, "*b h w c"] | at.Float[at.Array, "*b h w c"], - height: int, - width: int, - method: jax.image.ResizeMethod = jax.image.ResizeMethod.LINEAR, -) -> at.UInt8[at.Array, "*b {height} {width} c"] | at.Float[at.Array, "*b {height} {width} c"]: - """Replicates tf.image.resize_with_pad. Resizes an image to a target height and width without distortion - by padding with black. If the image is float32, it must be in the range [-1, 1]. - """ - has_batch_dim = images.ndim == 4 - if not has_batch_dim: - images = images[None] # type: ignore - cur_height, cur_width = images.shape[1:3] - ratio = max(cur_width / width, cur_height / height) - resized_height = int(cur_height / ratio) - resized_width = int(cur_width / ratio) - resized_images = jax.image.resize( - images, (images.shape[0], resized_height, resized_width, images.shape[3]), method=method - ) - if images.dtype == jnp.uint8: - # round from float back to uint8 - resized_images = jnp.round(resized_images).clip(0, 255).astype(jnp.uint8) - elif images.dtype == jnp.float32: - resized_images = resized_images.clip(-1.0, 1.0) - else: - raise ValueError(f"Unsupported image dtype: {images.dtype}") - - pad_h0, remainder_h = divmod(height - resized_height, 2) - pad_h1 = pad_h0 + remainder_h - pad_w0, remainder_w = divmod(width - resized_width, 2) - pad_w1 = pad_w0 + remainder_w - padded_images = jnp.pad( - resized_images, - ((0, 0), (pad_h0, pad_h1), (pad_w0, pad_w1), (0, 0)), - constant_values=0 if images.dtype == jnp.uint8 else -1.0, - ) - - if not has_batch_dim: - padded_images = padded_images[0] - return padded_images - - -def resize_with_pad_torch( - images: torch.Tensor, - height: int, - width: int, - mode: str = "bilinear", -) -> torch.Tensor: - """PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion - by padding with black. If the image is float32, it must be in the range [-1, 1]. - - Args: - images: Tensor of shape [*b, h, w, c] or [*b, c, h, w] - height: Target height - width: Target width - mode: Interpolation mode ('bilinear', 'nearest', etc.) - - Returns: - Resized and padded tensor with same shape format as input - """ - # Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w] - if images.shape[-1] <= 4: # Assume channels-last format - channels_last = True - # Convert to channels-first for torch operations - if images.dim() == 3: - images = images.unsqueeze(0) # Add batch dimension - images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w] - else: - channels_last = False - if images.dim() == 3: - images = images.unsqueeze(0) # Add batch dimension - - batch_size, channels, cur_height, cur_width = images.shape - - # Calculate resize ratio - ratio = max(cur_width / width, cur_height / height) - resized_height = int(cur_height / ratio) - resized_width = int(cur_width / ratio) - - # Resize - resized_images = F.interpolate( - images, size=(resized_height, resized_width), mode=mode, align_corners=False if mode == "bilinear" else None - ) - - # Handle dtype-specific clipping - if images.dtype == torch.uint8: - resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8) - elif images.dtype == torch.float32: - resized_images = resized_images.clamp(-1.0, 1.0) - else: - raise ValueError(f"Unsupported image dtype: {images.dtype}") - - # Calculate padding - pad_h0, remainder_h = divmod(height - resized_height, 2) - pad_h1 = pad_h0 + remainder_h - pad_w0, remainder_w = divmod(width - resized_width, 2) - pad_w1 = pad_w0 + remainder_w - - # Pad - constant_value = 0 if images.dtype == torch.uint8 else -1.0 - padded_images = F.pad( - resized_images, - (pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom - mode="constant", - value=constant_value, - ) - - # Convert back to original format if needed - if channels_last: - padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] - if batch_size == 1 and images.shape[0] == 1: - padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added - - return padded_images - - -def replace_padding_0to1_torch(image: torch.Tensor,) -> torch.Tensor: - """PyTorch version of replace_padding_0to1. - OpenPI requires images with 0 value paddings, while VGGT series requires 1 value paddings. - Here it achieves this bounding-box based padding replacement. - Args: - image: Tensor of shape [*b, h, w, c] - Returns: - Padding-replaced tensor with same shape as input - """ - single = False - if image.dim() == 3: - image = image.unsqueeze(0) - single = True - - b, h, w, c = image.shape - device = image.device - - nonzero_any = (image != 0).any(dim=-1) - - row_any = nonzero_any.any(dim=2) - col_any = nonzero_any.any(dim=1) - - top = row_any.to(torch.float32).argmax(dim=1) - bottom = h - 1 - row_any.flip(dims=[1]).to(torch.float32).argmax(dim=1) - left = col_any.to(torch.float32).argmax(dim=1) - right = w - 1 - col_any.flip(dims=[1]).to(torch.float32).argmax(dim=1) - - has_any = row_any.any(dim=1) - top = torch.where(has_any, top, torch.zeros_like(top)) - bottom = torch.where(has_any, bottom, torch.full_like(bottom, h - 1)) - left = torch.where(has_any, left, torch.zeros_like(left)) - right = torch.where(has_any, right, torch.full_like(right, w - 1)) - - rows = torch.arange(h, device=device).view(1, h, 1) - cols = torch.arange(w, device=device).view(1, 1, w) - top_v = top.view(b, 1, 1) - bottom_v = bottom.view(b, 1, 1) - left_v = left.view(b, 1, 1) - right_v = right.view(b, 1, 1) - - row_mask = (rows >= top_v) & (rows <= bottom_v) - col_mask = (cols >= left_v) & (cols <= right_v) - inside_mask = row_mask & col_mask - - padding_mask = ~inside_mask - - pixel_zero = (image == 0).all(dim=-1) - - final_mask = padding_mask & pixel_zero - - if final_mask.any(): - mask_exp = final_mask.unsqueeze(-1).expand_as(image) - one_t = torch.tensor(1, dtype=image.dtype, device=device) - image = torch.where(mask_exp, one_t, image) - - if single: - image = image.squeeze(0) - return image \ No newline at end of file diff --git a/capvector-pi05/src/openpi/shared/image_tools_test.py b/capvector-pi05/src/openpi/shared/image_tools_test.py deleted file mode 100644 index be1b7e1bcd08106d61b04bd40e73e5eeecb9f484..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/shared/image_tools_test.py +++ /dev/null @@ -1,37 +0,0 @@ -import jax.numpy as jnp - -from openpi.shared import image_tools - - -def test_resize_with_pad_shapes(): - # Test case 1: Resize image with larger dimensions - images = jnp.zeros((2, 10, 10, 3), dtype=jnp.uint8) # Input images of shape (batch_size, height, width, channels) - height = 20 - width = 20 - resized_images = image_tools.resize_with_pad(images, height, width) - assert resized_images.shape == (2, height, width, 3) - assert jnp.all(resized_images == 0) - - # Test case 2: Resize image with smaller dimensions - images = jnp.zeros((3, 30, 30, 3), dtype=jnp.uint8) - height = 15 - width = 15 - resized_images = image_tools.resize_with_pad(images, height, width) - assert resized_images.shape == (3, height, width, 3) - assert jnp.all(resized_images == 0) - - # Test case 3: Resize image with the same dimensions - images = jnp.zeros((1, 50, 50, 3), dtype=jnp.uint8) - height = 50 - width = 50 - resized_images = image_tools.resize_with_pad(images, height, width) - assert resized_images.shape == (1, height, width, 3) - assert jnp.all(resized_images == 0) - - # Test case 3: Resize image with odd-numbered padding - images = jnp.zeros((1, 256, 320, 3), dtype=jnp.uint8) - height = 60 - width = 80 - resized_images = image_tools.resize_with_pad(images, height, width) - assert resized_images.shape == (1, height, width, 3) - assert jnp.all(resized_images == 0) diff --git a/capvector-pi05/src/openpi/shared/nnx_utils.py b/capvector-pi05/src/openpi/shared/nnx_utils.py deleted file mode 100644 index 08907a48a6d96bdaf5ce42f1da836c74bf0285ed..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/shared/nnx_utils.py +++ /dev/null @@ -1,69 +0,0 @@ -from collections.abc import Callable -import dataclasses -import functools -import inspect -import re -from typing import Any, ParamSpec, TypeVar - -import flax.nnx as nnx -import jax - -P = ParamSpec("P") -R = TypeVar("R") - - -def module_jit(meth: Callable[P, R], *jit_args, **jit_kwargs) -> Callable[P, R]: - """A higher-order function to JIT-compile `nnx.Module` methods, freezing the module's state in the process. - - Why not `nnx.jit`? For some reason, naively applying `nnx.jit` to `nnx.Module` methods, bound or unbound, uses much - more memory than necessary. I'm guessing it has something to do with the fact that it must keep track of module - mutations. Also, `nnx.jit` has some inherent overhead compared to a standard `jax.jit`, since every call must - traverse the NNX module graph. See https://github.com/google/flax/discussions/4224 for details. - - `module_jit` is an alternative that avoids these issues by freezing the module's state. The function returned by - `module_jit` acts exactly like the original method, except that the state of the module is frozen to whatever it was - when `module_jit` was called. Mutations to the module within `meth` are still allowed, but they will be discarded - after the method call completes. - """ - if not (inspect.ismethod(meth) and isinstance(meth.__self__, nnx.Module)): - raise ValueError("module_jit must only be used on bound methods of nnx.Modules.") - - graphdef, state = nnx.split(meth.__self__) - - def fun(state: nnx.State, *args: P.args, **kwargs: P.kwargs) -> R: - module = nnx.merge(graphdef, state) - return meth.__func__(module, *args, **kwargs) - - jitted_fn = jax.jit(fun, *jit_args, **jit_kwargs) - - @functools.wraps(meth) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: - return jitted_fn(state, *args, **kwargs) - - return wrapper - - -@dataclasses.dataclass(frozen=True) -class PathRegex: - """NNX Filter that matches paths using a regex. - - By default, paths are joined with a `/` separator. This can be overridden by setting the `sep` argument. - """ - - pattern: str | re.Pattern - sep: str = "/" - - def __post_init__(self): - if not isinstance(self.pattern, re.Pattern): - object.__setattr__(self, "pattern", re.compile(self.pattern)) - - def __call__(self, path: nnx.filterlib.PathParts, x: Any) -> bool: - joined_path = self.sep.join(str(x) for x in path) - assert isinstance(self.pattern, re.Pattern) - return self.pattern.fullmatch(joined_path) is not None - - -def state_map(state: nnx.State, filter: nnx.filterlib.Filter, fn: Callable[[Any], Any]) -> nnx.State: - """Apply a function to the leaves of the state that match the filter.""" - filtered_keys = set(state.filter(filter).flat_state()) - return state.map(lambda k, v: fn(v) if k in filtered_keys else v) diff --git a/capvector-pi05/src/openpi/shared/normalize.py b/capvector-pi05/src/openpi/shared/normalize.py deleted file mode 100644 index f4bf6100fb506a21c064cae725eb04fea0e00017..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/shared/normalize.py +++ /dev/null @@ -1,146 +0,0 @@ -import json -import pathlib - -import numpy as np -import numpydantic -import pydantic - - -@pydantic.dataclasses.dataclass -class NormStats: - mean: numpydantic.NDArray - std: numpydantic.NDArray - q01: numpydantic.NDArray | None = None # 1st quantile - q99: numpydantic.NDArray | None = None # 99th quantile - - -class RunningStats: - """Compute running statistics of a batch of vectors.""" - - def __init__(self): - self._count = 0 - self._mean = None - self._mean_of_squares = None - self._min = None - self._max = None - self._histograms = None - self._bin_edges = None - self._num_quantile_bins = 5000 # for computing quantiles on the fly - - def update(self, batch: np.ndarray) -> None: - """ - Update the running statistics with a batch of vectors. - - Args: - vectors (np.ndarray): An array where all dimensions except the last are batch dimensions. - """ - batch = batch.reshape(-1, batch.shape[-1]) - num_elements, vector_length = batch.shape - if self._count == 0: - self._mean = np.mean(batch, axis=0) - self._mean_of_squares = np.mean(batch**2, axis=0) - self._min = np.min(batch, axis=0) - self._max = np.max(batch, axis=0) - self._histograms = [np.zeros(self._num_quantile_bins) for _ in range(vector_length)] - self._bin_edges = [ - np.linspace(self._min[i] - 1e-10, self._max[i] + 1e-10, self._num_quantile_bins + 1) - for i in range(vector_length) - ] - else: - if vector_length != self._mean.size: - raise ValueError("The length of new vectors does not match the initialized vector length.") - new_max = np.max(batch, axis=0) - new_min = np.min(batch, axis=0) - max_changed = np.any(new_max > self._max) - min_changed = np.any(new_min < self._min) - self._max = np.maximum(self._max, new_max) - self._min = np.minimum(self._min, new_min) - - if max_changed or min_changed: - self._adjust_histograms() - - self._count += num_elements - - batch_mean = np.mean(batch, axis=0) - batch_mean_of_squares = np.mean(batch**2, axis=0) - - # Update running mean and mean of squares. - self._mean += (batch_mean - self._mean) * (num_elements / self._count) - self._mean_of_squares += (batch_mean_of_squares - self._mean_of_squares) * (num_elements / self._count) - - self._update_histograms(batch) - - def get_statistics(self) -> NormStats: - """ - Compute and return the statistics of the vectors processed so far. - - Returns: - dict: A dictionary containing the computed statistics. - """ - if self._count < 2: - raise ValueError("Cannot compute statistics for less than 2 vectors.") - - variance = self._mean_of_squares - self._mean**2 - stddev = np.sqrt(np.maximum(0, variance)) - q01, q99 = self._compute_quantiles([0.01, 0.99]) - return NormStats(mean=self._mean, std=stddev, q01=q01, q99=q99) - - def _adjust_histograms(self): - """Adjust histograms when min or max changes.""" - for i in range(len(self._histograms)): - old_edges = self._bin_edges[i] - new_edges = np.linspace(self._min[i], self._max[i], self._num_quantile_bins + 1) - - # Redistribute the existing histogram counts to the new bins - new_hist, _ = np.histogram(old_edges[:-1], bins=new_edges, weights=self._histograms[i]) - - self._histograms[i] = new_hist - self._bin_edges[i] = new_edges - - def _update_histograms(self, batch: np.ndarray) -> None: - """Update histograms with new vectors.""" - for i in range(batch.shape[1]): - hist, _ = np.histogram(batch[:, i], bins=self._bin_edges[i]) - self._histograms[i] += hist - - def _compute_quantiles(self, quantiles): - """Compute quantiles based on histograms.""" - results = [] - for q in quantiles: - target_count = q * self._count - q_values = [] - for hist, edges in zip(self._histograms, self._bin_edges, strict=True): - cumsum = np.cumsum(hist) - idx = np.searchsorted(cumsum, target_count) - q_values.append(edges[idx]) - results.append(np.array(q_values)) - return results - - -class _NormStatsDict(pydantic.BaseModel): - norm_stats: dict[str, NormStats] - - -def serialize_json(norm_stats: dict[str, NormStats]) -> str: - """Serialize the running statistics to a JSON string.""" - return _NormStatsDict(norm_stats=norm_stats).model_dump_json(indent=2) - - -def deserialize_json(data: str) -> dict[str, NormStats]: - """Deserialize the running statistics from a JSON string.""" - return _NormStatsDict(**json.loads(data)).norm_stats - - -def save(directory: pathlib.Path | str, norm_stats: dict[str, NormStats]) -> None: - """Save the normalization stats to a directory.""" - path = pathlib.Path(directory) / "norm_stats.json" - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(serialize_json(norm_stats)) - - -def load(directory: pathlib.Path | str) -> dict[str, NormStats]: - """Load the normalization stats from a directory.""" - path = pathlib.Path(directory) / "norm_stats.json" - if not path.exists(): - raise FileNotFoundError(f"Norm stats file not found at: {path}") - return deserialize_json(path.read_text()) diff --git a/capvector-pi05/src/openpi/shared/normalize_test.py b/capvector-pi05/src/openpi/shared/normalize_test.py deleted file mode 100644 index ab0aa15f9f3f0fd3b8aac6fea2c864d32be00c9b..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/shared/normalize_test.py +++ /dev/null @@ -1,43 +0,0 @@ -import numpy as np - -import openpi.shared.normalize as normalize - - -def test_normalize_update(): - arr = np.arange(12).reshape(4, 3) # 4 vectors of length 3 - - stats = normalize.RunningStats() - for i in range(len(arr)): - stats.update(arr[i : i + 1]) # Update with one vector at a time - results = stats.get_statistics() - - assert np.allclose(results.mean, np.mean(arr, axis=0)) - assert np.allclose(results.std, np.std(arr, axis=0)) - - -def test_serialize_deserialize(): - stats = normalize.RunningStats() - stats.update(np.arange(12).reshape(4, 3)) # 4 vectors of length 3 - - norm_stats = {"test": stats.get_statistics()} - norm_stats2 = normalize.deserialize_json(normalize.serialize_json(norm_stats)) - assert np.allclose(norm_stats["test"].mean, norm_stats2["test"].mean) - assert np.allclose(norm_stats["test"].std, norm_stats2["test"].std) - - -def test_multiple_batch_dimensions(): - # Test with multiple batch dimensions: (2, 3, 4) where 4 is vector dimension - batch_shape = (2, 3, 4) - arr = np.random.rand(*batch_shape) - - stats = normalize.RunningStats() - stats.update(arr) # Should handle (2, 3, 4) -> reshape to (6, 4) - results = stats.get_statistics() - - # Flatten batch dimensions and compute expected stats - flattened = arr.reshape(-1, arr.shape[-1]) # (6, 4) - expected_mean = np.mean(flattened, axis=0) - expected_std = np.std(flattened, axis=0) - - assert np.allclose(results.mean, expected_mean) - assert np.allclose(results.std, expected_std) diff --git a/capvector-pi05/src/openpi/training/checkpoints.py b/capvector-pi05/src/openpi/training/checkpoints.py deleted file mode 100644 index 0f53c71d7e35b72e7d07c957fdb8010f09e23396..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/training/checkpoints.py +++ /dev/null @@ -1,159 +0,0 @@ -from __future__ import annotations - -import asyncio -import concurrent.futures as futures -import dataclasses -import logging -from typing import Protocol - -from etils import epath -import jax -import orbax.checkpoint as ocp -import orbax.checkpoint.future as future - -from openpi.shared import array_typing as at -import openpi.shared.normalize as _normalize -import openpi.training.data_loader as _data_loader -import openpi.training.utils as training_utils - - -def initialize_checkpoint_dir( - checkpoint_dir: epath.Path | str, *, keep_period: int | None, overwrite: bool, resume: bool -) -> tuple[ocp.CheckpointManager, bool]: - checkpoint_dir = epath.Path(checkpoint_dir).resolve() - resuming = False - if checkpoint_dir.exists(): - if overwrite: - checkpoint_dir.rmtree() - checkpoint_dir.mkdir(parents=True, exist_ok=True) - logging.info(f"Wiped checkpoint directory {checkpoint_dir}") - elif resume: - resuming = True - else: - raise FileExistsError( - f"Checkpoint directory {checkpoint_dir} already exists. Use --overwrite or --resume " - "to indicate how to handle it." - ) - - checkpoint_dir.mkdir(parents=True, exist_ok=True) - - mngr = ocp.CheckpointManager( - checkpoint_dir, - item_handlers={ - "assets": CallbackHandler(), - "train_state": ocp.PyTreeCheckpointHandler(), - "params": ocp.PyTreeCheckpointHandler(), - }, - options=ocp.CheckpointManagerOptions( - max_to_keep=1, - keep_period=keep_period, - create=False, - async_options=ocp.AsyncOptions(timeout_secs=7200), - ), - ) - - # Special case: the checkpoint directory exists and the user requests to resume training, but the training run did - # not get to the first checkpoint saved. In this case, we don't actually want the train script to try and restore a - # checkpoint, since it will fail. - if resuming and tuple(mngr.all_steps()) in [(), (0,)]: - logging.info("Checkpoint directory exists, but does not contain any checkpoints. Aborting resume.") - resuming = False - - return mngr, resuming - - -def save_state( - checkpoint_manager: ocp.CheckpointManager, - state: training_utils.TrainState, - data_loader: _data_loader.DataLoader, - step: int, -): - def save_assets(directory: epath.Path): - # Save the normalization stats. - data_config = data_loader.data_config() - norm_stats = data_config.norm_stats - if norm_stats is not None and data_config.asset_id is not None: - _normalize.save(directory / data_config.asset_id, norm_stats) - - # Split params that can be used for inference into a separate item. - with at.disable_typechecking(): - train_state, params = _split_params(state) - items = { - "assets": save_assets, - "train_state": train_state, - "params": {"params": params}, - } - checkpoint_manager.save(step, items) - - -def restore_state( - checkpoint_manager: ocp.CheckpointManager, - state: training_utils.TrainState, - data_loader: _data_loader.DataLoader, - step: int | None = None, -) -> training_utils.TrainState: - del data_loader - - with at.disable_typechecking(): - # Split params that can be used for inference into a separate item. - train_state, params = _split_params(state) - restored = checkpoint_manager.restore( - step, - items={ - "train_state": train_state, - "params": {"params": params}, - }, - ) - return _merge_params(restored["train_state"], restored["params"]) - - -def load_norm_stats(assets_dir: epath.Path | str, asset_id: str) -> dict[str, _normalize.NormStats] | None: - norm_stats_dir = epath.Path(assets_dir) / asset_id - norm_stats = _normalize.load(norm_stats_dir) - logging.info(f"Loaded norm stats from {norm_stats_dir}") - return norm_stats - - -class Callback(Protocol): - def __call__(self, directory: epath.Path) -> None: ... - - -class CallbackHandler(ocp.AsyncCheckpointHandler): - """A CheckpointHandler for calling an arbitrary function asynchronously. Only for saving, not for restoring.""" - - def save(self, directory: epath.Path, args: CallbackSave): - if jax.process_index() == 0: - args.callback(directory) - - async def async_save(self, directory: epath.Path, args: CallbackSave) -> list[futures.Future]: - return [future.CommitFutureAwaitingContractedSignals(asyncio.to_thread(self.save, directory, args))] - - def restore(self, *args, **kwargs): - raise NotImplementedError("CallbackHandler does not support restore") - - -@ocp.args.register_with_handler(CallbackHandler, for_save=True) -@dataclasses.dataclass -class CallbackSave(ocp.args.CheckpointArgs): - callback: Callback - - -@ocp.args.register_with_handler(CallbackHandler, for_restore=True) -class CallbackRestore(ocp.args.CheckpointArgs): ... - - -def _split_params(state: training_utils.TrainState) -> tuple[training_utils.TrainState, at.Params]: - if state.ema_params is not None: - params = state.ema_params - train_state = dataclasses.replace(state, ema_params=None) - else: - params = state.params - train_state = dataclasses.replace(state, params={}) - return train_state, params - - -def _merge_params(train_state: training_utils.TrainState, params: dict[str, at.Params]) -> training_utils.TrainState: - # Revert the logic inside `_split_params`. Assumes that existence of `params` means that EMA params were used during the split. - if train_state.params: - return dataclasses.replace(train_state, ema_params=params["params"]) - return dataclasses.replace(train_state, params=params["params"]) diff --git a/capvector-pi05/src/openpi/training/config.py b/capvector-pi05/src/openpi/training/config.py deleted file mode 100644 index 51690fa0a2385514eddde0a5716fcda15e5dcc2b..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/training/config.py +++ /dev/null @@ -1,1033 +0,0 @@ -"""See _CONFIGS for the list of available configs.""" - -import abc -from collections.abc import Sequence -import dataclasses -import difflib -import logging -import pathlib -from typing import Any, Literal, Protocol, TypeAlias - -import etils.epath as epath -import flax.nnx as nnx -from typing_extensions import override -import tyro - -import openpi.models.model as _model -import openpi.models.pi0_config as pi0_config -import openpi.models.pi0_fast as pi0_fast -import openpi.models.tokenizer as _tokenizer -import openpi.policies.aloha_policy as aloha_policy -import openpi.policies.droid_policy as droid_policy -import openpi.policies.libero_policy as libero_policy -import openpi.shared.download as _download -import openpi.shared.normalize as _normalize -import openpi.training.droid_rlds_dataset as droid_rlds_dataset -import openpi.training.misc.roboarena_config as roboarena_config -import openpi.training.optimizer as _optimizer -import openpi.training.weight_loaders as weight_loaders -import openpi.transforms as _transforms - -ModelType: TypeAlias = _model.ModelType -# Work around a tyro issue with using nnx.filterlib.Filter directly. -Filter: TypeAlias = nnx.filterlib.Filter - - -@dataclasses.dataclass(frozen=True) -class AssetsConfig: - """Determines the location of assets (e.g., norm stats) that will be used to set up the data pipeline. - - These assets will be replicated inside the checkpoint under the `assets/asset_id` directory. - - This can be used to load assets from a different checkpoint (e.g., base model checkpoint) or some other - centralized location. For example, to load the norm stats for the Trossen robot from the base model checkpoint - during fine-tuning, use: - - ``` - AssetsConfig( - assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets", - asset_id="trossen", - ) - ``` - """ - - # Assets directory. If not provided, the config assets_dirs will be used. This is useful to load assets from - # a different checkpoint (e.g., base model checkpoint) or some other centralized location. - assets_dir: str | None = None - - # Asset id. If not provided, the repo id will be used. This allows users to reference assets that describe - # different robot platforms. - asset_id: str | None = None - - -@dataclasses.dataclass(frozen=True) -class DataConfig: - # LeRobot repo id. If None, fake data will be created. - repo_id: str | None = None - # Directory within the assets directory containing the data assets. - asset_id: str | None = None - # Contains precomputed normalization stats. If None, normalization will not be performed. - norm_stats: dict[str, _transforms.NormStats] | None = None - - # Used to adopt the inputs from a dataset specific format to a common format - # which is expected by the data transforms. - repack_transforms: _transforms.Group = dataclasses.field(default_factory=_transforms.Group) - # Data transforms, typically include robot specific transformations. Will be applied - # before the data is normalized. See `model.Observation` and `model.Actions` to learn about the - # normalized data. - data_transforms: _transforms.Group = dataclasses.field(default_factory=_transforms.Group) - # Model specific transforms. Will be applied after the data is normalized. - model_transforms: _transforms.Group = dataclasses.field(default_factory=_transforms.Group) - # If true, will use quantile normalization. Otherwise, normal z-score normalization will be used. - use_quantile_norm: bool = False - - # Names of keys that will be used by the data loader to generate the action sequence. The length of the - # sequence is defined by the `action_horizon` field in the model config. This should be adjusted if your - # LeRobot dataset is using different keys to represent the action. - action_sequence_keys: Sequence[str] = ("actions",) - - # If true, will use the LeRobot dataset task to define the prompt. - prompt_from_task: bool = False - - # Only used for RLDS data loader (ie currently only used for DROID). - rlds_data_dir: str | None = None - # Action space for DROID dataset. - action_space: droid_rlds_dataset.DroidActionSpace | None = None - # Path to the data filter file for DROID dataset - filter_dict_path: str | None = None - - -class GroupFactory(Protocol): - def __call__(self, model_config: _model.BaseModelConfig) -> _transforms.Group: - """Create a group.""" - - -@dataclasses.dataclass(frozen=True) -class ModelTransformFactory(GroupFactory): - """Creates model transforms for standard pi0 models.""" - - # If provided, will determine the default prompt that be used by the model. - default_prompt: str | None = None - - def __call__(self, model_config: _model.BaseModelConfig) -> _transforms.Group: - match model_config.model_type: - case _model.ModelType.PI0: - return _transforms.Group( - inputs=[ - _transforms.InjectDefaultPrompt(self.default_prompt), - _transforms.ResizeImages(224, 224), - _transforms.TokenizePrompt( - _tokenizer.PaligemmaTokenizer(model_config.max_token_len), - ), - _transforms.PadStatesAndActions(model_config.action_dim), - ], - ) - case _model.ModelType.PI05: - assert isinstance(model_config, pi0_config.Pi0Config) - return _transforms.Group( - inputs=[ - _transforms.InjectDefaultPrompt(self.default_prompt), - _transforms.ResizeImages(224, 224), - _transforms.TokenizePrompt( - _tokenizer.PaligemmaTokenizer(model_config.max_token_len), - discrete_state_input=model_config.discrete_state_input, - ), - _transforms.PadStatesAndActions(model_config.action_dim), - ], - ) - case _model.ModelType.PI0_FAST: - tokenizer_cls = ( - _tokenizer.FASTTokenizer - if model_config.fast_model_tokenizer is None - else model_config.fast_model_tokenizer - ) - tokenizer_kwargs = ( - {} if model_config.fast_model_tokenizer_kwargs is None else model_config.fast_model_tokenizer_kwargs - ) - return _transforms.Group( - inputs=[ - _transforms.InjectDefaultPrompt(self.default_prompt), - _transforms.ResizeImages(224, 224), - _transforms.TokenizeFASTInputs( - tokenizer_cls(model_config.max_token_len, **tokenizer_kwargs), - ), - ], - outputs=[ - _transforms.ExtractFASTActions( - tokenizer_cls(model_config.max_token_len, **tokenizer_kwargs), - action_horizon=model_config.action_horizon, - action_dim=model_config.action_dim, - ) - ], - ) - - -@dataclasses.dataclass(frozen=True) -class DataConfigFactory(abc.ABC): - # The LeRobot repo id. - repo_id: str = tyro.MISSING - # Determines how the assets will be loaded. - assets: AssetsConfig = dataclasses.field(default_factory=AssetsConfig) - # Base config that will be updated by the factory. - base_config: tyro.conf.Suppress[DataConfig | None] = None - - @abc.abstractmethod - def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig: - """Create a data config.""" - - def create_base_config(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig: - repo_id = self.repo_id if self.repo_id is not tyro.MISSING else None - asset_id = self.assets.asset_id or repo_id - return dataclasses.replace( - self.base_config or DataConfig(), - repo_id=repo_id, - asset_id=asset_id, - norm_stats=self._load_norm_stats(epath.Path(self.assets.assets_dir or assets_dirs), asset_id), - use_quantile_norm=model_config.model_type != ModelType.PI0, - ) - - def _load_norm_stats(self, assets_dir: epath.Path, asset_id: str | None) -> dict[str, _transforms.NormStats] | None: - if asset_id is None: - return None - try: - data_assets_dir = str(assets_dir / asset_id) - norm_stats = _normalize.load(_download.maybe_download(data_assets_dir)) - logging.info(f"Loaded norm stats from {data_assets_dir}") - return norm_stats - except FileNotFoundError: - logging.info(f"Norm stats not found in {data_assets_dir}, skipping.") - return None - - -@dataclasses.dataclass(frozen=True) -class FakeDataConfig(DataConfigFactory): - repo_id: str = "fake" - - @override - def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig: - return DataConfig(repo_id=self.repo_id) - - -@dataclasses.dataclass(frozen=True) -class SimpleDataConfig(DataConfigFactory): - # Factory for the data transforms. - data_transforms: tyro.conf.Suppress[GroupFactory] = dataclasses.field(default_factory=GroupFactory) - # Factory for the model transforms. - model_transforms: tyro.conf.Suppress[GroupFactory] = dataclasses.field(default_factory=ModelTransformFactory) - - @override - def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig: - return dataclasses.replace( - self.create_base_config(assets_dirs, model_config), - data_transforms=self.data_transforms(model_config), - model_transforms=self.model_transforms(model_config), - ) - - -@dataclasses.dataclass(frozen=True) -class LeRobotAlohaDataConfig(DataConfigFactory): - # If true, will convert joint dimensions to deltas with respect to the current state before passing to the model. - # Gripper dimensions will remain in absolute values. - use_delta_joint_actions: bool = True - # If provided, will be injected into the input data if the "prompt" key is not present. - default_prompt: str | None = None - # If true, this will convert the joint and gripper values from the standard Aloha space to - # the space used by the pi internal runtime which was used to train the base model. People who - # use standard Aloha data should set this to true. - adapt_to_pi: bool = True - - # Repack transforms. - repack_transforms: tyro.conf.Suppress[_transforms.Group] = dataclasses.field( - default=_transforms.Group( - inputs=[ - _transforms.RepackTransform( - { - "images": {"cam_high": "observation.images.top"}, - "state": "observation.state", - "actions": "action", - } - ) - ] - ) - ) - # Action keys that will be used to read the action sequence from the dataset. - action_sequence_keys: Sequence[str] = ("action",) - - @override - def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig: - data_transforms = _transforms.Group( - inputs=[aloha_policy.AlohaInputs(adapt_to_pi=self.adapt_to_pi)], - outputs=[aloha_policy.AlohaOutputs(adapt_to_pi=self.adapt_to_pi)], - ) - if self.use_delta_joint_actions: - delta_action_mask = _transforms.make_bool_mask(6, -1, 6, -1) - data_transforms = data_transforms.push( - inputs=[_transforms.DeltaActions(delta_action_mask)], - outputs=[_transforms.AbsoluteActions(delta_action_mask)], - ) - - model_transforms = ModelTransformFactory(default_prompt=self.default_prompt)(model_config) - - return dataclasses.replace( - self.create_base_config(assets_dirs, model_config), - repack_transforms=self.repack_transforms, - data_transforms=data_transforms, - model_transforms=model_transforms, - action_sequence_keys=self.action_sequence_keys, - ) - - -@dataclasses.dataclass(frozen=True) -class LeRobotLiberoDataConfig(DataConfigFactory): - """ - This config is used to configure transforms that are applied at various parts of the data pipeline. - For your own dataset, you can copy this class and modify the transforms to match your dataset based on the - comments below. - """ - - extra_delta_transform: bool = False - - @override - def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig: - # The repack transform is *only* applied to the data coming from the dataset, - # and *not* during inference. We can use it to make inputs from the dataset look - # as close as possible to those coming from the inference environment (e.g. match the keys). - # Below, we match the keys in the dataset (which we defined in the data conversion script) to - # the keys we use in our inference pipeline (defined in the inference script for libero). - # For your own dataset, first figure out what keys your environment passes to the policy server - # and then modify the mappings below so your dataset's keys get matched to those target keys. - # The repack transform simply remaps key names here. - repack_transform = _transforms.Group( - inputs=[ - _transforms.RepackTransform( - { - "observation/image": "image", - "observation/wrist_image": "wrist_image", - "observation/state": "state", - "actions": "actions", - "prompt": "prompt", - } - ) - ] - ) - - # The data transforms are applied to the data coming from the dataset *and* during inference. - # Below, we define the transforms for data going into the model (``inputs``) and the transforms - # for data coming out of the model (``outputs``) (the latter is only used during inference). - # We defined these transforms in `libero_policy.py`. You can check the detailed comments there for - # how to modify the transforms to match your dataset. Once you created your own transforms, you can - # replace the transforms below with your own. - data_transforms = _transforms.Group( - inputs=[libero_policy.LiberoInputs(model_type=model_config.model_type)], - outputs=[libero_policy.LiberoOutputs()], - ) - - # One additional data transform: pi0 models are trained on delta actions (relative to the first - # state in each action chunk). IF your data has ``absolute`` actions (e.g. target joint angles) - # you can uncomment the following line to convert the actions to delta actions. The only exception - # is for the gripper actions which are always absolute. - # In the example below, we would apply the delta conversion to the first 6 actions (joints) and - # leave the 7th action (gripper) unchanged, i.e. absolute. - # In Libero, the raw actions in the dataset are already delta actions, so we *do not* need to - # apply a separate delta conversion (that's why it's commented out). Choose whether to apply this - # transform based on whether your dataset uses ``absolute`` or ``delta`` actions out of the box. - - # LIBERO already represents actions as deltas, but we have some old Pi0 checkpoints that are trained with this - # extra delta transform. - if self.extra_delta_transform: - delta_action_mask = _transforms.make_bool_mask(6, -1) - data_transforms = data_transforms.push( - inputs=[_transforms.DeltaActions(delta_action_mask)], - outputs=[_transforms.AbsoluteActions(delta_action_mask)], - ) - - # Model transforms include things like tokenizing the prompt and action targets - # You do not need to change anything here for your own dataset. - model_transforms = ModelTransformFactory()(model_config) - - # We return all data transforms for training and inference. No need to change anything here. - return dataclasses.replace( - self.create_base_config(assets_dirs, model_config), - repack_transforms=repack_transform, - data_transforms=data_transforms, - model_transforms=model_transforms, - ) - - -@dataclasses.dataclass(frozen=True) -class RLDSDroidDataConfig(DataConfigFactory): - """ - Config for training on DROID, using RLDS data format (for efficient training on larger datasets). - """ - - rlds_data_dir: str | None = None - action_space: droid_rlds_dataset.DroidActionSpace | None = None - - # Filtering options. Can pass a path to a dictionary that maps episodes to timestep ranges - # to tuples denoting ranges of time steps to keep (start, end). Episodes are uniquely identified with - # f"{recording_folderpath}--{file_path}", both of which are present in the RLDS episode metadata. - # Path to the filter dictionary file. - filter_dict_path: str | None = "gs://openpi-assets/droid/droid_sample_ranges_v1_0_1.json" - - @override - def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig: - repack_transform = _transforms.Group( - inputs=[ - _transforms.RepackTransform( - { - "observation/exterior_image_1_left": "observation/image", - "observation/wrist_image_left": "observation/wrist_image", - "observation/joint_position": "observation/joint_position", - "observation/gripper_position": "observation/gripper_position", - "actions": "actions", - "prompt": "prompt", - } - ) - ] - ) - - data_transforms = _transforms.Group( - inputs=[droid_policy.DroidInputs(model_type=model_config.model_type)], - outputs=[droid_policy.DroidOutputs()], - ) - - if self.action_space == droid_rlds_dataset.DroidActionSpace.JOINT_POSITION: - # Data loader returns absolute joint position actions -- convert to delta actions for training. - delta_action_mask = _transforms.make_bool_mask(7, -1) - data_transforms = data_transforms.push( - inputs=[_transforms.DeltaActions(delta_action_mask)], - outputs=[_transforms.AbsoluteActions(delta_action_mask)], - ) - - model_transforms = ModelTransformFactory()(model_config) - - assert self.rlds_data_dir is not None, "Need to set rlds data dir for RLDS data loader." - - return dataclasses.replace( - self.create_base_config(assets_dirs, model_config), - repack_transforms=repack_transform, - data_transforms=data_transforms, - model_transforms=model_transforms, - rlds_data_dir=self.rlds_data_dir, - action_space=self.action_space, - filter_dict_path=self.filter_dict_path, - ) - - -@dataclasses.dataclass(frozen=True) -class LeRobotDROIDDataConfig(DataConfigFactory): - """ - Example data config for custom DROID dataset in LeRobot format. - To convert your custom DROID dataset (<10s of hours) to LeRobot format, see examples/droid/convert_droid_data_to_lerobot.py - """ - - @override - def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig: - repack_transform = _transforms.Group( - inputs=[ - _transforms.RepackTransform( - { - "observation/exterior_image_1_left": "exterior_image_1_left", - "observation/exterior_image_2_left": "exterior_image_2_left", - "observation/wrist_image_left": "wrist_image_left", - "observation/joint_position": "joint_position", - "observation/gripper_position": "gripper_position", - "actions": "actions", - "prompt": "prompt", - } - ) - ] - ) - # We assume joint *velocity* actions, so we should *not* apply an additional delta transform. - data_transforms = _transforms.Group( - inputs=[droid_policy.DroidInputs(model_type=model_config.model_type)], - outputs=[droid_policy.DroidOutputs()], - ) - model_transforms = ModelTransformFactory()(model_config) - - return dataclasses.replace( - self.create_base_config(assets_dirs, model_config), - repack_transforms=repack_transform, - data_transforms=data_transforms, - model_transforms=model_transforms, - ) - - -@dataclasses.dataclass(frozen=True) -class TrainConfig: - # Name of the config. Must be unique. Will be used to reference this config. - name: tyro.conf.Suppress[str] - # Project name. - project_name: str = "openpi" - # Experiment name. Will be used to name the metadata and checkpoint directories. - exp_name: str = tyro.MISSING - - # Defines the model config. Some attributes (action_dim, action_horizon, and max_token_len) are shared by all models - # -- see BaseModelConfig. Specific model implementations (e.g., Pi0Config) inherit from BaseModelConfig and may - # define additional attributes. - model: _model.BaseModelConfig = dataclasses.field(default_factory=pi0_config.Pi0Config) - - # A weight loader can optionally load (possibly partial) weights from disk after the model is initialized. - weight_loader: weight_loaders.WeightLoader = dataclasses.field(default_factory=weight_loaders.NoOpWeightLoader) - - # Optional path to a PyTorch checkpoint to load weights from. - pytorch_weight_path: str | None = None - - # Spatial Forcing configs - vggt_weight_path: str | None = None - vggt_dim: int = 1024 - - vla_layers_align: int | None = None # total 18 for paligemma-2b - vggt_layers_align: int | None = None # total 24 for VGGT - - pooling_func: str | None = None - use_vggt_pe: bool | None = None - use_vlm_norm: bool | None = None - - align_loss_coeff: float = 0.0 - - # CapVector configs - regularization_vector_path: str | None = None - regularization_coeff: float = 0.0 - - # Precision for PyTorch training. - pytorch_training_precision: Literal["bfloat16", "float32"] = "bfloat16" - - lr_schedule: _optimizer.LRScheduleConfig = dataclasses.field(default_factory=_optimizer.CosineDecaySchedule) - optimizer: _optimizer.OptimizerConfig = dataclasses.field(default_factory=_optimizer.AdamW) - ema_decay: float | None = 0.99 - - # Specifies which weights should be frozen. - freeze_filter: tyro.conf.Suppress[Filter] = dataclasses.field(default_factory=nnx.Nothing) - - # Determines the data to be trained on. - data: DataConfigFactory = dataclasses.field(default_factory=FakeDataConfig) - - # Base directory for config assets (e.g., norm stats). - assets_base_dir: str = "./assets" - # Base directory for checkpoints. - checkpoint_base_dir: str = "./checkpoints" - - # Random seed that will be used by random generators during training. - seed: int = 42 - # Global batch size. - batch_size: int = 32 - # Number of workers to use for the data loader. Increasing this number will speed up data loading but - # will increase memory and CPU usage. - num_workers: int = 2 - # Number of train steps (batches) to run. - num_train_steps: int = 30_000 - - # How often (in steps) to log training metrics. - log_interval: int = 100 - # How often (in steps) to save checkpoints. - save_interval: int = 1000 - # If set, any existing checkpoints matching step % keep_period == 0 will not be deleted. - keep_period: int | None = 5000 - - # If true, will overwrite the checkpoint directory if it already exists. - overwrite: bool = False - # If true, will resume training from the last checkpoint. - resume: bool = False - - # If true, will enable wandb logging. - wandb_enabled: bool = True - - # Used to pass metadata to the policy server. - policy_metadata: dict[str, Any] | None = None - - # If the value is greater than 1, FSDP will be enabled and shard across number of specified devices; overall - # device memory will be reduced but training could potentially be slower. - # eg. if total device is 4 and fsdp devices is 2; then the model will shard to 2 devices and run - # data parallel between 2 groups of devices. - fsdp_devices: int = 1 - - @property - def assets_dirs(self) -> pathlib.Path: - """Get the assets directory for this config.""" - return (pathlib.Path(self.assets_base_dir) / self.name).resolve() - - @property - def checkpoint_dir(self) -> pathlib.Path: - """Get the checkpoint directory for this config.""" - if not self.exp_name: - raise ValueError("--exp_name must be set") - return (pathlib.Path(self.checkpoint_base_dir) / self.name / self.exp_name).resolve() - - @property - def trainable_filter(self) -> nnx.filterlib.Filter: - """Get the filter for the trainable parameters.""" - return nnx.All(nnx.Param, nnx.Not(self.freeze_filter)) - - def __post_init__(self) -> None: - if self.resume and self.overwrite: - raise ValueError("Cannot resume and overwrite at the same time.") - - -# Use `get_config` if you need to get a config by name in your code. -_CONFIGS = [ - # - # Inference Aloha configs. - # - TrainConfig( - name="pi0_aloha", - model=pi0_config.Pi0Config(), - data=LeRobotAlohaDataConfig( - assets=AssetsConfig(asset_id="trossen"), - ), - policy_metadata={"reset_pose": [0, -1.5, 1.5, 0, 0, 0]}, - ), - TrainConfig( - name="pi05_aloha", - model=pi0_config.Pi0Config(pi05=True), - data=LeRobotAlohaDataConfig( - assets=AssetsConfig(asset_id="trossen"), - ), - policy_metadata={"reset_pose": [0, -1.5, 1.5, 0, 0, 0]}, - ), - TrainConfig( - name="pi0_aloha_towel", - model=pi0_config.Pi0Config(), - data=LeRobotAlohaDataConfig( - assets=AssetsConfig(asset_id="trossen"), - default_prompt="fold the towel", - ), - policy_metadata={"reset_pose": [0, -1.5, 1.5, 0, 0, 0]}, - ), - TrainConfig( - name="pi0_aloha_tupperware", - model=pi0_config.Pi0Config(), - data=LeRobotAlohaDataConfig( - assets=AssetsConfig(asset_id="trossen"), - default_prompt="open the tupperware and put the food on the plate", - ), - policy_metadata={"reset_pose": [0, -1.5, 1.5, 0, 0, 0]}, - ), - # - # Inference DROID configs. - # - TrainConfig( - name="pi0_droid", - model=pi0_config.Pi0Config(action_horizon=10), - data=SimpleDataConfig( - assets=AssetsConfig(asset_id="droid"), - data_transforms=lambda model: _transforms.Group( - inputs=[droid_policy.DroidInputs(model_type=ModelType.PI0)], - outputs=[droid_policy.DroidOutputs()], - ), - base_config=DataConfig( - prompt_from_task=True, - ), - ), - ), - TrainConfig( - name="pi0_fast_droid", - model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=10), - data=SimpleDataConfig( - assets=AssetsConfig(asset_id="droid"), - data_transforms=lambda model: _transforms.Group( - inputs=[droid_policy.DroidInputs(model_type=ModelType.PI0_FAST)], - outputs=[droid_policy.DroidOutputs()], - ), - base_config=DataConfig( - prompt_from_task=True, - ), - ), - ), - TrainConfig( - name="pi05_droid", - model=pi0_config.Pi0Config(action_horizon=15, pi05=True), - data=SimpleDataConfig( - assets=AssetsConfig(asset_id="droid"), - data_transforms=lambda model: _transforms.Group( - inputs=[droid_policy.DroidInputs(model_type=ModelType.PI05)], - outputs=[droid_policy.DroidOutputs()], - ), - base_config=DataConfig( - prompt_from_task=True, - ), - ), - ), - # - # Fine-tuning Libero configs. - # - # These train configs define the hyperparameters for fine-tuning the base model on your own dataset. - # They are used to define key elements like the dataset you are training on, the base checkpoint you - # are using, and other hyperparameters like how many training steps to run or what learning rate to use. - # For your own dataset, you can copy this class and modify the dataset name, and data transforms based on - # the comments below. - TrainConfig( - # Change the name to reflect your model and dataset. - name="pi0_libero", - # Here you define the model config -- In this example we use pi0 as the model - # architecture and perform *full* finetuning. in the examples below we show how to modify - # this to perform *low-memory* (LORA) finetuning and use pi0-FAST as an alternative architecture. - model=pi0_config.Pi0Config(), - # Here you define the dataset you are training on. In this example we use the Libero - # dataset. For your own dataset, you can change the repo_id to point to your dataset. - # Also modify the DataConfig to use the new config you made for your dataset above. - data=LeRobotLiberoDataConfig( - repo_id="physical-intelligence/libero", - base_config=DataConfig( - # This flag determines whether we load the prompt (i.e. the task instruction) from the - # ``task`` field in the LeRobot dataset. If set to True, the prompt will show up in - # a field called ``prompt`` in the input dict. The recommended setting is True. - prompt_from_task=True, - ), - extra_delta_transform=True, - ), - # Here you define which pre-trained checkpoint you want to load to initialize the model. - # This should match the model config you chose above -- i.e. in this case we use the pi0 base model. - weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"), - # Below you can define other hyperparameters like the learning rate, number of training steps, etc. - # Check the base TrainConfig class for a full list of available hyperparameters. - num_train_steps=30_000, - ), - TrainConfig( - name="pi0_libero_low_mem_finetune", - # Here is an example of loading a pi0 model for LoRA fine-tuning. - model=pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora"), - data=LeRobotLiberoDataConfig( - repo_id="physical-intelligence/libero", - base_config=DataConfig(prompt_from_task=True), - extra_delta_transform=True, - ), - weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"), - num_train_steps=30_000, - # The freeze filter defines which parameters should be frozen during training. - # We have a convenience function in the model config that returns the default freeze filter - # for the given model config for LoRA finetuning. Just make sure it matches the model config - # you chose above. - freeze_filter=pi0_config.Pi0Config( - paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora" - ).get_freeze_filter(), - # Turn off EMA for LoRA finetuning. - ema_decay=None, - ), - TrainConfig( - name="pi0_fast_libero", - # Here is an example of loading a pi0-FAST model for full finetuning. - # Modify action_dim and action_horizon to match your dataset (action horizon is equal to - # the desired action chunk length). - # The max_token_len is the maximum number of (non-image) tokens the model can handle. - # This includes the tokenized prompt, proprioceptive state, and (FAST-tokenized) action tokens. - # Choosing this value too small may chop off tokens at the end of your sequence (the code will throw - # a warning), while choosing it too large will waste memory (since we pad each batch element to the - # max_token_len). A good rule of thumb is to use approx 180 for single-arm robots, and approx 250 for - # two-arm robots. Generally, err on the lower side here first, and potentially increase the value if - # you see many warnings being thrown during training. - model=pi0_fast.Pi0FASTConfig(action_dim=7, action_horizon=10, max_token_len=180), - data=LeRobotLiberoDataConfig( - repo_id="physical-intelligence/libero", - base_config=DataConfig(prompt_from_task=True), - extra_delta_transform=True, - ), - # Note that we load the pi0-FAST base model checkpoint here. - weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_fast_base/params"), - num_train_steps=30_000, - ), - TrainConfig( - name="pi0_fast_libero_low_mem_finetune", - # Here is an example of loading a pi0-FAST model for LoRA finetuning. - # For setting action_dim, action_horizon, and max_token_len, see the comments above. - model=pi0_fast.Pi0FASTConfig( - action_dim=7, action_horizon=10, max_token_len=180, paligemma_variant="gemma_2b_lora" - ), - data=LeRobotLiberoDataConfig( - repo_id="physical-intelligence/libero", - base_config=DataConfig(prompt_from_task=True), - extra_delta_transform=True, - ), - weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_fast_base/params"), - num_train_steps=30_000, - # Again, make sure to match the model config above when extracting the freeze filter - # that specifies which parameters should be frozen during LoRA finetuning. - freeze_filter=pi0_fast.Pi0FASTConfig( - action_dim=7, action_horizon=10, max_token_len=180, paligemma_variant="gemma_2b_lora" - ).get_freeze_filter(), - # Turn off EMA for LoRA finetuning. - ema_decay=None, - ), - TrainConfig( - name="pi05_libero", - model=pi0_config.Pi0Config(pi05=True, action_horizon=10, discrete_state_input=False), - data=LeRobotLiberoDataConfig( - repo_id="physical-intelligence/libero", - base_config=DataConfig(prompt_from_task=True), - extra_delta_transform=False, - ), - batch_size=256, - lr_schedule=_optimizer.CosineDecaySchedule( - warmup_steps=10_000, - peak_lr=5e-5, - decay_steps=1_000_000, - decay_lr=5e-5, - ), - optimizer=_optimizer.AdamW(clip_gradient_norm=1.0), - ema_decay=0.999, - weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_base/params"), - pytorch_weight_path="/path/to/your/pytorch_weight_path", - num_train_steps=30_000, - ), - # - # Fine-tuning Aloha configs. - # - # Personal Tasks - TrainConfig( - name="pi05_capvector_aloha_place_block", # - model=pi0_config.Pi0Config(pi05=True, discrete_state_input=False), - data=LeRobotAlohaDataConfig( - repo_id="cobot_dataset/place_one_floor_block", # your datasets repo_id, like "/" - default_prompt="place the green block", - repack_transforms=_transforms.Group( - inputs=[ - _transforms.RepackTransform( - { - "images": { - "cam_high": "observation.images.cam_high", - "cam_left_wrist": "observation.images.cam_left_wrist", - "cam_right_wrist": "observation.images.cam_right_wrist", - }, - "state": "observation.state", - "actions": "action", - } - ) - ] - ), - ), - weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_base/params"), - pytorch_weight_path='./checkpoints/vector_init/pi05SF-LIBEROspatial_minus_pi05-LIBEROspatial', - # CapVector - regularization_vector_path='checkpoints/diff/pi05SF-LIBEROspatial_minus_pi05-LIBEROspatial.pth', - regularization_coeff=1e-4, - # - num_train_steps=30_000, - batch_size=16, - ema_decay=None, - wandb_enabled=False, - ), - # - # This is a test config that is used to illustate how train on a custom LeRobot dataset. - # For instuctions on how to convert and train on your own Aloha dataset see examples/aloha_real/README.md - TrainConfig( - name="pi0_aloha_pen_uncap", - model=pi0_config.Pi0Config(), - data=LeRobotAlohaDataConfig( - repo_id="physical-intelligence/aloha_pen_uncap_diverse", - assets=AssetsConfig( - assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets", - asset_id="trossen", - ), - default_prompt="uncap the pen", - repack_transforms=_transforms.Group( - inputs=[ - _transforms.RepackTransform( - { - "images": { - "cam_high": "observation.images.cam_high", - "cam_left_wrist": "observation.images.cam_left_wrist", - "cam_right_wrist": "observation.images.cam_right_wrist", - }, - "state": "observation.state", - "actions": "action", - } - ) - ] - ), - ), - weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"), - num_train_steps=20_000, - ), - TrainConfig( - name="pi05_aloha_pen_uncap", - model=pi0_config.Pi0Config(pi05=True), - data=LeRobotAlohaDataConfig( - repo_id="physical-intelligence/aloha_pen_uncap_diverse", - assets=AssetsConfig( - assets_dir="gs://openpi-assets/checkpoints/pi05_base/assets", - asset_id="trossen", - ), - default_prompt="uncap the pen", - repack_transforms=_transforms.Group( - inputs=[ - _transforms.RepackTransform( - { - "images": { - "cam_high": "observation.images.cam_high", - "cam_left_wrist": "observation.images.cam_left_wrist", - "cam_right_wrist": "observation.images.cam_right_wrist", - }, - "state": "observation.state", - "actions": "action", - } - ) - ] - ), - ), - weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_base/params"), - num_train_steps=20_000, - batch_size=64, - ), - # - # Fine-tuning DROID configs. - # - TrainConfig( - # This config is for fine-tuning pi0-FAST-base on the *full* DROID dataset. - # We use RLDS data loading to make training on this large dataset tractable. - # For fine-tuning on your own DROID dataset, see below. - name="pi0_fast_full_droid_finetune", - model=pi0_fast.Pi0FASTConfig( - action_dim=8, - action_horizon=16, - max_token_len=180, - ), - data=RLDSDroidDataConfig( - repo_id="droid", - # Set this to the path to your DROID RLDS dataset (the parent directory of the `droid` directory). - rlds_data_dir="", - action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION, - ), - weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_fast_base/params"), - lr_schedule=_optimizer.CosineDecaySchedule( - warmup_steps=1_000, - peak_lr=5e-5, - decay_steps=1_000_000, - decay_lr=5e-5, - ), - num_train_steps=100_000, # 100k steps should be sufficient, takes ~2 days on 8x H100s - batch_size=256, - log_interval=100, - save_interval=5000, - keep_period=20_000, - num_workers=0, # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally - ), - TrainConfig( - # This config is for fine-tuning pi05 on the *full* DROID dataset. - # We use RLDS data loading to make training on this large dataset tractable. - # For fine-tuning on your own DROID dataset, see below. - name="pi05_full_droid_finetune", - model=pi0_config.Pi0Config( - pi05=True, - action_dim=32, - action_horizon=16, - ), - data=RLDSDroidDataConfig( - repo_id="droid", - # Set this to the path to your DROID RLDS dataset (the parent directory of the `droid` directory). - rlds_data_dir="/mnt/pi-data/kevin", - action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION, - assets=AssetsConfig( - assets_dir="gs://openpi-assets/checkpoints/pi05_base/assets/", - asset_id="droid", - ), - ), - weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_base/params"), - lr_schedule=_optimizer.CosineDecaySchedule( - warmup_steps=1_000, - peak_lr=5e-5, - decay_steps=1_000_000, - decay_lr=5e-5, - ), - num_train_steps=100_000, - batch_size=256, - log_interval=100, - save_interval=5000, - keep_period=10_000, - num_workers=0, # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally - ), - TrainConfig( - # This config is for fine-tuning pi05-DROID on a custom (smaller) DROID dataset. - # Here, we use LeRobot data format (like for all other fine-tuning examples) - # To convert your custom DROID dataset (<10s of hours) to LeRobot format, see examples/droid/convert_droid_data_to_lerobot.py - name="pi05_droid_finetune", - model=pi0_config.Pi0Config( - pi05=True, - action_dim=32, # pi05 is trained with 32-dim actions - action_horizon=16, - ), - data=LeRobotDROIDDataConfig( - # Replace with your custom DROID LeRobot dataset repo id. - repo_id="your_hf_username/my_droid_dataset", - base_config=DataConfig(prompt_from_task=True), - assets=AssetsConfig( - # Important: reuse the original DROID norm stats during fine-tuning! - assets_dir="gs://openpi-assets/checkpoints/pi05_droid/assets", - asset_id="droid", - ), - ), - weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_droid/params"), - num_train_steps=20_000, - batch_size=32, - ), - # - # ALOHA Sim configs. This config is used to demonstrate how to train on a simple simulated environment. - # - TrainConfig( - name="pi0_aloha_sim", - model=pi0_config.Pi0Config(), - data=LeRobotAlohaDataConfig( - repo_id="lerobot/aloha_sim_transfer_cube_human", - default_prompt="Transfer cube", - use_delta_joint_actions=False, - ), - weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"), - num_train_steps=20_000, - ), - # - # Debugging configs. - # - TrainConfig( - name="debug", - data=FakeDataConfig(), - batch_size=2, - model=pi0_config.Pi0Config(paligemma_variant="dummy", action_expert_variant="dummy"), - save_interval=100, - overwrite=True, - exp_name="debug", - num_train_steps=10, - wandb_enabled=False, - ), - TrainConfig( - name="debug_restore", - data=FakeDataConfig(), - batch_size=2, - model=pi0_config.Pi0Config(paligemma_variant="dummy", action_expert_variant="dummy"), - weight_loader=weight_loaders.CheckpointWeightLoader("./checkpoints/debug/debug/9/params"), - overwrite=True, - exp_name="debug", - num_train_steps=10, - wandb_enabled=False, - ), - TrainConfig( - name="debug_pi05", - model=pi0_config.Pi0Config(pi05=True, paligemma_variant="dummy", action_expert_variant="dummy"), - data=FakeDataConfig(), - batch_size=2, - num_train_steps=10, - overwrite=True, - exp_name="debug_pi05", - wandb_enabled=False, - ), - # - # RoboArena configs. - # - *roboarena_config.get_roboarena_configs(), -] - -if len({config.name for config in _CONFIGS}) != len(_CONFIGS): - raise ValueError("Config names must be unique.") -_CONFIGS_DICT = {config.name: config for config in _CONFIGS} - - -def cli() -> TrainConfig: - return tyro.extras.overridable_config_cli({k: (k, v) for k, v in _CONFIGS_DICT.items()}) - - -def get_config(config_name: str) -> TrainConfig: - """Get a config by name.""" - if config_name not in _CONFIGS_DICT: - closest = difflib.get_close_matches(config_name, _CONFIGS_DICT.keys(), n=1, cutoff=0.0) - closest_str = f" Did you mean '{closest[0]}'? " if closest else "" - raise ValueError(f"Config '{config_name}' not found.{closest_str}") - - return _CONFIGS_DICT[config_name] diff --git a/capvector-pi05/src/openpi/training/data_loader.py b/capvector-pi05/src/openpi/training/data_loader.py deleted file mode 100644 index 1847371abda186a0ea9de50e9f86b504e90a2cc2..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/training/data_loader.py +++ /dev/null @@ -1,540 +0,0 @@ -from collections.abc import Iterator, Sequence -import logging -import multiprocessing -import os -import typing -from typing import Literal, Protocol, SupportsIndex, TypeVar - -import jax -import jax.numpy as jnp -import lerobot.common.datasets.lerobot_dataset as lerobot_dataset -import numpy as np -import torch - -import openpi.models.model as _model -import openpi.training.config as _config -from openpi.training.droid_rlds_dataset import DroidRldsDataset -import openpi.transforms as _transforms - -T_co = TypeVar("T_co", covariant=True) - - -class Dataset(Protocol[T_co]): - """Interface for a dataset with random access.""" - - def __getitem__(self, index: SupportsIndex) -> T_co: - raise NotImplementedError("Subclasses of Dataset should implement __getitem__.") - - def __len__(self) -> int: - raise NotImplementedError("Subclasses of Dataset should implement __len__.") - - -class IterableDataset(Protocol[T_co]): - """Interface for an iterable dataset.""" - - def __iter__(self) -> Iterator[T_co]: - raise NotImplementedError("Subclasses of IterableDataset should implement __iter__.") - - def __len__(self) -> int: - raise NotImplementedError("Subclasses of Dataset should implement __len__.") - - -class DataLoader(Protocol[T_co]): - """Interface for a data loader.""" - - def data_config(self) -> _config.DataConfig: - """Get the data config for this data loader.""" - raise NotImplementedError("Subclasses of DataLoader should implement data_config.") - - def __iter__(self) -> Iterator[T_co]: - raise NotImplementedError("Subclasses of DataLoader should implement __iter__.") - - -class TransformedDataset(Dataset[T_co]): - def __init__(self, dataset: Dataset, transforms: Sequence[_transforms.DataTransformFn]): - self._dataset = dataset - self._transform = _transforms.compose(transforms) - - def __getitem__(self, index: SupportsIndex) -> T_co: - return self._transform(self._dataset[index]) - - def __len__(self) -> int: - return len(self._dataset) - - -class IterableTransformedDataset(IterableDataset[T_co]): - def __init__( - self, - dataset: IterableDataset, - transforms: Sequence[_transforms.DataTransformFn], - *, - is_batched: bool = False, - ): - self._dataset = dataset - self._transform = _transforms.compose(transforms) - self._is_batched = is_batched - - def __iter__(self): - for sample in self._dataset: - if self._is_batched: - # Transforms are designed to be applied to individual samples. So we need to split the batch into - # individual samples and apply the transform to each sample individually. - batch_size = next(v.shape[0] for v in sample.values()) - - # Split batch into individual samples using tree_map - individual_samples = [jax.tree.map(lambda x: x[i], sample) for i in range(batch_size)] # noqa: B023 - - # Transform each sample - transformed = [self._transform(s) for s in individual_samples] - - # Recombine batch with tree_map - yield jax.tree.map(lambda *x: np.stack(x, axis=0), *transformed) - else: - yield self._transform(sample) - - def __len__(self) -> int: - return len(self._dataset) - - -class FakeDataset(Dataset): - def __init__(self, model_config: _model.BaseModelConfig, num_samples: int): - self._num_samples = num_samples - self._observation_spec, self._action_spec = model_config.inputs_spec() - - def __getitem__(self, index: SupportsIndex) -> dict: - rng = jax.random.key(index.__index__()) - - def make_from_spec(spec: jax.ShapeDtypeStruct): - nonlocal rng - rng, data_rng = jax.random.split(rng) - # Remove the batch dimension. - shape = spec.shape[1:] - if spec.dtype == jnp.float32: - return jax.random.uniform(data_rng, shape=shape, minval=-1.0, maxval=1.0) - if spec.dtype == jnp.int32: - return jax.random.randint(data_rng, shape=shape, minval=0, maxval=2048) - return jnp.zeros(shape=shape, dtype=spec.dtype) - - observation = jax.tree.map(make_from_spec, self._observation_spec) - action = jax.tree.map(make_from_spec, self._action_spec) - - return { - **observation.to_dict(), - "actions": action, - } - - def __len__(self) -> int: - return self._num_samples - - -def create_torch_dataset( - data_config: _config.DataConfig, action_horizon: int, model_config: _model.BaseModelConfig -) -> Dataset: - """Create a dataset for training.""" - repo_id = data_config.repo_id - if repo_id is None: - raise ValueError("Repo ID is not set. Cannot create dataset.") - if repo_id == "fake": - return FakeDataset(model_config, num_samples=1024) - - dataset_meta = lerobot_dataset.LeRobotDatasetMetadata(repo_id) - dataset = lerobot_dataset.LeRobotDataset( - data_config.repo_id, - delta_timestamps={ - key: [t / dataset_meta.fps for t in range(action_horizon)] for key in data_config.action_sequence_keys - }, - ) - - if data_config.prompt_from_task: - dataset = TransformedDataset(dataset, [_transforms.PromptFromLeRobotTask(dataset_meta.tasks)]) - - return dataset - - -def create_rlds_dataset( - data_config: _config.DataConfig, - action_horizon: int, - batch_size: int, - *, - shuffle: bool = False, -) -> Dataset: - # At the moment, we only support DROID for RLDS datasets. - return DroidRldsDataset( - data_dir=data_config.rlds_data_dir, - batch_size=batch_size, - shuffle=shuffle, - action_chunk_size=action_horizon, - action_space=data_config.action_space, - filter_dict_path=data_config.filter_dict_path, - ) - - -def transform_dataset(dataset: Dataset, data_config: _config.DataConfig, *, skip_norm_stats: bool = False) -> Dataset: - """Transform the dataset by applying the data transforms.""" - norm_stats = {} - if data_config.repo_id != "fake" and not skip_norm_stats: - if data_config.norm_stats is None: - raise ValueError( - "Normalization stats not found. " - "Make sure to run `scripts/compute_norm_stats.py --config-name=`." - ) - norm_stats = data_config.norm_stats - - return TransformedDataset( - dataset, - [ - *data_config.repack_transforms.inputs, - *data_config.data_transforms.inputs, - _transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm), - *data_config.model_transforms.inputs, - ], - ) - - -def transform_iterable_dataset( - dataset: IterableDataset, - data_config: _config.DataConfig, - *, - skip_norm_stats: bool = False, - is_batched: bool = False, -) -> IterableDataset: - """Transform the dataset by applying the data transforms.""" - norm_stats = {} - if data_config.repo_id != "fake" and not skip_norm_stats: - if data_config.norm_stats is None: - raise ValueError( - "Normalization stats not found. " - "Make sure to run `scripts/compute_norm_stats.py --config-name=`." - ) - norm_stats = data_config.norm_stats - - return IterableTransformedDataset( - dataset, - [ - *data_config.repack_transforms.inputs, - *data_config.data_transforms.inputs, - _transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm), - *data_config.model_transforms.inputs, - ], - is_batched=is_batched, - ) - - -def create_data_loader( - config: _config.TrainConfig, - *, - sharding: jax.sharding.Sharding | None = None, - shuffle: bool = False, - num_batches: int | None = None, - skip_norm_stats: bool = False, - framework: Literal["jax", "pytorch"] = "jax", -) -> DataLoader[tuple[_model.Observation, _model.Actions]]: - """Create a data loader for training. - - Args: - config: The training configuration. - sharding: The sharding to use for the data loader (JAX only). - shuffle: Whether to shuffle the data. - num_batches: Determines the number of batches to return. - skip_norm_stats: Whether to skip data normalization. - framework: The framework to use ("jax" or "pytorch"). - """ - data_config = config.data.create(config.assets_dirs, config.model) - logging.info(f"data_config: {data_config}") - - if data_config.rlds_data_dir is not None: - return create_rlds_data_loader( - data_config, - action_horizon=config.model.action_horizon, - batch_size=config.batch_size, - sharding=sharding, - shuffle=shuffle, - num_batches=num_batches, - skip_norm_stats=skip_norm_stats, - framework=framework, - ) - return create_torch_data_loader( - data_config, - model_config=config.model, - action_horizon=config.model.action_horizon, - batch_size=config.batch_size, - sharding=sharding, - shuffle=shuffle, - num_batches=num_batches, - num_workers=config.num_workers, - seed=config.seed, - skip_norm_stats=skip_norm_stats, - framework=framework, - ) - - -def create_torch_data_loader( - data_config: _config.DataConfig, - model_config: _model.BaseModelConfig, - action_horizon: int, - batch_size: int, - *, - sharding: jax.sharding.Sharding | None = None, - skip_norm_stats: bool = False, - shuffle: bool = False, - num_batches: int | None = None, - num_workers: int = 0, - seed: int = 0, - framework: str = "jax", -) -> DataLoader[tuple[_model.Observation, _model.Actions]]: - """Create a data loader for training. - - Args: - data_config: The data configuration. - action_horizon: The action horizon. - batch_size: The batch size. - sharding: The sharding to use for the data loader. If None, the data loader will - use a single device sharding. - skip_norm_stats: Whether to skip data normalization. - shuffle: Whether to shuffle the data. - num_batches: Determines the number of batches to return. If the number exceeds the - number of batches in the dataset, the data loader will loop over the dataset. - If not provided, will iterate over the dataset indefinitely. - num_workers: The number of worker processes to use. If zero, the data loader will - execute in the main process. - seed: The seed to use for shuffling the data. - """ - dataset = create_torch_dataset(data_config, action_horizon, model_config) - dataset = transform_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats) - - # Use TorchDataLoader for both frameworks - # For PyTorch DDP, create DistributedSampler and divide batch size by world size - # For JAX, divide by process count - sampler = None - if framework == "pytorch": - if torch.distributed.is_initialized(): - sampler = torch.utils.data.distributed.DistributedSampler( - dataset, - num_replicas=torch.distributed.get_world_size(), - rank=torch.distributed.get_rank(), - shuffle=shuffle, - drop_last=True, - ) - local_batch_size = batch_size // torch.distributed.get_world_size() - else: - local_batch_size = batch_size - else: - local_batch_size = batch_size // jax.process_count() - - logging.info(f"local_batch_size: {local_batch_size}") - data_loader = TorchDataLoader( - dataset, - local_batch_size=local_batch_size, - sharding=None if framework == "pytorch" else sharding, - shuffle=(sampler is None and shuffle), # Don't shuffle if using sampler - sampler=sampler, - num_batches=num_batches, - num_workers=num_workers, - seed=seed, - framework=framework, - ) - - return DataLoaderImpl(data_config, data_loader) - - -def create_rlds_data_loader( - data_config: _config.DataConfig, - action_horizon: int, - batch_size: int, - *, - sharding: jax.sharding.Sharding | None = None, - skip_norm_stats: bool = False, - shuffle: bool = False, - num_batches: int | None = None, - framework: str = "jax", -) -> DataLoader[tuple[_model.Observation, _model.Actions]]: - """Create an RLDS data loader for training. - - Note: This data loader requires some extra dependencies -- see examples/droid/README_train.md - - Args: - data_config: The data configuration. - action_horizon: The action horizon. - batch_size: The batch size. - sharding: The sharding to use for the data loader. If None, the data loader will - use a single device sharding. - skip_norm_stats: Whether to skip data normalization. - shuffle: Whether to shuffle the data. - num_batches: Determines the number of batches to return. If the number exceeds the - number of batches in the dataset, the data loader will loop over the dataset. - If not provided, will iterate over the dataset indefinitely. - """ - if framework == "pytorch": - raise NotImplementedError("PyTorch RLDS data loader is not supported yet") - dataset = create_rlds_dataset(data_config, action_horizon, batch_size, shuffle=shuffle) - dataset = transform_iterable_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats, is_batched=True) - - data_loader = RLDSDataLoader( - dataset, - sharding=sharding, - num_batches=num_batches, - ) - - return DataLoaderImpl(data_config, data_loader) - - -class TorchDataLoader: - """Torch data loader implementation.""" - - def __init__( - self, - dataset, - local_batch_size: int, - *, - sharding: jax.sharding.Sharding | None = None, - shuffle: bool = False, - sampler: torch.utils.data.Sampler | None = None, - num_batches: int | None = None, - num_workers: int = 0, - seed: int = 0, - framework: str = "jax", - ): - """Create a PyTorch data loader. - - Args: - dataset: The dataset to load. - local_batch_size: The local batch size for each process. - sharding: The sharding to use for the data loader. - shuffle: Whether to shuffle the data. - num_batches: If provided, determines the number of returned batches. If the - number is larger than the number of batches in the dataset, the data loader - will loop over the dataset. If not provided, will iterate over the dataset - indefinitely. - num_workers: The number of worker processes to use. If zero, the data loader will - execute in the main process. - seed: The seed to use for shuffling the data. - """ - if jax.process_count() > 1: - raise NotImplementedError("Data loading with multiple processes is not supported.") - - if len(dataset) < local_batch_size: - raise ValueError(f"Local batch size ({local_batch_size}) is larger than the dataset size ({len(dataset)}).") - - # Store sharding - None for PyTorch, JAX sharding for JAX - self._sharding = sharding - if sharding is None and framework == "jax": - # Use data parallel sharding by default for JAX only. - self._sharding = jax.sharding.NamedSharding( - jax.sharding.Mesh(jax.devices(), ("B",)), - jax.sharding.PartitionSpec("B"), - ) - self._num_batches = num_batches - - mp_context = None - if num_workers > 0: - mp_context = multiprocessing.get_context("spawn") - - generator = torch.Generator() - generator.manual_seed(seed) - self._data_loader = torch.utils.data.DataLoader( - typing.cast(torch.utils.data.Dataset, dataset), - batch_size=local_batch_size, - shuffle=(sampler is None and shuffle), # Don't shuffle if using sampler - sampler=sampler, - num_workers=num_workers, - multiprocessing_context=mp_context, - persistent_workers=num_workers > 0, - collate_fn=_collate_fn, - worker_init_fn=_worker_init_fn, - drop_last=True, - generator=generator, - ) - - @property - def torch_loader(self) -> torch.utils.data.DataLoader: - return self._data_loader - - def __iter__(self): - num_items = 0 - while True: - data_iter = iter(self._data_loader) - while True: - if self._num_batches is not None and num_items >= self._num_batches: - return - try: - batch = next(data_iter) - except StopIteration: - break # We've exhausted the dataset. Create a new iterator and start over. - num_items += 1 - # For JAX, convert to sharded arrays; for PyTorch, return torch tensors - if self._sharding is not None: - yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch) - else: - yield jax.tree.map(torch.as_tensor, batch) - - -def _collate_fn(items): - """Collate the batch elements into batched numpy arrays.""" - # Make sure to convert to numpy arrays before stacking since some of the incoming elements - # may be JAX arrays. - return jax.tree.map(lambda *xs: np.stack([np.asarray(x) for x in xs], axis=0), *items) - - -def _worker_init_fn(worker_id: int) -> None: - """Tell JAX inside the worker process not to preallocate the GPU memory.""" - # NOTE: This is called after jax is imported inside the worker process. This - # means that this approach will not work for selecting the backend. - os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" - os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" - - -class RLDSDataLoader: - """Shallow wrapper around the DROID data loader to make it compatible with openpi. - - All batching already happens in the DROID dataset, so we don't need to do anything here. - """ - - def __init__( - self, - dataset: DroidRldsDataset, - *, - sharding: jax.sharding.Sharding | None = None, - num_batches: int | None = None, - ): - self._dataset = dataset - self._num_batches = num_batches - - if jax.process_count() > 1: - raise NotImplementedError("Data loading with multiple processes is not supported.") - - if sharding is None: - # Use data parallel sharding by default. - sharding = jax.sharding.NamedSharding( - jax.sharding.Mesh(jax.devices(), ("B",)), - jax.sharding.PartitionSpec("B"), - ) - - self._sharding = sharding - self._num_batches = num_batches - - def __iter__(self): - num_items = 0 - while True: - data_iter = iter(self._dataset) - while True: - if self._num_batches is not None and num_items >= self._num_batches: - return - try: - batch = next(data_iter) - except StopIteration: - break # We've exhausted the dataset. Create a new iterator and start over. - num_items += 1 - yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch) - - -class DataLoaderImpl(DataLoader): - def __init__(self, data_config: _config.DataConfig, data_loader: TorchDataLoader | RLDSDataLoader): - self._data_config = data_config - self._data_loader = data_loader - - def data_config(self) -> _config.DataConfig: - return self._data_config - - def __iter__(self): - for batch in self._data_loader: - yield _model.Observation.from_dict(batch), batch["actions"] diff --git a/capvector-pi05/src/openpi/training/data_loader_test.py b/capvector-pi05/src/openpi/training/data_loader_test.py deleted file mode 100644 index 3b77188885a14c4c82a23f288d912aab03028dcb..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/training/data_loader_test.py +++ /dev/null @@ -1,84 +0,0 @@ -import dataclasses - -import jax - -from openpi.models import pi0_config -from openpi.training import config as _config -from openpi.training import data_loader as _data_loader - - -def test_torch_data_loader(): - config = pi0_config.Pi0Config(action_dim=24, action_horizon=50, max_token_len=48) - dataset = _data_loader.FakeDataset(config, 16) - - loader = _data_loader.TorchDataLoader( - dataset, - local_batch_size=4, - num_batches=2, - ) - batches = list(loader) - - assert len(batches) == 2 - for batch in batches: - assert all(x.shape[0] == 4 for x in jax.tree.leaves(batch)) - - -def test_torch_data_loader_infinite(): - config = pi0_config.Pi0Config(action_dim=24, action_horizon=50, max_token_len=48) - dataset = _data_loader.FakeDataset(config, 4) - - loader = _data_loader.TorchDataLoader(dataset, local_batch_size=4) - data_iter = iter(loader) - - for _ in range(10): - _ = next(data_iter) - - -def test_torch_data_loader_parallel(): - config = pi0_config.Pi0Config(action_dim=24, action_horizon=50, max_token_len=48) - dataset = _data_loader.FakeDataset(config, 10) - - loader = _data_loader.TorchDataLoader(dataset, local_batch_size=4, num_batches=2, num_workers=2) - batches = list(loader) - - assert len(batches) == 2 - - for batch in batches: - assert all(x.shape[0] == 4 for x in jax.tree.leaves(batch)) - - -def test_with_fake_dataset(): - config = _config.get_config("debug") - - loader = _data_loader.create_data_loader(config, skip_norm_stats=True, num_batches=2) - batches = list(loader) - - assert len(batches) == 2 - - for batch in batches: - assert all(x.shape[0] == config.batch_size for x in jax.tree.leaves(batch)) - - for _, actions in batches: - assert actions.shape == (config.batch_size, config.model.action_horizon, config.model.action_dim) - - -def test_with_real_dataset(): - config = _config.get_config("pi0_aloha_sim") - config = dataclasses.replace(config, batch_size=4) - - loader = _data_loader.create_data_loader( - config, - # Skip since we may not have the data available. - skip_norm_stats=True, - num_batches=2, - shuffle=True, - ) - # Make sure that we can get the data config. - assert loader.data_config().repo_id == config.data.repo_id - - batches = list(loader) - - assert len(batches) == 2 - - for _, actions in batches: - assert actions.shape == (config.batch_size, config.model.action_horizon, config.model.action_dim) diff --git a/capvector-pi05/src/openpi/training/droid_rlds_dataset.py b/capvector-pi05/src/openpi/training/droid_rlds_dataset.py deleted file mode 100644 index debbe73cde9718b3d17157799b676fd65b00c7ba..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/training/droid_rlds_dataset.py +++ /dev/null @@ -1,221 +0,0 @@ -""" -RLDS-based data loader for DROID. -While openpi typically uses LeRobot's data loader, it is not currently scalable enough for larger datasets like DROID. -Thus, we provide a data loader example here that uses the RLDS data format. -The data loader also applies a few DROID-specific data filters / transformations. -""" - -from enum import Enum -from enum import auto -import json -import logging -from pathlib import Path - -import tqdm - -import openpi.shared.download as download - - -class DroidActionSpace(Enum): - """Action space for DROID dataset.""" - - JOINT_POSITION = auto() - JOINT_VELOCITY = auto() - - -class DroidRldsDataset: - def __init__( - self, - data_dir: str, - batch_size: int, - *, # Force keyword-only arguments - shuffle: bool = True, - action_chunk_size: int = 16, - # We default to joint position actions, since they allow policy evaluation in simulation. - action_space: DroidActionSpace = DroidActionSpace.JOINT_POSITION, - max_loaded_steps_per_episode: int = 100, - # Reduce this if you are running out of memory, but careful -- below ~100k shuffling is not sufficiently random. - shuffle_buffer_size: int = 250_000, - num_parallel_reads: int = -1, # -1 == tf.data.AUTOTUNE -- hack to not import tf at top level - num_parallel_calls: int = -1, # -1 == tf.data.AUTOTUNE -- hack to not import tf at top level - filter_dict_path=None, # Path to json file with indices to sample during training - ): - # Import tensorflow here to not make it mandatory in case RLDS data loader is not used. - import dlimp as dl - import tensorflow as tf - import tensorflow_datasets as tfds - - # Configure Tensorflow with *no GPU devices* (to prevent clobber with PyTorch / JAX) - tf.config.set_visible_devices([], "GPU") - - builder = tfds.builder("droid", data_dir=data_dir, version="1.0.1") - dataset = dl.DLataset.from_rlds(builder, split="train", shuffle=shuffle, num_parallel_reads=num_parallel_reads) - - # Filter out any unsuccessful trajectories -- we use the file name to check this - dataset = dataset.filter( - lambda traj: tf.strings.regex_full_match( - traj["traj_metadata"]["episode_metadata"]["file_path"][0], ".*success.*" - ) - ) - - # # Repeat dataset so we never run out of data. - dataset = dataset.repeat() - - # Load the filter dictionary if provided. - # The filter dictionary is a JSON file that maps episode keys to ranges of frames to sample - # (e.g., - # { - # "": [[0, 100], [200, 300]] - # } - # means keep frames 0-99 and 200-299). - if filter_dict_path is not None: - cached_filter_dict_path = download.maybe_download(filter_dict_path) - with Path(cached_filter_dict_path).open("r") as f: - filter_dict = json.load(f) - - logging.info(f"Using filter dictionary with {len(filter_dict)} episodes") - - keys_tensor = [] - values_tensor = [] - - for episode_key, ranges in tqdm.tqdm(filter_dict.items(), desc="Creating idle filter hash table..."): - for start, end in ranges: - for t in range(start, end): - frame_key = f"{episode_key}--{t}" - keys_tensor.append(frame_key) - values_tensor.append(True) - self.filter_table = tf.lookup.StaticHashTable( - tf.lookup.KeyValueTensorInitializer(keys_tensor, values_tensor), default_value=False - ) - logging.info("Filter hash table initialized") - else: - self.filter_table = tf.lookup.StaticHashTable( - tf.lookup.KeyValueTensorInitializer([""], [True]), default_value=True - ) - - def restructure(traj): - """Reformat observation and action keys, sample language instruction.""" - # Important: we use joint *position* action space -- easier to simulate! - actions = tf.concat( - ( - ( - traj["action_dict"]["joint_position"] - if action_space == DroidActionSpace.JOINT_POSITION - else traj["action_dict"]["joint_velocity"] - ), - traj["action_dict"]["gripper_position"], - ), - axis=-1, - ) - # Randomly samples one of the two exterior images in DROID during training (we only train with one at a time). - # Note: the "left" refers to the left camera in the stereo pair, we only train on the left camera. - exterior_img = tf.cond( - tf.random.uniform(shape=[]) > 0.5, - lambda: traj["observation"]["exterior_image_1_left"], - lambda: traj["observation"]["exterior_image_2_left"], - ) - wrist_img = traj["observation"]["wrist_image_left"] - # Randomly sample one of the three language instructions - instruction = tf.random.shuffle( - [traj["language_instruction"], traj["language_instruction_2"], traj["language_instruction_3"]] - )[0] - - traj_len = tf.shape(traj["action"])[0] - indices = tf.as_string(tf.range(traj_len)) - - # Data filtering: - # Compute a uniquely-identifying step ID by concatenating the recording folderpath, file path, - # and each step's time step index. This will index into the filter hash table, and if it returns true, - # then the frame passes the filter. - step_id = ( - traj["traj_metadata"]["episode_metadata"]["recording_folderpath"] - + "--" - + traj["traj_metadata"]["episode_metadata"]["file_path"] - + "--" - + indices - ) - passes_filter = self.filter_table.lookup(step_id) - - return { - "actions": actions, - "observation": { - "image": exterior_img, - "wrist_image": wrist_img, - "joint_position": traj["observation"]["joint_position"], - "gripper_position": traj["observation"]["gripper_position"], - }, - "prompt": instruction, - "step_id": step_id, - "passes_filter": passes_filter, - } - - dataset = dataset.traj_map(restructure, num_parallel_calls) - - def chunk_actions(traj): - """Splits episode into action chunks.""" - traj_len = tf.shape(traj["actions"])[0] - - # For each step in the trajectory, construct indices for the next n actions - action_chunk_indices = tf.broadcast_to( - tf.range(action_chunk_size)[None], - [traj_len, action_chunk_size], - ) + tf.broadcast_to( - tf.range(traj_len)[:, None], - [traj_len, action_chunk_size], - ) - - # Cap to length of the sequence --> final chunks will repeat the last action - # This makes sense, since we are using absolute joint + gripper position actions - action_chunk_indices = tf.minimum(action_chunk_indices, traj_len - 1) - - # Gather the actions for each chunk - traj["actions"] = tf.gather(traj["actions"], action_chunk_indices) - return traj - - dataset = dataset.traj_map(chunk_actions, num_parallel_calls) - - # Flatten: map from trajectory dataset to dataset of individual action chunks - dataset = dataset.flatten(num_parallel_calls=num_parallel_calls) - - # Filter data that doesn't pass the filter - def filter_from_dict(frame): - return frame["passes_filter"] - - dataset = dataset.filter(filter_from_dict) - - # Remove "passes_filter" key from output - def remove_passes_filter(frame): - frame.pop("passes_filter") - return frame - - dataset = dataset.map(remove_passes_filter) - - # Decode images: RLDS saves encoded images, only decode now for efficiency - def decode_images(traj): - traj["observation"]["image"] = tf.io.decode_image( - traj["observation"]["image"], expand_animations=False, dtype=tf.uint8 - ) - traj["observation"]["wrist_image"] = tf.io.decode_image( - traj["observation"]["wrist_image"], expand_animations=False, dtype=tf.uint8 - ) - return traj - - dataset = dataset.frame_map(decode_images, num_parallel_calls) - - # Shuffle, batch - dataset = dataset.shuffle(shuffle_buffer_size) - dataset = dataset.batch(batch_size) - # Note =>> Seems to reduce memory usage without affecting speed? - dataset = dataset.with_ram_budget(1) - - self.dataset = dataset - self.batch_size = batch_size - self.shuffle = shuffle - - def __iter__(self): - yield from self.dataset.as_numpy_iterator() - - def __len__(self): - # This is the approximate number of samples in DROID after filtering. - # Easier to hardcode than to iterate through the dataset and compute it. - return 20_000_000 diff --git a/capvector-pi05/src/openpi/training/misc/roboarena_config.py b/capvector-pi05/src/openpi/training/misc/roboarena_config.py deleted file mode 100644 index e0f366a43caf004d0f291db5af1d1678083888ca..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/training/misc/roboarena_config.py +++ /dev/null @@ -1,116 +0,0 @@ -"""RoboArena baseline policy configs.""" - -from typing import TypeAlias - -import openpi.models.model as _model -import openpi.models.pi0_config as pi0_config -import openpi.models.pi0_fast as pi0_fast -import openpi.models.tokenizer as _tokenizer -import openpi.policies.droid_policy as droid_policy -import openpi.transforms as _transforms - -ModelType: TypeAlias = _model.ModelType - - -def get_roboarena_configs(): - # Import here to avoid circular imports. - from openpi.training.config import AssetsConfig - from openpi.training.config import DataConfig - from openpi.training.config import SimpleDataConfig - from openpi.training.config import TrainConfig - - return [ - # - # RoboArena DROID baseline inference configs. - # - TrainConfig( - # Trained from PaliGemma, using RT-2 / OpenVLA style binning tokenizer. - name="paligemma_binning_droid", - model=pi0_fast.Pi0FASTConfig( - action_dim=8, - action_horizon=15, - max_token_len=400, - fast_model_tokenizer=_tokenizer.BinningTokenizer, - ), - data=SimpleDataConfig( - assets=AssetsConfig(asset_id="droid"), - data_transforms=lambda model: _transforms.Group( - inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)], - outputs=[droid_policy.DroidOutputs()], - ), - base_config=DataConfig( - prompt_from_task=True, - ), - ), - ), - TrainConfig( - # Trained from PaliGemma, using FAST tokenizer (using universal FAST+ tokenizer). - name="paligemma_fast_droid", - model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15), - data=SimpleDataConfig( - assets=AssetsConfig(asset_id="droid"), - data_transforms=lambda model: _transforms.Group( - inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)], - outputs=[droid_policy.DroidOutputs()], - ), - base_config=DataConfig( - prompt_from_task=True, - ), - ), - ), - TrainConfig( - # Trained from PaliGemma, using FAST tokenizer (tokenizer trained on DROID dataset). - name="paligemma_fast_specialist_droid", - model=pi0_fast.Pi0FASTConfig( - action_dim=8, - action_horizon=15, - fast_model_tokenizer=_tokenizer.FASTTokenizer, - fast_model_tokenizer_kwargs={"fast_tokenizer_path": "KarlP/fast_droid_specialist"}, - ), - data=SimpleDataConfig( - assets=AssetsConfig(asset_id="droid"), - data_transforms=lambda model: _transforms.Group( - inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)], - outputs=[droid_policy.DroidOutputs()], - ), - base_config=DataConfig( - prompt_from_task=True, - ), - ), - ), - TrainConfig( - # Trained from PaliGemma, using FSQ tokenizer. - name="paligemma_vq_droid", - model=pi0_fast.Pi0FASTConfig( - action_dim=8, - action_horizon=15, - fast_model_tokenizer=_tokenizer.FSQTokenizer, - fast_model_tokenizer_kwargs={"fsq_tokenizer_path": "gs://openpi-assets/tokenizers/droid_fsq_tokenizer"}, - ), - data=SimpleDataConfig( - assets=AssetsConfig(asset_id="droid"), - data_transforms=lambda model: _transforms.Group( - inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)], - outputs=[droid_policy.DroidOutputs()], - ), - base_config=DataConfig( - prompt_from_task=True, - ), - ), - ), - TrainConfig( - # pi0-style diffusion / flow VLA, trained on DROID from PaliGemma. - name="paligemma_diffusion_droid", - model=pi0_config.Pi0Config(action_horizon=10, action_dim=8), - data=SimpleDataConfig( - assets=AssetsConfig(asset_id="droid"), - data_transforms=lambda model: _transforms.Group( - inputs=[droid_policy.DroidInputs(action_dim=model.action_dim)], - outputs=[droid_policy.DroidOutputs()], - ), - base_config=DataConfig( - prompt_from_task=True, - ), - ), - ), - ] diff --git a/capvector-pi05/src/openpi/training/optimizer.py b/capvector-pi05/src/openpi/training/optimizer.py deleted file mode 100644 index a233bfd0e2d0fc62c295bd3ab82b35726a5fc545..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/training/optimizer.py +++ /dev/null @@ -1,109 +0,0 @@ -import dataclasses -from typing import Protocol, runtime_checkable - -import jax.numpy as jnp -import optax - -import openpi.shared.array_typing as at - - -@runtime_checkable -class LRScheduleConfig(Protocol): - def create(self) -> optax.Schedule: ... - - -@dataclasses.dataclass(frozen=True) -class CosineDecaySchedule(LRScheduleConfig): - """Cosine decay schedule with warmup.""" - - warmup_steps: int = 1_000 - peak_lr: float = 2.5e-5 - decay_steps: int = 30_000 - decay_lr: float = 2.5e-6 - - def create(self) -> optax.Schedule: - return optax.warmup_cosine_decay_schedule( - init_value=self.peak_lr / (self.warmup_steps + 1), - peak_value=self.peak_lr, - warmup_steps=self.warmup_steps, - decay_steps=self.decay_steps, - end_value=self.decay_lr, - ) - - -@dataclasses.dataclass(frozen=True) -class RsqrtDecaySchedule(LRScheduleConfig): - """Inverse square root decay schedule with warmup.""" - - warmup_steps: int = 1_000 - peak_lr: float = 5e-5 - timescale: float = 10_000 - - def create(self) -> optax.Schedule: - return optax.join_schedules( - [ - optax.linear_schedule( - init_value=self.peak_lr / (self.warmup_steps + 1), - end_value=self.peak_lr, - transition_steps=self.warmup_steps, - ), - lambda step: self.peak_lr / jnp.sqrt((self.timescale + step) / self.timescale), - ], - [self.warmup_steps], - ) - - -@runtime_checkable -class OptimizerConfig(Protocol): - def create( - self, - lr: optax.ScalarOrSchedule, - weight_decay_mask: at.PyTree | None = None, - ) -> optax.GradientTransformation: ... - - -@dataclasses.dataclass(frozen=True) -class AdamW(OptimizerConfig): - """AdamW optimizer.""" - - b1: float = 0.9 - b2: float = 0.95 - eps: float = 1e-8 - # Changing this to 0 can cause out-of-memory errors for some reason, so we set it to a negligible value. - weight_decay: float = 1e-10 - clip_gradient_norm: float = 1.0 - - def create( - self, - lr: optax.ScalarOrSchedule, - weight_decay_mask: at.PyTree | None = None, - ) -> optax.GradientTransformation: - tx = optax.adamw( - lr, b1=self.b1, b2=self.b2, eps=self.eps, weight_decay=self.weight_decay, mask=weight_decay_mask - ) - - return optax.chain(optax.clip_by_global_norm(self.clip_gradient_norm), tx) - - -@dataclasses.dataclass(frozen=True) -class SGD(OptimizerConfig): - """SGD optimizer.""" - - lr: float = 5e-5 - momentum: float = 0.9 - nesterov: bool = False - - def create( - self, - lr: optax.ScalarOrSchedule, - weight_decay_mask: at.PyTree | None = None, - ) -> optax.GradientTransformation: - assert weight_decay_mask is None, "Weight decay is not supported for SGD" - return optax.sgd(lr, momentum=self.momentum, nesterov=self.nesterov) - - -def create_optimizer( - optimizer: OptimizerConfig, lr_schedule: LRScheduleConfig, weight_decay_mask: at.PyTree | None = None -) -> optax.GradientTransformation: - lr = lr_schedule.create() - return optimizer.create(lr, weight_decay_mask=weight_decay_mask) diff --git a/capvector-pi05/src/openpi/training/sharding.py b/capvector-pi05/src/openpi/training/sharding.py deleted file mode 100644 index 6b34e5e11069637028f1792a14fd8be95073dfcd..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/training/sharding.py +++ /dev/null @@ -1,102 +0,0 @@ -import contextlib -import logging - -import jax -import numpy as np - -BATCH_AXIS = "batch" -FSDP_AXIS = "fsdp" -# In FSDP, we shard the data across both the batch and FSDP axes. -DATA_AXIS = (BATCH_AXIS, FSDP_AXIS) - - -class _MeshState: - active_mesh: jax.sharding.Mesh | None = None - - -def make_mesh(num_fsdp_devices: int) -> jax.sharding.Mesh: - if jax.device_count() % num_fsdp_devices != 0: - raise ValueError( - f"Number of devices {jax.device_count()} must be divisible by the number of FSDP devices {num_fsdp_devices}." - ) - mesh_shape = (jax.device_count() // num_fsdp_devices, num_fsdp_devices) - return jax.make_mesh(mesh_shape, (BATCH_AXIS, FSDP_AXIS)) - - -@contextlib.contextmanager -def set_mesh(mesh: jax.sharding.Mesh): - """Plumbing the mesh deep into the module tree is extremeley cumbersome; until the JAX team lands a better API, a - custom context manager like this one is the recommended way to maintain a reference to a global mesh. This is only used - in `activation_sharding_constraint` below.""" - if _MeshState.active_mesh is not None: - raise ValueError("Cannot nest set_mesh context managers.") - _MeshState.active_mesh = mesh - try: - yield - finally: - _MeshState.active_mesh = None - - -def activation_sharding_constraint(pytree): - if _MeshState.active_mesh is None: - return pytree - return jax.lax.with_sharding_constraint( - pytree, jax.sharding.NamedSharding(_MeshState.active_mesh, jax.sharding.PartitionSpec(DATA_AXIS)) - ) - - -def fsdp_sharding( - pytree, - mesh: jax.sharding.Mesh, - *, - min_size_mbytes: int = 4, # 4 MiB - log: bool = False, -): - """Apply FSDP sharding to a pytree of arrays based on the mesh shape. - - Args: - pytree: A pytree to be apply sharding specified by the mesh, note that only array types (eg. contains .shape attr) - will be considered for sharding. - mesh: The mesh being used for applying sharding on to pytree. - min_size_mbytes: The minimum size of the array in MiB to be considered for sharding, any array smaller than this - will be replicated. - log: If true, will log the sharding decisions for arrays that are being considered for sharding. - - Returns: - The sharded pytree. - """ - min_size_bytes = min_size_mbytes * 2**20 - - def _shard_arr(kp, array: jax.ShapeDtypeStruct): - # if fsdp is not actually going to be used, replicate everything to avoid extraneous logging - if mesh.shape[FSDP_AXIS] == 1: - return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) - # replicate scalar and vector arrays - if not hasattr(array, "shape"): - return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) - if len(array.shape) < 2: - return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) - # replicate small arrays - if (arr_size := np.prod(array.shape) * np.dtype(array.dtype).itemsize) < min_size_bytes: - return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) - - # shard matrices and larger tensors along the largest axis that is divisible by the fsdp dimension - axes = np.argsort(array.shape)[::-1] - spec = [None] * len(axes) - for i in axes: - if array.shape[i] % mesh.shape[FSDP_AXIS] == 0: - if log: - logging.info( - f"Sharding {jax.tree_util.keystr(kp)} of shape {array.shape} ({arr_size / 2**20:.2f} MiB) along axis {i}" - ) - spec[i] = FSDP_AXIS - return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(*spec)) - - # replicate if no valid sharding was found - if log: - logging.warning( - f"Could not find a valid sharding for {jax.tree_util.keystr(kp)} of shape {array.shape} with mesh of shape {mesh.shape}" - ) - return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) - - return jax.tree_util.tree_map_with_path(_shard_arr, pytree) diff --git a/capvector-pi05/src/openpi/training/utils.py b/capvector-pi05/src/openpi/training/utils.py deleted file mode 100644 index 5593fee824510233d18c527b7a4f16470970fca3..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/training/utils.py +++ /dev/null @@ -1,38 +0,0 @@ -from collections.abc import Callable -from typing import Any - -from flax import nnx -from flax import struct -import jax -import optax - -from openpi.models import model as _model -from openpi.shared import array_typing as at - - -@at.typecheck -@struct.dataclass -class TrainState: - step: at.Int[at.ArrayLike, ""] - params: nnx.State - model_def: nnx.GraphDef[_model.BaseModel] - opt_state: optax.OptState - tx: optax.GradientTransformation = struct.field(pytree_node=False) - - ema_decay: float | None = struct.field(pytree_node=False) - ema_params: nnx.State | None = None - - -@at.typecheck -def tree_to_info(tree: at.PyTree, interp_func: Callable[[Any], str] = str) -> str: - """Converts a PyTree into a human-readable string for logging. Optionally, `interp_func` can be provided to convert - the leaf values to more meaningful strings. - """ - tree, _ = jax.tree_util.tree_flatten_with_path(tree) - return "\n".join(f"{jax.tree_util.keystr(path)}: {interp_func(value)}" for path, value in tree) - - -@at.typecheck -def array_tree_to_info(tree: at.PyTree) -> str: - """Converts a PyTree of arrays into a human-readable string for logging.""" - return tree_to_info(tree, lambda x: f"{x.shape}@{x.dtype}") diff --git a/capvector-pi05/src/openpi/training/weight_loaders.py b/capvector-pi05/src/openpi/training/weight_loaders.py deleted file mode 100644 index f13f3cb90475a59e76c81392c0ee8246192e8d24..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/training/weight_loaders.py +++ /dev/null @@ -1,104 +0,0 @@ -import dataclasses -import logging -import re -from typing import Protocol, runtime_checkable - -import flax.traverse_util -import numpy as np - -import openpi.models.model as _model -import openpi.shared.array_typing as at -import openpi.shared.download as download - -logger = logging.getLogger(__name__) - - -@runtime_checkable -class WeightLoader(Protocol): - def load(self, params: at.Params) -> at.Params: - """Loads the model weights. - - Args: - params: Parameters of the model. This is a nested structure of array-like objects that - represent the model's parameters. - - Returns: - Loaded parameters. The structure must be identical to `params`. If returning a subset of - the parameters the loader must merge the loaded parameters with `params`. - """ - - -@dataclasses.dataclass(frozen=True) -class NoOpWeightLoader(WeightLoader): - def load(self, params: at.Params) -> at.Params: - return params - - -@dataclasses.dataclass(frozen=True) -class CheckpointWeightLoader(WeightLoader): - """Loads an entire set of weights from a checkpoint. - - Compatible with: - trained checkpoints: - example: "./checkpoints////params" - released checkpoints: - example: "gs://openpi-assets/checkpoints//params" - """ - - params_path: str - - def load(self, params: at.Params) -> at.Params: - # We are loading np.ndarray and relying on the training code to properly convert and shard the params. - loaded_params = _model.restore_params(download.maybe_download(self.params_path), restore_type=np.ndarray) - # Add all missing LoRA weights. - return _merge_params(loaded_params, params, missing_regex=".*lora.*") - - -@dataclasses.dataclass(frozen=True) -class PaliGemmaWeightLoader(WeightLoader): - """Loads weights from the official PaliGemma checkpoint. - - This will overwrite existing weights with similar names while keeping all extra weights intact. - This allows us to support the action expert which is used by the Pi0 model. - """ - - def load(self, params: at.Params) -> at.Params: - path = download.maybe_download( - "gs://vertex-model-garden-paligemma-us/paligemma/pt_224.npz", gs={"token": "anon"} - ) - with path.open("rb") as f: - flat_params = dict(np.load(f, allow_pickle=False)) - loaded_params = {"PaliGemma": flax.traverse_util.unflatten_dict(flat_params, sep="/")["params"]} - # Add all missing weights. - return _merge_params(loaded_params, params, missing_regex=".*") - - -def _merge_params(loaded_params: at.Params, params: at.Params, *, missing_regex: str) -> at.Params: - """Merges the loaded parameters with the reference parameters. - - Args: - loaded_params: The parameters to merge. - params: The reference parameters. - missing_regex: A regex pattern for all missing keys that should be merged from the reference parameters. - - Returns: - A new dictionary with the merged parameters. - """ - flat_ref = flax.traverse_util.flatten_dict(params, sep="/") - flat_loaded = flax.traverse_util.flatten_dict(loaded_params, sep="/") - - # First, take all weights that are a subset of the reference weights. - result = {} - for k, v in flat_loaded.items(): - if k in flat_ref: - result[k] = v.astype(flat_ref[k].dtype) if v.dtype != flat_ref[k].dtype else v - - flat_loaded.clear() - - # Then, merge any missing weights as defined by the missing regex. - pattern = re.compile(missing_regex) - for k in {k for k in flat_ref if pattern.fullmatch(k)}: - if k not in result: - result[k] = flat_ref[k] - - return flax.traverse_util.unflatten_dict(result, sep="/") diff --git a/capvector-pi05/src/openpi/transforms.py b/capvector-pi05/src/openpi/transforms.py deleted file mode 100644 index 782cf4cab861f39d4535563fb1f5b8aaf1198b49..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/transforms.py +++ /dev/null @@ -1,469 +0,0 @@ -from collections.abc import Callable, Mapping, Sequence -import dataclasses -import re -from typing import Protocol, TypeAlias, TypeVar, runtime_checkable - -import flax.traverse_util as traverse_util -import jax -import numpy as np -from openpi_client import image_tools - -from openpi.models import tokenizer as _tokenizer -from openpi.shared import array_typing as at -from openpi.shared import normalize as _normalize - -DataDict: TypeAlias = at.PyTree -NormStats: TypeAlias = _normalize.NormStats - - -T = TypeVar("T") -S = TypeVar("S") - - -@runtime_checkable -class DataTransformFn(Protocol): - def __call__(self, data: DataDict) -> DataDict: - """Apply transformation to the data. - - Args: - data: The data to apply the transform to. This is a possibly nested dictionary that contains - unbatched data elements. Each leaf is expected to be a numpy array. Using JAX arrays is allowed - but not recommended since it may result in extra GPU memory usage inside data loader worker - processes. - - Returns: - The transformed data. Could be the input `data` that was modified in place, or a new data structure. - """ - - -@dataclasses.dataclass(frozen=True) -class Group: - """A group of transforms.""" - - # Transforms that are applied to the model input data. - inputs: Sequence[DataTransformFn] = () - - # Transforms that are applied to the model output data. - outputs: Sequence[DataTransformFn] = () - - def push(self, *, inputs: Sequence[DataTransformFn] = (), outputs: Sequence[DataTransformFn] = ()) -> "Group": - """Append transforms to the group and return a new group. - - Args: - inputs: Appended to the *end* of the current input transforms. - outputs: Appended to the *beginning* of the current output transforms. - - Returns: - A new group with the appended transforms. - """ - return Group(inputs=(*self.inputs, *inputs), outputs=(*outputs, *self.outputs)) - - -@dataclasses.dataclass(frozen=True) -class CompositeTransform(DataTransformFn): - """A composite transform that applies a sequence of transforms in order.""" - - transforms: Sequence[DataTransformFn] - - def __call__(self, data: DataDict) -> DataDict: - for transform in self.transforms: - data = transform(data) - return data - - -def compose(transforms: Sequence[DataTransformFn]) -> DataTransformFn: - """Compose a sequence of transforms into a single transform.""" - return CompositeTransform(transforms) - - -@dataclasses.dataclass(frozen=True) -class RepackTransform(DataTransformFn): - """Repacks an input dictionary into a new dictionary. - - Repacking is defined using a dictionary where the keys are the new keys and the values - are the flattened paths to the old keys. We use '/' as the separator during flattening. - - Example: - { - "images": { - "cam_high": "observation.images.top", - "cam_low": "observation.images.bottom", - }, - "state": "observation.state", - "actions": "action", - } - """ - - structure: at.PyTree[str] - - def __call__(self, data: DataDict) -> DataDict: - flat_item = flatten_dict(data) - return jax.tree.map(lambda k: flat_item[k], self.structure) - - -@dataclasses.dataclass(frozen=True) -class InjectDefaultPrompt(DataTransformFn): - prompt: str | None - - def __call__(self, data: DataDict) -> DataDict: - if self.prompt is not None and "prompt" not in data: - data["prompt"] = np.asarray(self.prompt) - return data - - -@dataclasses.dataclass(frozen=True) -class Normalize(DataTransformFn): - norm_stats: at.PyTree[NormStats] | None - # If true, will use quantile normalization. Otherwise, normal z-score normalization will be used. - use_quantiles: bool = False - # If true, will raise an error if any of the keys in the norm stats are not present in the data. - strict: bool = False - - def __post_init__(self): - if self.norm_stats is not None and self.use_quantiles: - _assert_quantile_stats(self.norm_stats) - - def __call__(self, data: DataDict) -> DataDict: - if self.norm_stats is None: - return data - - return apply_tree( - data, - self.norm_stats, - self._normalize_quantile if self.use_quantiles else self._normalize, - strict=self.strict, - ) - - def _normalize(self, x, stats: NormStats): - mean, std = stats.mean[..., : x.shape[-1]], stats.std[..., : x.shape[-1]] - return (x - mean) / (std + 1e-6) - - def _normalize_quantile(self, x, stats: NormStats): - assert stats.q01 is not None - assert stats.q99 is not None - q01, q99 = stats.q01[..., : x.shape[-1]], stats.q99[..., : x.shape[-1]] - return (x - q01) / (q99 - q01 + 1e-6) * 2.0 - 1.0 - - -@dataclasses.dataclass(frozen=True) -class Unnormalize(DataTransformFn): - norm_stats: at.PyTree[NormStats] | None - # If true, will use quantile normalization. Otherwise, normal z-score normalization will be used. - use_quantiles: bool = False - - def __post_init__(self): - if self.norm_stats is not None and self.use_quantiles: - _assert_quantile_stats(self.norm_stats) - - def __call__(self, data: DataDict) -> DataDict: - if self.norm_stats is None: - return data - - # Make sure that all the keys in the norm stats are present in the data. - return apply_tree( - data, - self.norm_stats, - self._unnormalize_quantile if self.use_quantiles else self._unnormalize, - strict=True, - ) - - def _unnormalize(self, x, stats: NormStats): - mean = pad_to_dim(stats.mean, x.shape[-1], axis=-1, value=0.0) - std = pad_to_dim(stats.std, x.shape[-1], axis=-1, value=1.0) - return x * (std + 1e-6) + mean - - def _unnormalize_quantile(self, x, stats: NormStats): - assert stats.q01 is not None - assert stats.q99 is not None - q01, q99 = stats.q01, stats.q99 - if (dim := q01.shape[-1]) < x.shape[-1]: - return np.concatenate([(x[..., :dim] + 1.0) / 2.0 * (q99 - q01 + 1e-6) + q01, x[..., dim:]], axis=-1) - return (x + 1.0) / 2.0 * (q99 - q01 + 1e-6) + q01 - - -@dataclasses.dataclass(frozen=True) -class ResizeImages(DataTransformFn): - height: int - width: int - - def __call__(self, data: DataDict) -> DataDict: - data["image_padding_mask"] = dict() - for cam in data["image"]: - resized_img, img_padding_mask = image_tools.resize_with_pad( - data["image"][cam], - self.height, - self.width, - return_mask=True - ) - data["image"][cam] = resized_img - data["image_padding_mask"][cam] = img_padding_mask - return data - - -@dataclasses.dataclass(frozen=True) -class SubsampleActions(DataTransformFn): - stride: int - - def __call__(self, data: DataDict) -> DataDict: - data["actions"] = data["actions"][:: self.stride] - return data - - -@dataclasses.dataclass(frozen=True) -class DeltaActions(DataTransformFn): - """Repacks absolute actions into delta action space.""" - - # Boolean mask for the action dimensions to be repacked into delta action space. Length - # can be smaller than the actual number of dimensions. If None, this transform is a no-op. - # See `make_bool_mask` for more details. - mask: Sequence[bool] | None - - def __call__(self, data: DataDict) -> DataDict: - if "actions" not in data or self.mask is None: - return data - - state, actions = data["state"], data["actions"] - mask = np.asarray(self.mask) - dims = mask.shape[-1] - actions[..., :dims] -= np.expand_dims(np.where(mask, state[..., :dims], 0), axis=-2) - data["actions"] = actions - - return data - - -@dataclasses.dataclass(frozen=True) -class AbsoluteActions(DataTransformFn): - """Repacks delta actions into absolute action space.""" - - # Boolean mask for the action dimensions to be repacked into absolute action space. Length - # can be smaller than the actual number of dimensions. If None, this transform is a no-op. - # See `make_bool_mask` for more details. - mask: Sequence[bool] | None - - def __call__(self, data: DataDict) -> DataDict: - if "actions" not in data or self.mask is None: - return data - - state, actions = data["state"], data["actions"] - mask = np.asarray(self.mask) - dims = mask.shape[-1] - actions[..., :dims] += np.expand_dims(np.where(mask, state[..., :dims], 0), axis=-2) - data["actions"] = actions - - return data - - -@dataclasses.dataclass(frozen=True) -class TokenizePrompt(DataTransformFn): - tokenizer: _tokenizer.PaligemmaTokenizer - discrete_state_input: bool = False - - def __call__(self, data: DataDict) -> DataDict: - if (prompt := data.pop("prompt", None)) is None: - raise ValueError("Prompt is required") - - if self.discrete_state_input: - if (state := data.get("state", None)) is None: - raise ValueError("State is required.") - else: - state = None - - if not isinstance(prompt, str): - prompt = prompt.item() - - tokens, token_masks = self.tokenizer.tokenize(prompt, state) - return {**data, "tokenized_prompt": tokens, "tokenized_prompt_mask": token_masks} - - -@dataclasses.dataclass(frozen=True) -class TokenizeFASTInputs(DataTransformFn): - tokenizer: _tokenizer.FASTTokenizer - - def __call__(self, data: DataDict) -> DataDict: - if (prompt := data.pop("prompt", None)) is None: - raise ValueError("Prompt is required") - - if not isinstance(prompt, str): - prompt = prompt.item() - - state, actions = data["state"], data.get("actions") - tokens, token_mask, ar_mask, loss_mask = self.tokenizer.tokenize(prompt, state, actions) - return { - **data, - "tokenized_prompt": tokens, - "tokenized_prompt_mask": token_mask, - "token_ar_mask": ar_mask, - "token_loss_mask": loss_mask, - } - - -@dataclasses.dataclass(frozen=True) -class ExtractFASTActions(DataTransformFn): - tokenizer: _tokenizer.FASTTokenizer - action_horizon: int - action_dim: int - - def __call__(self, data: DataDict) -> DataDict: - if "actions" not in data: - return data - # Model outputs are saved in "actions", but for FAST models they represent tokens. - tokens = data.pop("actions") - actions = self.tokenizer.extract_actions(tokens.astype(np.int32), self.action_horizon, self.action_dim) - return { - **data, - "actions": actions, - } - - -@dataclasses.dataclass(frozen=True) -class PromptFromLeRobotTask(DataTransformFn): - """Extracts a prompt from the current LeRobot dataset task.""" - - # Contains the LeRobot dataset tasks (dataset.meta.tasks). - tasks: dict[int, str] - - def __call__(self, data: DataDict) -> DataDict: - if "task_index" not in data: - raise ValueError('Cannot extract prompt without "task_index"') - - task_index = int(data["task_index"]) - if (prompt := self.tasks.get(task_index)) is None: - raise ValueError(f"{task_index=} not found in task mapping: {self.tasks}") - - return {**data, "prompt": prompt} - - -@dataclasses.dataclass(frozen=True) -class PadStatesAndActions(DataTransformFn): - """Zero-pads states and actions to the model action dimension.""" - - model_action_dim: int - - def __call__(self, data: DataDict) -> DataDict: - data["state"] = pad_to_dim(data["state"], self.model_action_dim, axis=-1) - if "actions" in data: - data["actions"] = pad_to_dim(data["actions"], self.model_action_dim, axis=-1) - return data - - -def flatten_dict(tree: at.PyTree) -> dict: - """Flatten a nested dictionary. Uses '/' as the separator.""" - return traverse_util.flatten_dict(tree, sep="/") - - -def unflatten_dict(tree: dict) -> at.PyTree: - """Unflatten a flattened dictionary. Assumes that '/' was used as a separator.""" - return traverse_util.unflatten_dict(tree, sep="/") - - -def transform_dict(patterns: Mapping[str, str | None], tree: at.PyTree) -> at.PyTree: - """Transform the structure of a nested dictionary using a set of patterns. - - The transformation is defined using the `patterns` dictionary. The keys are the - input keys that should be matched and the values are the new names inside the output - dictionary. If the value is None, the input key is removed. - - Both keys and values should represent flattened paths using '/' as the separator. - Keys can be regular expressions and values can include backreferences to the - matched groups (see `re.sub` for more details). Note that the regular expression - must match the entire key. - - The order inside the `patterns` dictionary is important. Only the first pattern that - matches the input key will be used. - - See unit tests for more examples. - - Args: - patterns: A mapping from old keys to new keys. - tree: The nested dictionary to transform. - - Returns: - The transformed nested dictionary. - """ - data = flatten_dict(tree) - - # Compile the patterns. - compiled = {re.compile(k): v for k, v in patterns.items()} - - output = {} - for k in data: - for pattern, repl in compiled.items(): - if pattern.fullmatch(k): - new_k = pattern.sub(repl, k, count=1) if repl is not None else None - break - else: - # Use the original key if no match is found. - new_k = k - - if new_k is not None: - if new_k in output: - raise ValueError(f"Key '{new_k}' already exists in output") - output[new_k] = data[k] - - # Validate the output structure to make sure that it can be unflattened. - names = sorted(output) - for i in range(len(names) - 1): - name, next_name = names[i : i + 2] - if next_name.startswith(name + "/"): - raise ValueError(f"Leaf '{name}' aliases a node of '{next_name}'") - - return unflatten_dict(output) - - -def apply_tree( - tree: at.PyTree[T], selector: at.PyTree[S], fn: Callable[[T, S], T], *, strict: bool = False -) -> at.PyTree[T]: - tree = flatten_dict(tree) - selector = flatten_dict(selector) - - def transform(k: str, v: T) -> T: - if k in selector: - return fn(v, selector[k]) - return v - - if strict: - for k in selector: - if k not in tree: - raise ValueError(f"Selector key {k} not found in tree") - - return unflatten_dict({k: transform(k, v) for k, v in tree.items()}) - - -def pad_to_dim(x: np.ndarray, target_dim: int, axis: int = -1, value: float = 0.0) -> np.ndarray: - """Pad an array to the target dimension with zeros along the specified axis.""" - current_dim = x.shape[axis] - if current_dim < target_dim: - pad_width = [(0, 0)] * len(x.shape) - pad_width[axis] = (0, target_dim - current_dim) - return np.pad(x, pad_width, constant_values=value) - return x - - -def make_bool_mask(*dims: int) -> tuple[bool, ...]: - """Make a boolean mask for the given dimensions. - - Example: - make_bool_mask(2, -2, 2) == (True, True, False, False, True, True) - make_bool_mask(2, 0, 2) == (True, True, True, True) - - Args: - dims: The dimensions to make the mask for. - - Returns: - A tuple of booleans. - """ - result = [] - for dim in dims: - if dim > 0: - result.extend([True] * (dim)) - else: - result.extend([False] * (-dim)) - return tuple(result) - - -def _assert_quantile_stats(norm_stats: at.PyTree[NormStats]) -> None: - for k, v in flatten_dict(norm_stats).items(): - if v.q01 is None or v.q99 is None: - raise ValueError( - f"quantile stats must be provided if use_quantile_norm is True. Key {k} is missing q01 or q99." - ) diff --git a/capvector-pi05/src/openpi/transforms_test.py b/capvector-pi05/src/openpi/transforms_test.py deleted file mode 100644 index 2ef17015132c940545a7af27b470806412431fbe..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/openpi/transforms_test.py +++ /dev/null @@ -1,121 +0,0 @@ -import numpy as np -import pytest - -import openpi.models.tokenizer as _tokenizer -import openpi.transforms as _transforms - - -def test_repack_transform(): - transform = _transforms.RepackTransform( - structure={ - "a": {"b": "b/c"}, - "d": "e/f", - } - ) - item = {"b": {"c": 1}, "e": {"f": 2}} - assert transform(item) == {"a": {"b": 1}, "d": 2} - - -def test_delta_actions(): - item = {"state": np.array([1, 2, 3]), "actions": np.array([[3, 4, 5], [5, 6, 7]])} - - transform = _transforms.DeltaActions(mask=[False, True]) - transformed = transform(item) - - assert np.all(transformed["state"] == np.array([1, 2, 3])) - assert np.all(transformed["actions"] == np.array([[3, 2, 5], [5, 4, 7]])) - - -def test_delta_actions_noop(): - item = {"state": np.array([1, 2, 3]), "actions": np.array([[3, 4, 5], [5, 6, 7]])} - - # No-op when the mask is disabled. - transform = _transforms.DeltaActions(mask=None) - assert transform(item) is item - - # No-op when there are no actions in the input. - del item["actions"] - transform = _transforms.DeltaActions(mask=[True, False]) - assert transform(item) is item - - -def test_absolute_actions(): - item = {"state": np.array([1, 2, 3]), "actions": np.array([[3, 4, 5], [5, 6, 7]])} - - transform = _transforms.AbsoluteActions(mask=[False, True]) - transformed = transform(item) - - assert np.all(transformed["state"] == np.array([1, 2, 3])) - assert np.all(transformed["actions"] == np.array([[3, 6, 5], [5, 8, 7]])) - - -def test_absolute_actions_noop(): - item = {"state": np.array([1, 2, 3]), "actions": np.array([[3, 4, 5], [5, 6, 7]])} - - # No-op when the mask is disabled. - transform = _transforms.AbsoluteActions(mask=None) - assert transform(item) is item - - # No-op when there are no actions in the input. - del item["actions"] - transform = _transforms.AbsoluteActions(mask=[True, False]) - assert transform(item) is item - - -def test_make_bool_mask(): - assert _transforms.make_bool_mask(2, -2, 2) == (True, True, False, False, True, True) - assert _transforms.make_bool_mask(2, 0, 2) == (True, True, True, True) - - -def test_tokenize_prompt(): - tokenizer = _tokenizer.PaligemmaTokenizer(max_len=12) - transform = _transforms.TokenizePrompt(tokenizer) - - data = transform({"prompt": "Hello, world!"}) - - tok_prompt, tok_mask = tokenizer.tokenize("Hello, world!") - assert np.allclose(tok_prompt, data["tokenized_prompt"]) - assert np.allclose(tok_mask, data["tokenized_prompt_mask"]) - - -def test_tokenize_no_prompt(): - transform = _transforms.TokenizePrompt(_tokenizer.PaligemmaTokenizer()) - - with pytest.raises(ValueError, match="Prompt is required"): - transform({}) - - -def test_transform_dict(): - # Rename and remove keys. - input = {"a": {"b": 1, "c": 2}} - output = _transforms.transform_dict({"a/b": "a/c", "a/c": None}, input) - assert output == {"a": {"c": 1}} - - # Raises and error since the renamed key conflicts with an existing key. - with pytest.raises(ValueError, match="Key 'a/c' already exists in output"): - _transforms.transform_dict({"a/b": "a/c"}, input) - - # Full match is required and so nothing will be removed. - input = {"a": {"b": 1, "c": 2}} - output = _transforms.transform_dict({"a": None}, input) - assert output == input - - # The regex matches the entire key and so the entire input will be removed. - input = {"a": {"b": 1, "c": 2}} - output = _transforms.transform_dict({"a.+": None}, input) - assert output == {} - - # Replace keys using backreferences. All leaves named 'c' are replaced with 'd'. - input = {"a": {"b": 1, "c": 1}, "b": {"c": 2}} - output = _transforms.transform_dict({"(.+)/c": r"\1/d"}, input) - assert output == {"a": {"b": 1, "d": 1}, "b": {"d": 2}} - - -def test_extract_prompt_from_task(): - transform = _transforms.PromptFromLeRobotTask({1: "Hello, world!"}) - - data = transform({"task_index": 1}) - assert data["prompt"] == "Hello, world!" - - with pytest.raises(ValueError, match="task_index=2 not found in task mapping"): - transform({"task_index": 2}) diff --git a/capvector-pi05/src/vggt/__init__.py b/capvector-pi05/src/vggt/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/capvector-pi05/src/vggt/dependency/__init__.py b/capvector-pi05/src/vggt/dependency/__init__.py deleted file mode 100644 index 2dad12e149190cb0f746c1fbdf306614e6302714..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/dependency/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .track_modules.track_refine import refine_track -from .track_modules.blocks import BasicEncoder, ShallowEncoder -from .track_modules.base_track_predictor import BaseTrackerPredictor diff --git a/capvector-pi05/src/vggt/dependency/distortion.py b/capvector-pi05/src/vggt/dependency/distortion.py deleted file mode 100644 index 3dad24807a04ddaf61917d4cb3aaf086a7c4095a..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/dependency/distortion.py +++ /dev/null @@ -1,182 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import numpy as np -from typing import Union - -ArrayLike = Union[np.ndarray, torch.Tensor] - - -def _is_numpy(x: ArrayLike) -> bool: - return isinstance(x, np.ndarray) - - -def _is_torch(x: ArrayLike) -> bool: - return isinstance(x, torch.Tensor) - - -def _ensure_torch(x: ArrayLike) -> torch.Tensor: - """Convert input to torch tensor if it's not already one.""" - if _is_numpy(x): - return torch.from_numpy(x) - elif _is_torch(x): - return x - else: - return torch.tensor(x) - - -def single_undistortion(params, tracks_normalized): - """ - Apply undistortion to the normalized tracks using the given distortion parameters once. - - Args: - params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN. - tracks_normalized (torch.Tensor or numpy.ndarray): Normalized tracks tensor of shape [batch_size, num_tracks, 2]. - - Returns: - torch.Tensor: Undistorted normalized tracks tensor. - """ - params = _ensure_torch(params) - tracks_normalized = _ensure_torch(tracks_normalized) - - u, v = tracks_normalized[..., 0].clone(), tracks_normalized[..., 1].clone() - u_undist, v_undist = apply_distortion(params, u, v) - return torch.stack([u_undist, v_undist], dim=-1) - - -def iterative_undistortion(params, tracks_normalized, max_iterations=100, max_step_norm=1e-10, rel_step_size=1e-6): - """ - Iteratively undistort the normalized tracks using the given distortion parameters. - - Args: - params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN. - tracks_normalized (torch.Tensor or numpy.ndarray): Normalized tracks tensor of shape [batch_size, num_tracks, 2]. - max_iterations (int): Maximum number of iterations for the undistortion process. - max_step_norm (float): Maximum step norm for convergence. - rel_step_size (float): Relative step size for numerical differentiation. - - Returns: - torch.Tensor: Undistorted normalized tracks tensor. - """ - params = _ensure_torch(params) - tracks_normalized = _ensure_torch(tracks_normalized) - - B, N, _ = tracks_normalized.shape - u, v = tracks_normalized[..., 0].clone(), tracks_normalized[..., 1].clone() - original_u, original_v = u.clone(), v.clone() - - eps = torch.finfo(u.dtype).eps - for idx in range(max_iterations): - u_undist, v_undist = apply_distortion(params, u, v) - dx = original_u - u_undist - dy = original_v - v_undist - - step_u = torch.clamp(torch.abs(u) * rel_step_size, min=eps) - step_v = torch.clamp(torch.abs(v) * rel_step_size, min=eps) - - J_00 = (apply_distortion(params, u + step_u, v)[0] - apply_distortion(params, u - step_u, v)[0]) / (2 * step_u) - J_01 = (apply_distortion(params, u, v + step_v)[0] - apply_distortion(params, u, v - step_v)[0]) / (2 * step_v) - J_10 = (apply_distortion(params, u + step_u, v)[1] - apply_distortion(params, u - step_u, v)[1]) / (2 * step_u) - J_11 = (apply_distortion(params, u, v + step_v)[1] - apply_distortion(params, u, v - step_v)[1]) / (2 * step_v) - - J = torch.stack([torch.stack([J_00 + 1, J_01], dim=-1), torch.stack([J_10, J_11 + 1], dim=-1)], dim=-2) - - delta = torch.linalg.solve(J, torch.stack([dx, dy], dim=-1)) - - u += delta[..., 0] - v += delta[..., 1] - - if torch.max((delta**2).sum(dim=-1)) < max_step_norm: - break - - return torch.stack([u, v], dim=-1) - - -def apply_distortion(extra_params, u, v): - """ - Applies radial or OpenCV distortion to the given 2D points. - - Args: - extra_params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN, where N can be 1, 2, or 4. - u (torch.Tensor or numpy.ndarray): Normalized x coordinates of shape Bxnum_tracks. - v (torch.Tensor or numpy.ndarray): Normalized y coordinates of shape Bxnum_tracks. - - Returns: - points2D (torch.Tensor): Distorted 2D points of shape BxNx2. - """ - extra_params = _ensure_torch(extra_params) - u = _ensure_torch(u) - v = _ensure_torch(v) - - num_params = extra_params.shape[1] - - if num_params == 1: - # Simple radial distortion - k = extra_params[:, 0] - u2 = u * u - v2 = v * v - r2 = u2 + v2 - radial = k[:, None] * r2 - du = u * radial - dv = v * radial - - elif num_params == 2: - # RadialCameraModel distortion - k1, k2 = extra_params[:, 0], extra_params[:, 1] - u2 = u * u - v2 = v * v - r2 = u2 + v2 - radial = k1[:, None] * r2 + k2[:, None] * r2 * r2 - du = u * radial - dv = v * radial - - elif num_params == 4: - # OpenCVCameraModel distortion - k1, k2, p1, p2 = (extra_params[:, 0], extra_params[:, 1], extra_params[:, 2], extra_params[:, 3]) - u2 = u * u - v2 = v * v - uv = u * v - r2 = u2 + v2 - radial = k1[:, None] * r2 + k2[:, None] * r2 * r2 - du = u * radial + 2 * p1[:, None] * uv + p2[:, None] * (r2 + 2 * u2) - dv = v * radial + 2 * p2[:, None] * uv + p1[:, None] * (r2 + 2 * v2) - else: - raise ValueError("Unsupported number of distortion parameters") - - u = u.clone() + du - v = v.clone() + dv - - return u, v - - -if __name__ == "__main__": - import random - import pycolmap - - max_diff = 0 - for i in range(1000): - # Define distortion parameters (assuming 1 parameter for simplicity) - B = random.randint(1, 500) - track_num = random.randint(100, 1000) - params = torch.rand((B, 1), dtype=torch.float32) # Batch size 1, 4 parameters - tracks_normalized = torch.rand((B, track_num, 2), dtype=torch.float32) # Batch size 1, 5 points - - # Undistort the tracks - undistorted_tracks = iterative_undistortion(params, tracks_normalized) - - for b in range(B): - pycolmap_intri = np.array([1, 0, 0, params[b].item()]) - pycam = pycolmap.Camera(model="SIMPLE_RADIAL", width=1, height=1, params=pycolmap_intri, camera_id=0) - - undistorted_tracks_pycolmap = pycam.cam_from_img(tracks_normalized[b].numpy()) - diff = (undistorted_tracks[b] - undistorted_tracks_pycolmap).abs().median() - max_diff = max(max_diff, diff) - print(f"diff: {diff}, max_diff: {max_diff}") - - import pdb - - pdb.set_trace() diff --git a/capvector-pi05/src/vggt/dependency/np_to_pycolmap.py b/capvector-pi05/src/vggt/dependency/np_to_pycolmap.py deleted file mode 100644 index a49c1fb856a69f329dbe4be3ea7627e4e5676f53..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/dependency/np_to_pycolmap.py +++ /dev/null @@ -1,320 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import numpy as np -import pycolmap -from .projection import project_3D_points_np - - -def batch_np_matrix_to_pycolmap( - points3d, - extrinsics, - intrinsics, - tracks, - image_size, - masks=None, - max_reproj_error=None, - max_points3D_val=3000, - shared_camera=False, - camera_type="SIMPLE_PINHOLE", - extra_params=None, - min_inlier_per_frame=64, - points_rgb=None, -): - """ - Convert Batched NumPy Arrays to PyCOLMAP - - Check https://github.com/colmap/pycolmap for more details about its format - - NOTE that colmap expects images/cameras/points3D to be 1-indexed - so there is a +1 offset between colmap index and batch index - - - NOTE: different from VGGSfM, this function: - 1. Use np instead of torch - 2. Frame index and camera id starts from 1 rather than 0 (to fit the format of PyCOLMAP) - """ - # points3d: Px3 - # extrinsics: Nx3x4 - # intrinsics: Nx3x3 - # tracks: NxPx2 - # masks: NxP - # image_size: 2, assume all the frames have been padded to the same size - # where N is the number of frames and P is the number of tracks - - N, P, _ = tracks.shape - assert len(extrinsics) == N - assert len(intrinsics) == N - assert len(points3d) == P - assert image_size.shape[0] == 2 - - reproj_mask = None - - if max_reproj_error is not None: - projected_points_2d, projected_points_cam = project_3D_points_np(points3d, extrinsics, intrinsics) - projected_diff = np.linalg.norm(projected_points_2d - tracks, axis=-1) - projected_points_2d[projected_points_cam[:, -1] <= 0] = 1e6 - reproj_mask = projected_diff < max_reproj_error - - if masks is not None and reproj_mask is not None: - masks = np.logical_and(masks, reproj_mask) - elif masks is not None: - masks = masks - else: - masks = reproj_mask - - assert masks is not None - - if masks.sum(1).min() < min_inlier_per_frame: - print(f"Not enough inliers per frame, skip BA.") - return None, None - - # Reconstruction object, following the format of PyCOLMAP/COLMAP - reconstruction = pycolmap.Reconstruction() - - inlier_num = masks.sum(0) - valid_mask = inlier_num >= 2 # a track is invalid if without two inliers - valid_idx = np.nonzero(valid_mask)[0] - - # Only add 3D points that have sufficient 2D points - for vidx in valid_idx: - # Use RGB colors if provided, otherwise use zeros - rgb = points_rgb[vidx] if points_rgb is not None else np.zeros(3) - reconstruction.add_point3D(points3d[vidx], pycolmap.Track(), rgb) - - num_points3D = len(valid_idx) - camera = None - # frame idx - for fidx in range(N): - # set camera - if camera is None or (not shared_camera): - pycolmap_intri = _build_pycolmap_intri(fidx, intrinsics, camera_type, extra_params) - - camera = pycolmap.Camera( - model=camera_type, width=image_size[0], height=image_size[1], params=pycolmap_intri, camera_id=fidx + 1 - ) - - # add camera - reconstruction.add_camera(camera) - - # set image - cam_from_world = pycolmap.Rigid3d( - pycolmap.Rotation3d(extrinsics[fidx][:3, :3]), extrinsics[fidx][:3, 3] - ) # Rot and Trans - - image = pycolmap.Image( - id=fidx + 1, name=f"image_{fidx + 1}", camera_id=camera.camera_id, cam_from_world=cam_from_world - ) - - points2D_list = [] - - point2D_idx = 0 - - # NOTE point3D_id start by 1 - for point3D_id in range(1, num_points3D + 1): - original_track_idx = valid_idx[point3D_id - 1] - - if (reconstruction.points3D[point3D_id].xyz < max_points3D_val).all(): - if masks[fidx][original_track_idx]: - # It seems we don't need +0.5 for BA - point2D_xy = tracks[fidx][original_track_idx] - # Please note when adding the Point2D object - # It not only requires the 2D xy location, but also the id to 3D point - points2D_list.append(pycolmap.Point2D(point2D_xy, point3D_id)) - - # add element - track = reconstruction.points3D[point3D_id].track - track.add_element(fidx + 1, point2D_idx) - point2D_idx += 1 - - assert point2D_idx == len(points2D_list) - - try: - image.points2D = pycolmap.ListPoint2D(points2D_list) - image.registered = True - except: - print(f"frame {fidx + 1} is out of BA") - image.registered = False - - # add image - reconstruction.add_image(image) - - return reconstruction, valid_mask - - -def pycolmap_to_batch_np_matrix(reconstruction, device="cpu", camera_type="SIMPLE_PINHOLE"): - """ - Convert a PyCOLMAP Reconstruction Object to batched NumPy arrays. - - Args: - reconstruction (pycolmap.Reconstruction): The reconstruction object from PyCOLMAP. - device (str): Ignored in NumPy version (kept for API compatibility). - camera_type (str): The type of camera model used (default: "SIMPLE_PINHOLE"). - - Returns: - tuple: A tuple containing points3D, extrinsics, intrinsics, and optionally extra_params. - """ - - num_images = len(reconstruction.images) - max_points3D_id = max(reconstruction.point3D_ids()) - points3D = np.zeros((max_points3D_id, 3)) - - for point3D_id in reconstruction.points3D: - points3D[point3D_id - 1] = reconstruction.points3D[point3D_id].xyz - - extrinsics = [] - intrinsics = [] - - extra_params = [] if camera_type == "SIMPLE_RADIAL" else None - - for i in range(num_images): - # Extract and append extrinsics - pyimg = reconstruction.images[i + 1] - pycam = reconstruction.cameras[pyimg.camera_id] - matrix = pyimg.cam_from_world.matrix() - extrinsics.append(matrix) - - # Extract and append intrinsics - calibration_matrix = pycam.calibration_matrix() - intrinsics.append(calibration_matrix) - - if camera_type == "SIMPLE_RADIAL": - extra_params.append(pycam.params[-1]) - - # Convert lists to NumPy arrays instead of torch tensors - extrinsics = np.stack(extrinsics) - intrinsics = np.stack(intrinsics) - - if camera_type == "SIMPLE_RADIAL": - extra_params = np.stack(extra_params) - extra_params = extra_params[:, None] - - return points3D, extrinsics, intrinsics, extra_params - - -######################################################## - - -def batch_np_matrix_to_pycolmap_wo_track( - points3d, - points_xyf, - points_rgb, - extrinsics, - intrinsics, - image_size, - shared_camera=False, - camera_type="SIMPLE_PINHOLE", -): - """ - Convert Batched NumPy Arrays to PyCOLMAP - - Different from batch_np_matrix_to_pycolmap, this function does not use tracks. - - It saves points3d to colmap reconstruction format only to serve as init for Gaussians or other nvs methods. - - Do NOT use this for BA. - """ - # points3d: Px3 - # points_xyf: Px3, with x, y coordinates and frame indices - # points_rgb: Px3, rgb colors - # extrinsics: Nx3x4 - # intrinsics: Nx3x3 - # image_size: 2, assume all the frames have been padded to the same size - # where N is the number of frames and P is the number of tracks - - N = len(extrinsics) - P = len(points3d) - - # Reconstruction object, following the format of PyCOLMAP/COLMAP - reconstruction = pycolmap.Reconstruction() - - for vidx in range(P): - reconstruction.add_point3D(points3d[vidx], pycolmap.Track(), points_rgb[vidx]) - - camera = None - # frame idx - for fidx in range(N): - # set camera - if camera is None or (not shared_camera): - pycolmap_intri = _build_pycolmap_intri(fidx, intrinsics, camera_type) - - camera = pycolmap.Camera( - model=camera_type, width=image_size[0], height=image_size[1], params=pycolmap_intri, camera_id=fidx + 1 - ) - - # add camera - reconstruction.add_camera(camera) - - # set image - cam_from_world = pycolmap.Rigid3d( - pycolmap.Rotation3d(extrinsics[fidx][:3, :3]), extrinsics[fidx][:3, 3] - ) # Rot and Trans - - image = pycolmap.Image( - id=fidx + 1, name=f"image_{fidx + 1}", camera_id=camera.camera_id, cam_from_world=cam_from_world - ) - - points2D_list = [] - - point2D_idx = 0 - - points_belong_to_fidx = points_xyf[:, 2].astype(np.int32) == fidx - points_belong_to_fidx = np.nonzero(points_belong_to_fidx)[0] - - for point3D_batch_idx in points_belong_to_fidx: - point3D_id = point3D_batch_idx + 1 - point2D_xyf = points_xyf[point3D_batch_idx] - point2D_xy = point2D_xyf[:2] - points2D_list.append(pycolmap.Point2D(point2D_xy, point3D_id)) - - # add element - track = reconstruction.points3D[point3D_id].track - track.add_element(fidx + 1, point2D_idx) - point2D_idx += 1 - - assert point2D_idx == len(points2D_list) - - try: - image.points2D = pycolmap.ListPoint2D(points2D_list) - image.registered = True - except: - print(f"frame {fidx + 1} does not have any points") - image.registered = False - - # add image - reconstruction.add_image(image) - - return reconstruction - - -def _build_pycolmap_intri(fidx, intrinsics, camera_type, extra_params=None): - """ - Helper function to get camera parameters based on camera type. - - Args: - fidx: Frame index - intrinsics: Camera intrinsic parameters - camera_type: Type of camera model - extra_params: Additional parameters for certain camera types - - Returns: - pycolmap_intri: NumPy array of camera parameters - """ - if camera_type == "PINHOLE": - pycolmap_intri = np.array( - [intrinsics[fidx][0, 0], intrinsics[fidx][1, 1], intrinsics[fidx][0, 2], intrinsics[fidx][1, 2]] - ) - elif camera_type == "SIMPLE_PINHOLE": - focal = (intrinsics[fidx][0, 0] + intrinsics[fidx][1, 1]) / 2 - pycolmap_intri = np.array([focal, intrinsics[fidx][0, 2], intrinsics[fidx][1, 2]]) - elif camera_type == "SIMPLE_RADIAL": - raise NotImplementedError("SIMPLE_RADIAL is not supported yet") - focal = (intrinsics[fidx][0, 0] + intrinsics[fidx][1, 1]) / 2 - pycolmap_intri = np.array([focal, intrinsics[fidx][0, 2], intrinsics[fidx][1, 2], extra_params[fidx][0]]) - else: - raise ValueError(f"Camera type {camera_type} is not supported yet") - - return pycolmap_intri diff --git a/capvector-pi05/src/vggt/dependency/projection.py b/capvector-pi05/src/vggt/dependency/projection.py deleted file mode 100644 index 38fd175fe6fce096ef2bfb6b0996085c5ae44fee..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/dependency/projection.py +++ /dev/null @@ -1,228 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import numpy as np -from .distortion import apply_distortion - - -def img_from_cam_np( - intrinsics: np.ndarray, points_cam: np.ndarray, extra_params: np.ndarray | None = None, default: float = 0.0 -) -> np.ndarray: - """ - Apply intrinsics (and optional radial distortion) to camera-space points. - - Args - ---- - intrinsics : (B,3,3) camera matrix K. - points_cam : (B,3,N) homogeneous camera coords (x, y, z)ᵀ. - extra_params: (B, N) or (B, k) distortion params (k = 1,2,4) or None. - default : value used for np.nan replacement. - - Returns - ------- - points2D : (B,N,2) pixel coordinates. - """ - # 1. perspective divide ─────────────────────────────────────── - z = points_cam[:, 2:3, :] # (B,1,N) - points_cam_norm = points_cam / z # (B,3,N) - uv = points_cam_norm[:, :2, :] # (B,2,N) - - # 2. optional distortion ────────────────────────────────────── - if extra_params is not None: - uu, vv = apply_distortion(extra_params, uv[:, 0], uv[:, 1]) - uv = np.stack([uu, vv], axis=1) # (B,2,N) - - # 3. homogeneous coords then K multiplication ───────────────── - ones = np.ones_like(uv[:, :1, :]) # (B,1,N) - points_cam_h = np.concatenate([uv, ones], axis=1) # (B,3,N) - - # batched mat-mul: K · [u v 1]ᵀ - points2D_h = np.einsum("bij,bjk->bik", intrinsics, points_cam_h) # (B,3,N) - points2D = np.nan_to_num(points2D_h[:, :2, :], nan=default) # (B,2,N) - - return points2D.transpose(0, 2, 1) # (B,N,2) - - -def project_3D_points_np( - points3D: np.ndarray, - extrinsics: np.ndarray, - intrinsics: np.ndarray | None = None, - extra_params: np.ndarray | None = None, - *, - default: float = 0.0, - only_points_cam: bool = False, -): - """ - NumPy clone of ``project_3D_points``. - - Parameters - ---------- - points3D : (N,3) world-space points. - extrinsics : (B,3,4) [R|t] matrix for each of B cameras. - intrinsics : (B,3,3) K matrix (optional if you only need cam-space). - extra_params : (B,k) or (B,N) distortion parameters (k ∈ {1,2,4}) or None. - default : value used to replace NaNs. - only_points_cam : if True, skip the projection and return points_cam with points2D as None. - - Returns - ------- - (points2D, points_cam) : A tuple where points2D is (B,N,2) pixel coords or None if only_points_cam=True, - and points_cam is (B,3,N) camera-space coordinates. - """ - # ----- 0. prep sizes ----------------------------------------------------- - N = points3D.shape[0] # #points - B = extrinsics.shape[0] # #cameras - - # ----- 1. world → homogeneous ------------------------------------------- - w_h = np.ones((N, 1), dtype=points3D.dtype) - points3D_h = np.concatenate([points3D, w_h], axis=1) # (N,4) - - # broadcast to every camera (no actual copying with np.broadcast_to) ------ - points3D_h_B = np.broadcast_to(points3D_h, (B, N, 4)) # (B,N,4) - - # ----- 2. apply extrinsics (camera frame) ------------------------------ - # X_cam = E · X_hom - # einsum: E_(b i j) · X_(b n j) → (b n i) - points_cam = np.einsum("bij,bnj->bni", extrinsics, points3D_h_B) # (B,N,3) - points_cam = points_cam.transpose(0, 2, 1) # (B,3,N) - - if only_points_cam: - return None, points_cam - - # ----- 3. intrinsics + distortion --------------------------------------- - if intrinsics is None: - raise ValueError("`intrinsics` must be provided unless only_points_cam=True") - - points2D = img_from_cam_np(intrinsics, points_cam, extra_params=extra_params, default=default) - - return points2D, points_cam - - -def project_3D_points(points3D, extrinsics, intrinsics=None, extra_params=None, default=0, only_points_cam=False): - """ - Transforms 3D points to 2D using extrinsic and intrinsic parameters. - Args: - points3D (torch.Tensor): 3D points of shape Px3. - extrinsics (torch.Tensor): Extrinsic parameters of shape Bx3x4. - intrinsics (torch.Tensor): Intrinsic parameters of shape Bx3x3. - extra_params (torch.Tensor): Extra parameters of shape BxN, used for radial distortion. - default (float): Default value to replace NaNs. - only_points_cam (bool): If True, skip the projection and return points2D as None. - - Returns: - tuple: (points2D, points_cam) where points2D is of shape BxNx2 or None if only_points_cam=True, - and points_cam is of shape Bx3xN. - """ - with torch.cuda.amp.autocast(dtype=torch.double): - N = points3D.shape[0] # Number of points - B = extrinsics.shape[0] # Batch size, i.e., number of cameras - points3D_homogeneous = torch.cat([points3D, torch.ones_like(points3D[..., 0:1])], dim=1) # Nx4 - # Reshape for batch processing - points3D_homogeneous = points3D_homogeneous.unsqueeze(0).expand(B, -1, -1) # BxNx4 - - # Step 1: Apply extrinsic parameters - # Transform 3D points to camera coordinate system for all cameras - points_cam = torch.bmm(extrinsics, points3D_homogeneous.transpose(-1, -2)) - - if only_points_cam: - return None, points_cam - - # Step 2: Apply intrinsic parameters and (optional) distortion - points2D = img_from_cam(intrinsics, points_cam, extra_params, default) - - return points2D, points_cam - - -def img_from_cam(intrinsics, points_cam, extra_params=None, default=0.0): - """ - Applies intrinsic parameters and optional distortion to the given 3D points. - - Args: - intrinsics (torch.Tensor): Intrinsic camera parameters of shape Bx3x3. - points_cam (torch.Tensor): 3D points in camera coordinates of shape Bx3xN. - extra_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4. - default (float, optional): Default value to replace NaNs in the output. - - Returns: - points2D (torch.Tensor): 2D points in pixel coordinates of shape BxNx2. - """ - - # Normalize by the third coordinate (homogeneous division) - points_cam = points_cam / points_cam[:, 2:3, :] - # Extract uv - uv = points_cam[:, :2, :] - - # Apply distortion if extra_params are provided - if extra_params is not None: - uu, vv = apply_distortion(extra_params, uv[:, 0], uv[:, 1]) - uv = torch.stack([uu, vv], dim=1) - - # Prepare points_cam for batch matrix multiplication - points_cam_homo = torch.cat((uv, torch.ones_like(uv[:, :1, :])), dim=1) # Bx3xN - # Apply intrinsic parameters using batch matrix multiplication - points2D_homo = torch.bmm(intrinsics, points_cam_homo) # Bx3xN - - # Extract x and y coordinates - points2D = points2D_homo[:, :2, :] # Bx2xN - - # Replace NaNs with default value - points2D = torch.nan_to_num(points2D, nan=default) - - return points2D.transpose(1, 2) # BxNx2 - - -if __name__ == "__main__": - # Set up example input - B, N = 24, 10240 - - for _ in range(100): - points3D = np.random.rand(N, 3).astype(np.float64) - extrinsics = np.random.rand(B, 3, 4).astype(np.float64) - intrinsics = np.random.rand(B, 3, 3).astype(np.float64) - - # Convert to torch tensors - points3D_torch = torch.tensor(points3D) - extrinsics_torch = torch.tensor(extrinsics) - intrinsics_torch = torch.tensor(intrinsics) - - # Run NumPy implementation - points2D_np, points_cam_np = project_3D_points_np(points3D, extrinsics, intrinsics) - - # Run torch implementation - points2D_torch, points_cam_torch = project_3D_points(points3D_torch, extrinsics_torch, intrinsics_torch) - - # Convert torch output to numpy - points2D_torch_np = points2D_torch.detach().numpy() - points_cam_torch_np = points_cam_torch.detach().numpy() - - # Compute difference - diff = np.abs(points2D_np - points2D_torch_np) - print("Difference between NumPy and PyTorch implementations:") - print(diff) - - # Check max error - max_diff = np.max(diff) - print(f"Maximum difference: {max_diff}") - - if np.allclose(points2D_np, points2D_torch_np, atol=1e-6): - print("Implementations match closely.") - else: - print("Significant differences detected.") - - if points_cam_np is not None: - points_cam_diff = np.abs(points_cam_np - points_cam_torch_np) - print("Difference between NumPy and PyTorch camera-space coordinates:") - print(points_cam_diff) - - # Check max error - max_cam_diff = np.max(points_cam_diff) - print(f"Maximum camera-space coordinate difference: {max_cam_diff}") - - if np.allclose(points_cam_np, points_cam_torch_np, atol=1e-6): - print("Camera-space coordinates match closely.") - else: - print("Significant differences detected in camera-space coordinates.") diff --git a/capvector-pi05/src/vggt/dependency/track_modules/__init__.py b/capvector-pi05/src/vggt/dependency/track_modules/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/capvector-pi05/src/vggt/dependency/track_modules/base_track_predictor.py b/capvector-pi05/src/vggt/dependency/track_modules/base_track_predictor.py deleted file mode 100644 index 27aa7092691f4526f4a93ca76170783b1e71c335..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/dependency/track_modules/base_track_predictor.py +++ /dev/null @@ -1,190 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import torch.nn as nn -from einops import rearrange, repeat - -from .blocks import EfficientUpdateFormer, CorrBlock -from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed - - -class BaseTrackerPredictor(nn.Module): - def __init__( - self, - stride=4, - corr_levels=5, - corr_radius=4, - latent_dim=128, - hidden_size=384, - use_spaceatt=True, - depth=6, - fine=False, - ): - super(BaseTrackerPredictor, self).__init__() - """ - The base template to create a track predictor - - Modified from https://github.com/facebookresearch/co-tracker/ - """ - - self.stride = stride - self.latent_dim = latent_dim - self.corr_levels = corr_levels - self.corr_radius = corr_radius - self.hidden_size = hidden_size - self.fine = fine - - self.flows_emb_dim = latent_dim // 2 - self.transformer_dim = self.corr_levels * (self.corr_radius * 2 + 1) ** 2 + self.latent_dim * 2 - - if self.fine: - # TODO this is the old dummy code, will remove this when we train next model - self.transformer_dim += 4 if self.transformer_dim % 2 == 0 else 5 - else: - self.transformer_dim += (4 - self.transformer_dim % 4) % 4 - - space_depth = depth if use_spaceatt else 0 - time_depth = depth - - self.updateformer = EfficientUpdateFormer( - space_depth=space_depth, - time_depth=time_depth, - input_dim=self.transformer_dim, - hidden_size=self.hidden_size, - output_dim=self.latent_dim + 2, - mlp_ratio=4.0, - add_space_attn=use_spaceatt, - ) - - self.norm = nn.GroupNorm(1, self.latent_dim) - - # A linear layer to update track feats at each iteration - self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU()) - - if not self.fine: - self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) - - def forward(self, query_points, fmaps=None, iters=4, return_feat=False, down_ratio=1): - """ - query_points: B x N x 2, the number of batches, tracks, and xy - fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension. - note HH and WW is the size of feature maps instead of original images - """ - B, N, D = query_points.shape - B, S, C, HH, WW = fmaps.shape - - assert D == 2 - - # Scale the input query_points because we may downsample the images - # by down_ratio or self.stride - # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map - # its query_points should be query_points/4 - if down_ratio > 1: - query_points = query_points / float(down_ratio) - query_points = query_points / float(self.stride) - - # Init with coords as the query points - # It means the search will start from the position of query points at the reference frames - coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1) - - # Sample/extract the features of the query points in the query frame - query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0]) - - # init track feats by query feats - track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C - # back up the init coords - coords_backup = coords.clone() - - # Construct the correlation block - - fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius) - - coord_preds = [] - - # Iterative Refinement - for itr in range(iters): - # Detach the gradients from the last iteration - # (in my experience, not very important for performance) - coords = coords.detach() - - # Compute the correlation (check the implementation of CorrBlock) - - fcorr_fn.corr(track_feats) - fcorrs = fcorr_fn.sample(coords) # B, S, N, corrdim - - corrdim = fcorrs.shape[3] - - fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corrdim) - - # Movement of current coords relative to query points - flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2) - - flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False) - - # (In my trials, it is also okay to just add the flows_emb instead of concat) - flows_emb = torch.cat([flows_emb, flows], dim=-1) - - track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim) - - # Concatenate them as the input for the transformers - transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2) - - if transformer_input.shape[2] < self.transformer_dim: - # pad the features to match the dimension - pad_dim = self.transformer_dim - transformer_input.shape[2] - pad = torch.zeros_like(flows_emb[..., 0:pad_dim]) - transformer_input = torch.cat([transformer_input, pad], dim=2) - - # 2D positional embed - # TODO: this can be much simplified - pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device) - sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0]) - sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1) - - x = transformer_input + sampled_pos_emb - - # B, N, S, C - x = rearrange(x, "(b n) s d -> b n s d", b=B) - - # Compute the delta coordinates and delta track features - delta = self.updateformer(x) - # BN, S, C - delta = rearrange(delta, " b n s d -> (b n) s d", b=B) - delta_coords_ = delta[:, :, :2] - delta_feats_ = delta[:, :, 2:] - - track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim) - delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim) - - # Update the track features - track_feats_ = self.ffeat_updater(self.norm(delta_feats_)) + track_feats_ - track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC - - # B x S x N x 2 - coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3) - - # Force coord0 as query - # because we assume the query points should not be changed - coords[:, 0] = coords_backup[:, 0] - - # The predicted tracks are in the original image scale - if down_ratio > 1: - coord_preds.append(coords * self.stride * down_ratio) - else: - coord_preds.append(coords * self.stride) - - # B, S, N - if not self.fine: - vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N) - vis_e = torch.sigmoid(vis_e) - else: - vis_e = None - - if return_feat: - return coord_preds, vis_e, track_feats, query_track_feat - else: - return coord_preds, vis_e diff --git a/capvector-pi05/src/vggt/dependency/track_modules/blocks.py b/capvector-pi05/src/vggt/dependency/track_modules/blocks.py deleted file mode 100644 index 513f96836644ff27e714cba517510d2dd7e702df..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/dependency/track_modules/blocks.py +++ /dev/null @@ -1,329 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - - -# Modified from https://github.com/facebookresearch/co-tracker/ - - -import torch -import torch.nn as nn -import torch.nn.functional as F -from functools import partial -from typing import Callable -import collections -from torch import Tensor -from itertools import repeat - -from .utils import bilinear_sampler - -from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock - - -class BasicEncoder(nn.Module): - def __init__(self, input_dim=3, output_dim=128, stride=4): - super(BasicEncoder, self).__init__() - - self.stride = stride - self.norm_fn = "instance" - self.in_planes = output_dim // 2 - - self.norm1 = nn.InstanceNorm2d(self.in_planes) - self.norm2 = nn.InstanceNorm2d(output_dim * 2) - - self.conv1 = nn.Conv2d(input_dim, self.in_planes, kernel_size=7, stride=2, padding=3, padding_mode="zeros") - self.relu1 = nn.ReLU(inplace=True) - self.layer1 = self._make_layer(output_dim // 2, stride=1) - self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2) - self.layer3 = self._make_layer(output_dim, stride=2) - self.layer4 = self._make_layer(output_dim, stride=2) - - self.conv2 = nn.Conv2d( - output_dim * 3 + output_dim // 4, output_dim * 2, kernel_size=3, padding=1, padding_mode="zeros" - ) - self.relu2 = nn.ReLU(inplace=True) - self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") - elif isinstance(m, (nn.InstanceNorm2d)): - if m.weight is not None: - nn.init.constant_(m.weight, 1) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - - def _make_layer(self, dim, stride=1): - layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) - layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) - layers = (layer1, layer2) - - self.in_planes = dim - return nn.Sequential(*layers) - - def forward(self, x): - _, _, H, W = x.shape - - x = self.conv1(x) - x = self.norm1(x) - x = self.relu1(x) - - a = self.layer1(x) - b = self.layer2(a) - c = self.layer3(b) - d = self.layer4(c) - - a = _bilinear_intepolate(a, self.stride, H, W) - b = _bilinear_intepolate(b, self.stride, H, W) - c = _bilinear_intepolate(c, self.stride, H, W) - d = _bilinear_intepolate(d, self.stride, H, W) - - x = self.conv2(torch.cat([a, b, c, d], dim=1)) - x = self.norm2(x) - x = self.relu2(x) - x = self.conv3(x) - return x - - -class ShallowEncoder(nn.Module): - def __init__(self, input_dim=3, output_dim=32, stride=1, norm_fn="instance"): - super(ShallowEncoder, self).__init__() - self.stride = stride - self.norm_fn = norm_fn - self.in_planes = output_dim - - if self.norm_fn == "group": - self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes) - self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2) - elif self.norm_fn == "batch": - self.norm1 = nn.BatchNorm2d(self.in_planes) - self.norm2 = nn.BatchNorm2d(output_dim * 2) - elif self.norm_fn == "instance": - self.norm1 = nn.InstanceNorm2d(self.in_planes) - self.norm2 = nn.InstanceNorm2d(output_dim * 2) - elif self.norm_fn == "none": - self.norm1 = nn.Sequential() - - self.conv1 = nn.Conv2d(input_dim, self.in_planes, kernel_size=3, stride=2, padding=1, padding_mode="zeros") - self.relu1 = nn.ReLU(inplace=True) - - self.layer1 = self._make_layer(output_dim, stride=2) - - self.layer2 = self._make_layer(output_dim, stride=2) - self.conv2 = nn.Conv2d(output_dim, output_dim, kernel_size=1) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") - elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): - if m.weight is not None: - nn.init.constant_(m.weight, 1) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - - def _make_layer(self, dim, stride=1): - self.in_planes = dim - - layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) - return layer1 - - def forward(self, x): - _, _, H, W = x.shape - - x = self.conv1(x) - x = self.norm1(x) - x = self.relu1(x) - - tmp = self.layer1(x) - x = x + F.interpolate(tmp, (x.shape[-2:]), mode="bilinear", align_corners=True) - tmp = self.layer2(tmp) - x = x + F.interpolate(tmp, (x.shape[-2:]), mode="bilinear", align_corners=True) - tmp = None - x = self.conv2(x) + x - - x = F.interpolate(x, (H // self.stride, W // self.stride), mode="bilinear", align_corners=True) - - return x - - -def _bilinear_intepolate(x, stride, H, W): - return F.interpolate(x, (H // stride, W // stride), mode="bilinear", align_corners=True) - - -class EfficientUpdateFormer(nn.Module): - """ - Transformer model that updates track estimates. - """ - - def __init__( - self, - space_depth=6, - time_depth=6, - input_dim=320, - hidden_size=384, - num_heads=8, - output_dim=130, - mlp_ratio=4.0, - add_space_attn=True, - num_virtual_tracks=64, - ): - super().__init__() - - self.out_channels = 2 - self.num_heads = num_heads - self.hidden_size = hidden_size - self.add_space_attn = add_space_attn - self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) - self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True) - self.num_virtual_tracks = num_virtual_tracks - - if self.add_space_attn: - self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size)) - else: - self.virual_tracks = None - - self.time_blocks = nn.ModuleList( - [ - AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention) - for _ in range(time_depth) - ] - ) - - if add_space_attn: - self.space_virtual_blocks = nn.ModuleList( - [ - AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention) - for _ in range(space_depth) - ] - ) - self.space_point2virtual_blocks = nn.ModuleList( - [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] - ) - self.space_virtual2point_blocks = nn.ModuleList( - [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] - ) - assert len(self.time_blocks) >= len(self.space_virtual2point_blocks) - self.initialize_weights() - - def initialize_weights(self): - def _basic_init(module): - if isinstance(module, nn.Linear): - torch.nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - - def init_weights_vit_timm(module: nn.Module, name: str = ""): - """ViT weight initialization, original timm impl (for reproducibility)""" - if isinstance(module, nn.Linear): - trunc_normal_(module.weight, std=0.02) - if module.bias is not None: - nn.init.zeros_(module.bias) - - def forward(self, input_tensor, mask=None): - tokens = self.input_transform(input_tensor) - - init_tokens = tokens - - B, _, T, _ = tokens.shape - - if self.add_space_attn: - virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1) - tokens = torch.cat([tokens, virtual_tokens], dim=1) - - _, N, _, _ = tokens.shape - - j = 0 - for i in range(len(self.time_blocks)): - time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C - time_tokens = self.time_blocks[i](time_tokens) - - tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C - if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0): - space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C - point_tokens = space_tokens[:, : N - self.num_virtual_tracks] - virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :] - - virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask) - virtual_tokens = self.space_virtual_blocks[j](virtual_tokens) - point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask) - space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1) - tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C - j += 1 - - if self.add_space_attn: - tokens = tokens[:, : N - self.num_virtual_tracks] - - tokens = tokens + init_tokens - - flow = self.flow_head(tokens) - return flow - - -class CorrBlock: - def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"): - B, S, C, H, W = fmaps.shape - self.S, self.C, self.H, self.W = S, C, H, W - self.padding_mode = padding_mode - self.num_levels = num_levels - self.radius = radius - self.fmaps_pyramid = [] - self.multiple_track_feats = multiple_track_feats - - self.fmaps_pyramid.append(fmaps) - for i in range(self.num_levels - 1): - fmaps_ = fmaps.reshape(B * S, C, H, W) - fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2) - _, _, H, W = fmaps_.shape - fmaps = fmaps_.reshape(B, S, C, H, W) - self.fmaps_pyramid.append(fmaps) - - def sample(self, coords): - r = self.radius - B, S, N, D = coords.shape - assert D == 2 - - H, W = self.H, self.W - out_pyramid = [] - for i in range(self.num_levels): - corrs = self.corrs_pyramid[i] # B, S, N, H, W - *_, H, W = corrs.shape - - dx = torch.linspace(-r, r, 2 * r + 1) - dy = torch.linspace(-r, r, 2 * r + 1) - delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device) - - centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i - delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) - coords_lvl = centroid_lvl + delta_lvl - - corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode) - corrs = corrs.view(B, S, N, -1) - - out_pyramid.append(corrs) - - out = torch.cat(out_pyramid, dim=-1).contiguous() # B, S, N, LRR*2 - return out - - def corr(self, targets): - B, S, N, C = targets.shape - if self.multiple_track_feats: - targets_split = targets.split(C // self.num_levels, dim=-1) - B, S, N, C = targets_split[0].shape - - assert C == self.C - assert S == self.S - - fmap1 = targets - - self.corrs_pyramid = [] - for i, fmaps in enumerate(self.fmaps_pyramid): - *_, H, W = fmaps.shape - fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W) - if self.multiple_track_feats: - fmap1 = targets_split[i] - corrs = torch.matmul(fmap1, fmap2s) - corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W - corrs = corrs / torch.sqrt(torch.tensor(C).float()) - self.corrs_pyramid.append(corrs) diff --git a/capvector-pi05/src/vggt/dependency/track_modules/modules.py b/capvector-pi05/src/vggt/dependency/track_modules/modules.py deleted file mode 100644 index e1a5cdb57239a9e40f8cf2e208622c06f6492004..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/dependency/track_modules/modules.py +++ /dev/null @@ -1,202 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - - -import torch -import torch.nn as nn -import torch.nn.functional as F -from functools import partial -from typing import Callable -import collections -from torch import Tensor -from itertools import repeat - - -# From PyTorch internals -def _ntuple(n): - def parse(x): - if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): - return tuple(x) - return tuple(repeat(x, n)) - - return parse - - -def exists(val): - return val is not None - - -def default(val, d): - return val if exists(val) else d - - -to_2tuple = _ntuple(2) - - -class ResidualBlock(nn.Module): - """ - ResidualBlock: construct a block of two conv layers with residual connections - """ - - def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3): - super(ResidualBlock, self).__init__() - - self.conv1 = nn.Conv2d( - in_planes, planes, kernel_size=kernel_size, padding=1, stride=stride, padding_mode="zeros" - ) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, padding=1, padding_mode="zeros") - self.relu = nn.ReLU(inplace=True) - - num_groups = planes // 8 - - if norm_fn == "group": - self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - if not stride == 1: - self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - - elif norm_fn == "batch": - self.norm1 = nn.BatchNorm2d(planes) - self.norm2 = nn.BatchNorm2d(planes) - if not stride == 1: - self.norm3 = nn.BatchNorm2d(planes) - - elif norm_fn == "instance": - self.norm1 = nn.InstanceNorm2d(planes) - self.norm2 = nn.InstanceNorm2d(planes) - if not stride == 1: - self.norm3 = nn.InstanceNorm2d(planes) - - elif norm_fn == "none": - self.norm1 = nn.Sequential() - self.norm2 = nn.Sequential() - if not stride == 1: - self.norm3 = nn.Sequential() - else: - raise NotImplementedError - - if stride == 1: - self.downsample = None - else: - self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) - - def forward(self, x): - y = x - y = self.relu(self.norm1(self.conv1(y))) - y = self.relu(self.norm2(self.conv2(y))) - - if self.downsample is not None: - x = self.downsample(x) - - return self.relu(x + y) - - -class Mlp(nn.Module): - """MLP as used in Vision Transformer, MLP-Mixer and related networks""" - - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.GELU, - norm_layer=None, - bias=True, - drop=0.0, - use_conv=False, - ): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - bias = to_2tuple(bias) - drop_probs = to_2tuple(drop) - linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear - - self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) - self.act = act_layer() - self.drop1 = nn.Dropout(drop_probs[0]) - self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) - self.drop2 = nn.Dropout(drop_probs[1]) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop1(x) - x = self.fc2(x) - x = self.drop2(x) - return x - - -class AttnBlock(nn.Module): - def __init__( - self, - hidden_size, - num_heads, - attn_class: Callable[..., nn.Module] = nn.MultiheadAttention, - mlp_ratio=4.0, - **block_kwargs, - ): - """ - Self attention block - """ - super().__init__() - self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - - self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs) - - mlp_hidden_dim = int(hidden_size * mlp_ratio) - - self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) - - def forward(self, x, mask=None): - # Prepare the mask for PyTorch's attention (it expects a different format) - # attn_mask = mask if mask is not None else None - # Normalize before attention - x = self.norm1(x) - - # PyTorch's MultiheadAttention returns attn_output, attn_output_weights - # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask) - - attn_output, _ = self.attn(x, x, x) - - # Add & Norm - x = x + attn_output - x = x + self.mlp(self.norm2(x)) - return x - - -class CrossAttnBlock(nn.Module): - def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs): - """ - Cross attention block - """ - super().__init__() - self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.norm_context = nn.LayerNorm(hidden_size) - self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - - self.cross_attn = nn.MultiheadAttention( - embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs - ) - - mlp_hidden_dim = int(hidden_size * mlp_ratio) - - self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) - - def forward(self, x, context, mask=None): - # Normalize inputs - x = self.norm1(x) - context = self.norm_context(context) - - # Apply cross attention - # Note: nn.MultiheadAttention returns attn_output, attn_output_weights - attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask) - - # Add & Norm - x = x + attn_output - x = x + self.mlp(self.norm2(x)) - return x diff --git a/capvector-pi05/src/vggt/dependency/track_modules/track_refine.py b/capvector-pi05/src/vggt/dependency/track_modules/track_refine.py deleted file mode 100644 index 461572c2096f2fb69de45cef9ba401465ebd084f..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/dependency/track_modules/track_refine.py +++ /dev/null @@ -1,419 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from functools import partial -from torch import nn, einsum -from einops import rearrange, repeat -from einops.layers.torch import Rearrange, Reduce - -from PIL import Image -import os -from typing import Union, Tuple - - -def refine_track( - images, fine_fnet, fine_tracker, coarse_pred, compute_score=False, pradius=15, sradius=2, fine_iters=6, chunk=40960 -): - """ - Refines the tracking of images using a fine track predictor and a fine feature network. - Check https://arxiv.org/abs/2312.04563 for more details. - - Args: - images (torch.Tensor): The images to be tracked. - fine_fnet (nn.Module): The fine feature network. - fine_tracker (nn.Module): The fine track predictor. - coarse_pred (torch.Tensor): The coarse predictions of tracks. - compute_score (bool, optional): Whether to compute the score. Defaults to False. - pradius (int, optional): The radius of a patch. Defaults to 15. - sradius (int, optional): The search radius. Defaults to 2. - - Returns: - torch.Tensor: The refined tracks. - torch.Tensor, optional: The score. - """ - - # coarse_pred shape: BxSxNx2, - # where B is the batch, S is the video/images length, and N is the number of tracks - # now we are going to extract patches with the center at coarse_pred - # Please note that the last dimension indicates x and y, and hence has a dim number of 2 - B, S, N, _ = coarse_pred.shape - _, _, _, H, W = images.shape - - # Given the raidus of a patch, compute the patch size - psize = pradius * 2 + 1 - - # Note that we assume the first frame is the query frame - # so the 2D locations of the first frame are the query points - query_points = coarse_pred[:, 0] - - # Given 2D positions, we can use grid_sample to extract patches - # but it takes too much memory. - # Instead, we use the floored track xy to sample patches. - - # For example, if the query point xy is (128.16, 252.78), - # and the patch size is (31, 31), - # our goal is to extract the content of a rectangle - # with left top: (113.16, 237.78) - # and right bottom: (143.16, 267.78). - # However, we record the floored left top: (113, 237) - # and the offset (0.16, 0.78) - # Then what we need is just unfolding the images like in CNN, - # picking the content at [(113, 237), (143, 267)]. - # Such operations are highly optimized at pytorch - # (well if you really want to use interpolation, check the function extract_glimpse() below) - - with torch.no_grad(): - content_to_extract = images.reshape(B * S, 3, H, W) - C_in = content_to_extract.shape[1] - - # Please refer to https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html - # for the detailed explanation of unfold() - # Here it runs sliding windows (psize x psize) to build patches - # The shape changes from - # (B*S)x C_in x H x W to (B*S)x C_in x H_new x W_new x Psize x Psize - # where Psize is the size of patch - content_to_extract = content_to_extract.unfold(2, psize, 1).unfold(3, psize, 1) - - # Floor the coarse predictions to get integers and save the fractional/decimal - track_int = coarse_pred.floor().int() - track_frac = coarse_pred - track_int - - # Note the points represent the center of patches - # now we get the location of the top left corner of patches - # because the ouput of pytorch unfold are indexed by top left corner - topleft = track_int - pradius - topleft_BSN = topleft.clone() - - # clamp the values so that we will not go out of indexes - # NOTE: (VERY IMPORTANT: This operation ASSUMES H=W). - # You need to seperately clamp x and y if H!=W - topleft = topleft.clamp(0, H - psize) - - # Reshape from BxSxNx2 -> (B*S)xNx2 - topleft = topleft.reshape(B * S, N, 2) - - # Prepare batches for indexing, shape: (B*S)xN - batch_indices = torch.arange(B * S)[:, None].expand(-1, N).to(content_to_extract.device) - - # extracted_patches: (B*S) x N x C_in x Psize x Psize - extracted_patches = content_to_extract[batch_indices, :, topleft[..., 1], topleft[..., 0]] - - if chunk < 0: - # Extract image patches based on top left corners - # Feed patches to fine fent for features - patch_feat = fine_fnet(extracted_patches.reshape(B * S * N, C_in, psize, psize)) - else: - patches = extracted_patches.reshape(B * S * N, C_in, psize, psize) - - patch_feat_list = [] - for p in torch.split(patches, chunk): - patch_feat_list += [fine_fnet(p)] - patch_feat = torch.cat(patch_feat_list, 0) - - C_out = patch_feat.shape[1] - - # Refine the coarse tracks by fine_tracker - # reshape back to B x S x N x C_out x Psize x Psize - patch_feat = patch_feat.reshape(B, S, N, C_out, psize, psize) - patch_feat = rearrange(patch_feat, "b s n c p q -> (b n) s c p q") - - # Prepare for the query points for fine tracker - # They are relative to the patch left top corner, - # instead of the image top left corner now - # patch_query_points: N x 1 x 2 - # only 1 here because for each patch we only have 1 query point - patch_query_points = track_frac[:, 0] + pradius - patch_query_points = patch_query_points.reshape(B * N, 2).unsqueeze(1) - - # Feed the PATCH query points and tracks into fine tracker - fine_pred_track_lists, _, _, query_point_feat = fine_tracker( - query_points=patch_query_points, fmaps=patch_feat, iters=fine_iters, return_feat=True - ) - - # relative the patch top left - fine_pred_track = fine_pred_track_lists[-1].clone() - - # From (relative to the patch top left) to (relative to the image top left) - for idx in range(len(fine_pred_track_lists)): - fine_level = rearrange(fine_pred_track_lists[idx], "(b n) s u v -> b s n u v", b=B, n=N) - fine_level = fine_level.squeeze(-2) - fine_level = fine_level + topleft_BSN - fine_pred_track_lists[idx] = fine_level - - # relative to the image top left - refined_tracks = fine_pred_track_lists[-1].clone() - refined_tracks[:, 0] = query_points - - score = None - - if compute_score: - score = compute_score_fn(query_point_feat, patch_feat, fine_pred_track, sradius, psize, B, N, S, C_out) - - return refined_tracks, score - - -def refine_track_v0( - images, fine_fnet, fine_tracker, coarse_pred, compute_score=False, pradius=15, sradius=2, fine_iters=6 -): - """ - COPIED FROM VGGSfM - - Refines the tracking of images using a fine track predictor and a fine feature network. - Check https://arxiv.org/abs/2312.04563 for more details. - - Args: - images (torch.Tensor): The images to be tracked. - fine_fnet (nn.Module): The fine feature network. - fine_tracker (nn.Module): The fine track predictor. - coarse_pred (torch.Tensor): The coarse predictions of tracks. - compute_score (bool, optional): Whether to compute the score. Defaults to False. - pradius (int, optional): The radius of a patch. Defaults to 15. - sradius (int, optional): The search radius. Defaults to 2. - - Returns: - torch.Tensor: The refined tracks. - torch.Tensor, optional: The score. - """ - - # coarse_pred shape: BxSxNx2, - # where B is the batch, S is the video/images length, and N is the number of tracks - # now we are going to extract patches with the center at coarse_pred - # Please note that the last dimension indicates x and y, and hence has a dim number of 2 - B, S, N, _ = coarse_pred.shape - _, _, _, H, W = images.shape - - # Given the raidus of a patch, compute the patch size - psize = pradius * 2 + 1 - - # Note that we assume the first frame is the query frame - # so the 2D locations of the first frame are the query points - query_points = coarse_pred[:, 0] - - # Given 2D positions, we can use grid_sample to extract patches - # but it takes too much memory. - # Instead, we use the floored track xy to sample patches. - - # For example, if the query point xy is (128.16, 252.78), - # and the patch size is (31, 31), - # our goal is to extract the content of a rectangle - # with left top: (113.16, 237.78) - # and right bottom: (143.16, 267.78). - # However, we record the floored left top: (113, 237) - # and the offset (0.16, 0.78) - # Then what we need is just unfolding the images like in CNN, - # picking the content at [(113, 237), (143, 267)]. - # Such operations are highly optimized at pytorch - # (well if you really want to use interpolation, check the function extract_glimpse() below) - - with torch.no_grad(): - content_to_extract = images.reshape(B * S, 3, H, W) - C_in = content_to_extract.shape[1] - - # Please refer to https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html - # for the detailed explanation of unfold() - # Here it runs sliding windows (psize x psize) to build patches - # The shape changes from - # (B*S)x C_in x H x W to (B*S)x C_in x H_new x W_new x Psize x Psize - # where Psize is the size of patch - content_to_extract = content_to_extract.unfold(2, psize, 1).unfold(3, psize, 1) - - # Floor the coarse predictions to get integers and save the fractional/decimal - track_int = coarse_pred.floor().int() - track_frac = coarse_pred - track_int - - # Note the points represent the center of patches - # now we get the location of the top left corner of patches - # because the ouput of pytorch unfold are indexed by top left corner - topleft = track_int - pradius - topleft_BSN = topleft.clone() - - # clamp the values so that we will not go out of indexes - # NOTE: (VERY IMPORTANT: This operation ASSUMES H=W). - # You need to seperately clamp x and y if H!=W - topleft = topleft.clamp(0, H - psize) - - # Reshape from BxSxNx2 -> (B*S)xNx2 - topleft = topleft.reshape(B * S, N, 2) - - # Prepare batches for indexing, shape: (B*S)xN - batch_indices = torch.arange(B * S)[:, None].expand(-1, N).to(content_to_extract.device) - - # Extract image patches based on top left corners - # extracted_patches: (B*S) x N x C_in x Psize x Psize - extracted_patches = content_to_extract[batch_indices, :, topleft[..., 1], topleft[..., 0]] - - # Feed patches to fine fent for features - patch_feat = fine_fnet(extracted_patches.reshape(B * S * N, C_in, psize, psize)) - - C_out = patch_feat.shape[1] - - # Refine the coarse tracks by fine_tracker - - # reshape back to B x S x N x C_out x Psize x Psize - patch_feat = patch_feat.reshape(B, S, N, C_out, psize, psize) - patch_feat = rearrange(patch_feat, "b s n c p q -> (b n) s c p q") - - # Prepare for the query points for fine tracker - # They are relative to the patch left top corner, - # instead of the image top left corner now - # patch_query_points: N x 1 x 2 - # only 1 here because for each patch we only have 1 query point - patch_query_points = track_frac[:, 0] + pradius - patch_query_points = patch_query_points.reshape(B * N, 2).unsqueeze(1) - - # Feed the PATCH query points and tracks into fine tracker - fine_pred_track_lists, _, _, query_point_feat = fine_tracker( - query_points=patch_query_points, fmaps=patch_feat, iters=fine_iters, return_feat=True - ) - - # relative the patch top left - fine_pred_track = fine_pred_track_lists[-1].clone() - - # From (relative to the patch top left) to (relative to the image top left) - for idx in range(len(fine_pred_track_lists)): - fine_level = rearrange(fine_pred_track_lists[idx], "(b n) s u v -> b s n u v", b=B, n=N) - fine_level = fine_level.squeeze(-2) - fine_level = fine_level + topleft_BSN - fine_pred_track_lists[idx] = fine_level - - # relative to the image top left - refined_tracks = fine_pred_track_lists[-1].clone() - refined_tracks[:, 0] = query_points - - score = None - - if compute_score: - score = compute_score_fn(query_point_feat, patch_feat, fine_pred_track, sradius, psize, B, N, S, C_out) - - return refined_tracks, score - - -################################## NOTE: NOT USED ################################## - - -def compute_score_fn(query_point_feat, patch_feat, fine_pred_track, sradius, psize, B, N, S, C_out): - """ - Compute the scores, i.e., the standard deviation of the 2D similarity heatmaps, - given the query point features and reference frame feature maps - """ - - from kornia.utils.grid import create_meshgrid - from kornia.geometry.subpix import dsnt - - # query_point_feat initial shape: B x N x C_out, - # query_point_feat indicates the feat at the coorponsing query points - # Therefore we don't have S dimension here - query_point_feat = query_point_feat.reshape(B, N, C_out) - # reshape and expand to B x (S-1) x N x C_out - query_point_feat = query_point_feat.unsqueeze(1).expand(-1, S - 1, -1, -1) - # and reshape to (B*(S-1)*N) x C_out - query_point_feat = query_point_feat.reshape(B * (S - 1) * N, C_out) - - # Radius and size for computing the score - ssize = sradius * 2 + 1 - - # Reshape, you know it, so many reshaping operations - patch_feat = rearrange(patch_feat, "(b n) s c p q -> b s n c p q", b=B, n=N) - - # Again, we unfold the patches to smaller patches - # so that we can then focus on smaller patches - # patch_feat_unfold shape: - # B x S x N x C_out x (psize - 2*sradius) x (psize - 2*sradius) x ssize x ssize - # well a bit scary, but actually not - patch_feat_unfold = patch_feat.unfold(4, ssize, 1).unfold(5, ssize, 1) - - # Do the same stuffs above, i.e., the same as extracting patches - fine_prediction_floor = fine_pred_track.floor().int() - fine_level_floor_topleft = fine_prediction_floor - sradius - - # Clamp to ensure the smaller patch is valid - fine_level_floor_topleft = fine_level_floor_topleft.clamp(0, psize - ssize) - fine_level_floor_topleft = fine_level_floor_topleft.squeeze(2) - - # Prepare the batch indices and xy locations - - batch_indices_score = torch.arange(B)[:, None, None].expand(-1, S, N) # BxSxN - batch_indices_score = batch_indices_score.reshape(-1).to(patch_feat_unfold.device) # B*S*N - y_indices = fine_level_floor_topleft[..., 0].flatten() # Flatten H indices - x_indices = fine_level_floor_topleft[..., 1].flatten() # Flatten W indices - - reference_frame_feat = patch_feat_unfold.reshape( - B * S * N, C_out, psize - sradius * 2, psize - sradius * 2, ssize, ssize - ) - - # Note again, according to pytorch convention - # x_indices cooresponds to [..., 1] and y_indices cooresponds to [..., 0] - reference_frame_feat = reference_frame_feat[batch_indices_score, :, x_indices, y_indices] - reference_frame_feat = reference_frame_feat.reshape(B, S, N, C_out, ssize, ssize) - # pick the frames other than the first one, so we have S-1 frames here - reference_frame_feat = reference_frame_feat[:, 1:].reshape(B * (S - 1) * N, C_out, ssize * ssize) - - # Compute similarity - sim_matrix = torch.einsum("mc,mcr->mr", query_point_feat, reference_frame_feat) - softmax_temp = 1.0 / C_out**0.5 - heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1) - # 2D heatmaps - heatmap = heatmap.reshape(B * (S - 1) * N, ssize, ssize) # * x ssize x ssize - - coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] - grid_normalized = create_meshgrid(ssize, ssize, normalized_coordinates=True, device=heatmap.device).reshape( - 1, -1, 2 - ) - - var = torch.sum(grid_normalized**2 * heatmap.view(-1, ssize * ssize, 1), dim=1) - coords_normalized**2 - std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # clamp needed for numerical stability - - score = std.reshape(B, S - 1, N) - # set score as 1 for the query frame - score = torch.cat([torch.ones_like(score[:, 0:1]), score], dim=1) - - return score - - -def extract_glimpse( - tensor: torch.Tensor, size: Tuple[int, int], offsets, mode="bilinear", padding_mode="zeros", debug=False, orib=None -): - B, C, W, H = tensor.shape - - h, w = size - xs = torch.arange(0, w, dtype=tensor.dtype, device=tensor.device) - (w - 1) / 2.0 - ys = torch.arange(0, h, dtype=tensor.dtype, device=tensor.device) - (h - 1) / 2.0 - - vy, vx = torch.meshgrid(ys, xs) - grid = torch.stack([vx, vy], dim=-1) # h, w, 2 - grid = grid[None] - - B, N, _ = offsets.shape - - offsets = offsets.reshape((B * N), 1, 1, 2) - offsets_grid = offsets + grid - - # normalised grid to [-1, 1] - offsets_grid = (offsets_grid - offsets_grid.new_tensor([W / 2, H / 2])) / offsets_grid.new_tensor([W / 2, H / 2]) - - # BxCxHxW -> Bx1xCxHxW - tensor = tensor[:, None] - - # Bx1xCxHxW -> BxNxCxHxW - tensor = tensor.expand(-1, N, -1, -1, -1) - - # BxNxCxHxW -> (B*N)xCxHxW - tensor = tensor.reshape((B * N), C, W, H) - - sampled = torch.nn.functional.grid_sample( - tensor, offsets_grid, mode=mode, align_corners=False, padding_mode=padding_mode - ) - - # NOTE: I am not sure it should be h, w or w, h here - # but okay for sqaures - sampled = sampled.reshape(B, N, C, h, w) - - return sampled diff --git a/capvector-pi05/src/vggt/dependency/track_modules/utils.py b/capvector-pi05/src/vggt/dependency/track_modules/utils.py deleted file mode 100644 index c1b002c055ec44d2bf65f99041c47a53d6b0c9b1..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/dependency/track_modules/utils.py +++ /dev/null @@ -1,216 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -# Modified from https://github.com/facebookresearch/PoseDiffusion -# and https://github.com/facebookresearch/co-tracker/tree/main - - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from typing import Optional, Tuple, Union -from einops import rearrange, repeat - - -def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor: - """ - This function initializes a grid and generates a 2D positional embedding using sine and cosine functions. - It is a wrapper of get_2d_sincos_pos_embed_from_grid. - Args: - - embed_dim: The embedding dimension. - - grid_size: The grid size. - Returns: - - pos_embed: The generated 2D positional embedding. - """ - if isinstance(grid_size, tuple): - grid_size_h, grid_size_w = grid_size - else: - grid_size_h = grid_size_w = grid_size - grid_h = torch.arange(grid_size_h, dtype=torch.float) - grid_w = torch.arange(grid_size_w, dtype=torch.float) - grid = torch.meshgrid(grid_w, grid_h, indexing="xy") - grid = torch.stack(grid, dim=0) - grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) - if return_grid: - return (pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2), grid) - return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2) - - -def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor: - """ - This function generates a 2D positional embedding from a given grid using sine and cosine functions. - - Args: - - embed_dim: The embedding dimension. - - grid: The grid to generate the embedding from. - - Returns: - - emb: The generated 2D positional embedding. - """ - assert embed_dim % 2 == 0 - - # use half of dimensions to encode grid_h - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) - - emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D) - return emb - - -def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor: - """ - This function generates a 1D positional embedding from a given grid using sine and cosine functions. - - Args: - - embed_dim: The embedding dimension. - - pos: The position to generate the embedding from. - - Returns: - - emb: The generated 1D positional embedding. - """ - assert embed_dim % 2 == 0 - omega = torch.arange(embed_dim // 2, dtype=torch.double) - omega /= embed_dim / 2.0 - omega = 1.0 / 10000**omega # (D/2,) - - pos = pos.reshape(-1) # (M,) - out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product - - emb_sin = torch.sin(out) # (M, D/2) - emb_cos = torch.cos(out) # (M, D/2) - - emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) - return emb[None].float() - - -def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor: - """ - This function generates a 2D positional embedding from given coordinates using sine and cosine functions. - - Args: - - xy: The coordinates to generate the embedding from. - - C: The size of the embedding. - - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding. - - Returns: - - pe: The generated 2D positional embedding. - """ - B, N, D = xy.shape - assert D == 2 - - x = xy[:, :, 0:1] - y = xy[:, :, 1:2] - div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2)) - - pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) - pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) - - pe_x[:, :, 0::2] = torch.sin(x * div_term) - pe_x[:, :, 1::2] = torch.cos(x * div_term) - - pe_y[:, :, 0::2] = torch.sin(y * div_term) - pe_y[:, :, 1::2] = torch.cos(y * div_term) - - pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3) - if cat_coords: - pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3) - return pe - - -def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"): - r"""Sample a tensor using bilinear interpolation - - `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at - coordinates :attr:`coords` using bilinear interpolation. It is the same - as `torch.nn.functional.grid_sample()` but with a different coordinate - convention. - - The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where - :math:`B` is the batch size, :math:`C` is the number of channels, - :math:`H` is the height of the image, and :math:`W` is the width of the - image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is - interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`. - - Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`, - in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note - that in this case the order of the components is slightly different - from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`. - - If `align_corners` is `True`, the coordinate :math:`x` is assumed to be - in the range :math:`[0,W-1]`, with 0 corresponding to the center of the - left-most image pixel :math:`W-1` to the center of the right-most - pixel. - - If `align_corners` is `False`, the coordinate :math:`x` is assumed to - be in the range :math:`[0,W]`, with 0 corresponding to the left edge of - the left-most pixel :math:`W` to the right edge of the right-most - pixel. - - Similar conventions apply to the :math:`y` for the range - :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range - :math:`[0,T-1]` and :math:`[0,T]`. - - Args: - input (Tensor): batch of input images. - coords (Tensor): batch of coordinates. - align_corners (bool, optional): Coordinate convention. Defaults to `True`. - padding_mode (str, optional): Padding mode. Defaults to `"border"`. - - Returns: - Tensor: sampled points. - """ - - sizes = input.shape[2:] - - assert len(sizes) in [2, 3] - - if len(sizes) == 3: - # t x y -> x y t to match dimensions T H W in grid_sample - coords = coords[..., [1, 2, 0]] - - if align_corners: - coords = coords * torch.tensor([2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device) - else: - coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device) - - coords -= 1 - - return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode) - - -def sample_features4d(input, coords): - r"""Sample spatial features - - `sample_features4d(input, coords)` samples the spatial features - :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`. - - The field is sampled at coordinates :attr:`coords` using bilinear - interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R, - 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the - same convention as :func:`bilinear_sampler` with `align_corners=True`. - - The output tensor has one feature per point, and has shape :math:`(B, - R, C)`. - - Args: - input (Tensor): spatial features. - coords (Tensor): points. - - Returns: - Tensor: sampled features. - """ - - B, _, _, _ = input.shape - - # B R 2 -> B R 1 2 - coords = coords.unsqueeze(2) - - # B C R 1 - feats = bilinear_sampler(input, coords) - - return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C diff --git a/capvector-pi05/src/vggt/dependency/track_predict.py b/capvector-pi05/src/vggt/dependency/track_predict.py deleted file mode 100644 index 22465c0b53d3e310e0025c06aed8dabccf7339d3..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/dependency/track_predict.py +++ /dev/null @@ -1,326 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import numpy as np -from .vggsfm_utils import * - - -def predict_tracks( - images, - conf=None, - points_3d=None, - masks=None, - max_query_pts=2048, - query_frame_num=5, - keypoint_extractor="aliked+sp", - max_points_num=163840, - fine_tracking=True, - complete_non_vis=True, -): - """ - Predict tracks for the given images and masks. - - TODO: support non-square images - TODO: support masks - - - This function predicts the tracks for the given images and masks using the specified query method - and track predictor. It finds query points, and predicts the tracks, visibility, and scores for the query frames. - - Args: - images: Tensor of shape [S, 3, H, W] containing the input images. - conf: Tensor of shape [S, 1, H, W] containing the confidence scores. Default is None. - points_3d: Tensor containing 3D points. Default is None. - masks: Optional tensor of shape [S, 1, H, W] containing masks. Default is None. - max_query_pts: Maximum number of query points. Default is 2048. - query_frame_num: Number of query frames to use. Default is 5. - keypoint_extractor: Method for keypoint extraction. Default is "aliked+sp". - max_points_num: Maximum number of points to process at once. Default is 163840. - fine_tracking: Whether to use fine tracking. Default is True. - complete_non_vis: Whether to augment non-visible frames. Default is True. - - Returns: - pred_tracks: Numpy array containing the predicted tracks. - pred_vis_scores: Numpy array containing the visibility scores for the tracks. - pred_confs: Numpy array containing the confidence scores for the tracks. - pred_points_3d: Numpy array containing the 3D points for the tracks. - pred_colors: Numpy array containing the point colors for the tracks. (0, 255) - """ - - device = images.device - dtype = images.dtype - tracker = build_vggsfm_tracker().to(device, dtype) - - # Find query frames - query_frame_indexes = generate_rank_by_dino(images, query_frame_num=query_frame_num, device=device) - - # Add the first image to the front if not already present - if 0 in query_frame_indexes: - query_frame_indexes.remove(0) - query_frame_indexes = [0, *query_frame_indexes] - - # TODO: add the functionality to handle the masks - keypoint_extractors = initialize_feature_extractors( - max_query_pts, extractor_method=keypoint_extractor, device=device - ) - - pred_tracks = [] - pred_vis_scores = [] - pred_confs = [] - pred_points_3d = [] - pred_colors = [] - - fmaps_for_tracker = tracker.process_images_to_fmaps(images) - - if fine_tracking: - print("For faster inference, consider disabling fine_tracking") - - for query_index in query_frame_indexes: - print(f"Predicting tracks for query frame {query_index}") - pred_track, pred_vis, pred_conf, pred_point_3d, pred_color = _forward_on_query( - query_index, - images, - conf, - points_3d, - fmaps_for_tracker, - keypoint_extractors, - tracker, - max_points_num, - fine_tracking, - device, - ) - - pred_tracks.append(pred_track) - pred_vis_scores.append(pred_vis) - pred_confs.append(pred_conf) - pred_points_3d.append(pred_point_3d) - pred_colors.append(pred_color) - - if complete_non_vis: - pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors = _augment_non_visible_frames( - pred_tracks, - pred_vis_scores, - pred_confs, - pred_points_3d, - pred_colors, - images, - conf, - points_3d, - fmaps_for_tracker, - keypoint_extractors, - tracker, - max_points_num, - fine_tracking, - min_vis=500, - non_vis_thresh=0.1, - device=device, - ) - - pred_tracks = np.concatenate(pred_tracks, axis=1) - pred_vis_scores = np.concatenate(pred_vis_scores, axis=1) - pred_confs = np.concatenate(pred_confs, axis=0) if pred_confs else None - pred_points_3d = np.concatenate(pred_points_3d, axis=0) if pred_points_3d else None - pred_colors = np.concatenate(pred_colors, axis=0) if pred_colors else None - - # from vggt.utils.visual_track import visualize_tracks_on_images - # visualize_tracks_on_images(images[None], torch.from_numpy(pred_tracks[None]), torch.from_numpy(pred_vis_scores[None])>0.2, out_dir="track_visuals") - - return pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors - - -def _forward_on_query( - query_index, - images, - conf, - points_3d, - fmaps_for_tracker, - keypoint_extractors, - tracker, - max_points_num, - fine_tracking, - device, -): - """ - Process a single query frame for track prediction. - - Args: - query_index: Index of the query frame - images: Tensor of shape [S, 3, H, W] containing the input images - conf: Confidence tensor - points_3d: 3D points tensor - fmaps_for_tracker: Feature maps for the tracker - keypoint_extractors: Initialized feature extractors - tracker: VGG-SFM tracker - max_points_num: Maximum number of points to process at once - fine_tracking: Whether to use fine tracking - device: Device to use for computation - - Returns: - pred_track: Predicted tracks - pred_vis: Visibility scores for the tracks - pred_conf: Confidence scores for the tracks - pred_point_3d: 3D points for the tracks - pred_color: Point colors for the tracks (0, 255) - """ - frame_num, _, height, width = images.shape - - query_image = images[query_index] - query_points = extract_keypoints(query_image, keypoint_extractors, round_keypoints=False) - query_points = query_points[:, torch.randperm(query_points.shape[1], device=device)] - - # Extract the color at the keypoint locations - query_points_long = query_points.squeeze(0).round().long() - pred_color = images[query_index][:, query_points_long[:, 1], query_points_long[:, 0]] - pred_color = (pred_color.permute(1, 0).cpu().numpy() * 255).astype(np.uint8) - - # Query the confidence and points_3d at the keypoint locations - if (conf is not None) and (points_3d is not None): - assert height == width - assert conf.shape[-2] == conf.shape[-1] - assert conf.shape[:3] == points_3d.shape[:3] - scale = conf.shape[-1] / width - - query_points_scaled = (query_points.squeeze(0) * scale).round().long() - query_points_scaled = query_points_scaled.cpu().numpy() - - pred_conf = conf[query_index][query_points_scaled[:, 1], query_points_scaled[:, 0]] - pred_point_3d = points_3d[query_index][query_points_scaled[:, 1], query_points_scaled[:, 0]] - - # heuristic to remove low confidence points - # should I export this as an input parameter? - valid_mask = pred_conf > 1.2 - if valid_mask.sum() > 512: - query_points = query_points[:, valid_mask] # Make sure shape is compatible - pred_conf = pred_conf[valid_mask] - pred_point_3d = pred_point_3d[valid_mask] - pred_color = pred_color[valid_mask] - else: - pred_conf = None - pred_point_3d = None - - reorder_index = calculate_index_mappings(query_index, frame_num, device=device) - - images_feed, fmaps_feed = switch_tensor_order([images, fmaps_for_tracker], reorder_index, dim=0) - images_feed = images_feed[None] # add batch dimension - fmaps_feed = fmaps_feed[None] # add batch dimension - - all_points_num = images_feed.shape[1] * query_points.shape[1] - - # Don't need to be scared, this is just chunking to make GPU happy - if all_points_num > max_points_num: - num_splits = (all_points_num + max_points_num - 1) // max_points_num - query_points = torch.chunk(query_points, num_splits, dim=1) - else: - query_points = [query_points] - - pred_track, pred_vis, _ = predict_tracks_in_chunks( - tracker, images_feed, query_points, fmaps_feed, fine_tracking=fine_tracking - ) - - pred_track, pred_vis = switch_tensor_order([pred_track, pred_vis], reorder_index, dim=1) - - pred_track = pred_track.squeeze(0).float().cpu().numpy() - pred_vis = pred_vis.squeeze(0).float().cpu().numpy() - - return pred_track, pred_vis, pred_conf, pred_point_3d, pred_color - - -def _augment_non_visible_frames( - pred_tracks: list, # ← running list of np.ndarrays - pred_vis_scores: list, # ← running list of np.ndarrays - pred_confs: list, # ← running list of np.ndarrays for confidence scores - pred_points_3d: list, # ← running list of np.ndarrays for 3D points - pred_colors: list, # ← running list of np.ndarrays for colors - images: torch.Tensor, - conf, - points_3d, - fmaps_for_tracker, - keypoint_extractors, - tracker, - max_points_num: int, - fine_tracking: bool, - *, - min_vis: int = 500, - non_vis_thresh: float = 0.1, - device: torch.device = None, -): - """ - Augment tracking for frames with insufficient visibility. - - Args: - pred_tracks: List of numpy arrays containing predicted tracks. - pred_vis_scores: List of numpy arrays containing visibility scores. - pred_confs: List of numpy arrays containing confidence scores. - pred_points_3d: List of numpy arrays containing 3D points. - pred_colors: List of numpy arrays containing point colors. - images: Tensor of shape [S, 3, H, W] containing the input images. - conf: Tensor of shape [S, 1, H, W] containing confidence scores - points_3d: Tensor containing 3D points - fmaps_for_tracker: Feature maps for the tracker - keypoint_extractors: Initialized feature extractors - tracker: VGG-SFM tracker - max_points_num: Maximum number of points to process at once - fine_tracking: Whether to use fine tracking - min_vis: Minimum visibility threshold - non_vis_thresh: Non-visibility threshold - device: Device to use for computation - - Returns: - Updated pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, and pred_colors lists. - """ - last_query = -1 - final_trial = False - cur_extractors = keypoint_extractors # may be replaced on the final trial - - while True: - # Visibility per frame - vis_array = np.concatenate(pred_vis_scores, axis=1) - - # Count frames with sufficient visibility using numpy - sufficient_vis_count = (vis_array > non_vis_thresh).sum(axis=-1) - non_vis_frames = np.where(sufficient_vis_count < min_vis)[0].tolist() - - if len(non_vis_frames) == 0: - break - - print("Processing non visible frames:", non_vis_frames) - - # Decide the frames & extractor for this round - if non_vis_frames[0] == last_query: - # Same frame failed twice - final "all-in" attempt - final_trial = True - cur_extractors = initialize_feature_extractors(2048, extractor_method="sp+sift+aliked", device=device) - query_frame_list = non_vis_frames # blast them all at once - else: - query_frame_list = [non_vis_frames[0]] # Process one at a time - - last_query = non_vis_frames[0] - - # Run the tracker for every selected frame - for query_index in query_frame_list: - new_track, new_vis, new_conf, new_point_3d, new_color = _forward_on_query( - query_index, - images, - conf, - points_3d, - fmaps_for_tracker, - cur_extractors, - tracker, - max_points_num, - fine_tracking, - device, - ) - pred_tracks.append(new_track) - pred_vis_scores.append(new_vis) - pred_confs.append(new_conf) - pred_points_3d.append(new_point_3d) - pred_colors.append(new_color) - - if final_trial: - break # Stop after final attempt - - return pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors diff --git a/capvector-pi05/src/vggt/dependency/vggsfm_tracker.py b/capvector-pi05/src/vggt/dependency/vggsfm_tracker.py deleted file mode 100644 index e3940907f2bde73886af29af5e4ef8250b5b0d1b..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/dependency/vggsfm_tracker.py +++ /dev/null @@ -1,124 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from functools import partial -from torch import nn, einsum -from einops import rearrange, repeat -from einops.layers.torch import Rearrange, Reduce - -from hydra.utils import instantiate -from omegaconf import OmegaConf - -from .track_modules.track_refine import refine_track -from .track_modules.blocks import BasicEncoder, ShallowEncoder -from .track_modules.base_track_predictor import BaseTrackerPredictor - - -class TrackerPredictor(nn.Module): - def __init__(self, **extra_args): - super(TrackerPredictor, self).__init__() - """ - Initializes the tracker predictor. - - Both coarse_predictor and fine_predictor are constructed as a BaseTrackerPredictor, - check track_modules/base_track_predictor.py - - Both coarse_fnet and fine_fnet are constructed as a 2D CNN network - check track_modules/blocks.py for BasicEncoder and ShallowEncoder - """ - # Define coarse predictor configuration - coarse_stride = 4 - self.coarse_down_ratio = 2 - - # Create networks directly instead of using instantiate - self.coarse_fnet = BasicEncoder(stride=coarse_stride) - self.coarse_predictor = BaseTrackerPredictor(stride=coarse_stride) - - # Create fine predictor with stride = 1 - self.fine_fnet = ShallowEncoder(stride=1) - self.fine_predictor = BaseTrackerPredictor( - stride=1, - depth=4, - corr_levels=3, - corr_radius=3, - latent_dim=32, - hidden_size=256, - fine=True, - use_spaceatt=False, - ) - - def forward( - self, images, query_points, fmaps=None, coarse_iters=6, inference=True, fine_tracking=True, fine_chunk=40960 - ): - """ - Args: - images (torch.Tensor): Images as RGB, in the range of [0, 1], with a shape of B x S x 3 x H x W. - query_points (torch.Tensor): 2D xy of query points, relative to top left, with a shape of B x N x 2. - fmaps (torch.Tensor, optional): Precomputed feature maps. Defaults to None. - coarse_iters (int, optional): Number of iterations for coarse prediction. Defaults to 6. - inference (bool, optional): Whether to perform inference. Defaults to True. - fine_tracking (bool, optional): Whether to perform fine tracking. Defaults to True. - - Returns: - tuple: A tuple containing fine_pred_track, coarse_pred_track, pred_vis, and pred_score. - """ - - if fmaps is None: - batch_num, frame_num, image_dim, height, width = images.shape - reshaped_image = images.reshape(batch_num * frame_num, image_dim, height, width) - fmaps = self.process_images_to_fmaps(reshaped_image) - fmaps = fmaps.reshape(batch_num, frame_num, -1, fmaps.shape[-2], fmaps.shape[-1]) - - if inference: - torch.cuda.empty_cache() - - # Coarse prediction - coarse_pred_track_lists, pred_vis = self.coarse_predictor( - query_points=query_points, fmaps=fmaps, iters=coarse_iters, down_ratio=self.coarse_down_ratio - ) - coarse_pred_track = coarse_pred_track_lists[-1] - - if inference: - torch.cuda.empty_cache() - - if fine_tracking: - # Refine the coarse prediction - fine_pred_track, pred_score = refine_track( - images, self.fine_fnet, self.fine_predictor, coarse_pred_track, compute_score=False, chunk=fine_chunk - ) - - if inference: - torch.cuda.empty_cache() - else: - fine_pred_track = coarse_pred_track - pred_score = torch.ones_like(pred_vis) - - return fine_pred_track, coarse_pred_track, pred_vis, pred_score - - def process_images_to_fmaps(self, images): - """ - This function processes images for inference. - - Args: - images (torch.Tensor): The images to be processed with shape S x 3 x H x W. - - Returns: - torch.Tensor: The processed feature maps. - """ - if self.coarse_down_ratio > 1: - # whether or not scale down the input images to save memory - fmaps = self.coarse_fnet( - F.interpolate(images, scale_factor=1 / self.coarse_down_ratio, mode="bilinear", align_corners=True) - ) - else: - fmaps = self.coarse_fnet(images) - - return fmaps diff --git a/capvector-pi05/src/vggt/dependency/vggsfm_utils.py b/capvector-pi05/src/vggt/dependency/vggsfm_utils.py deleted file mode 100644 index d1b75497199d18d62f3fb5db1f203fe5edccedf2..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/dependency/vggsfm_utils.py +++ /dev/null @@ -1,305 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import logging -import warnings -from typing import Dict, List, Optional, Tuple, Union - -import numpy as np -import pycolmap -import torch -import torch.nn.functional as F -from lightglue import ALIKED, SIFT, SuperPoint - -from .vggsfm_tracker import TrackerPredictor - -# Suppress verbose logging from dependencies -logging.getLogger("dinov2").setLevel(logging.WARNING) -warnings.filterwarnings("ignore", message="xFormers is available") -warnings.filterwarnings("ignore", message="dinov2") - -# Constants -_RESNET_MEAN = [0.485, 0.456, 0.406] -_RESNET_STD = [0.229, 0.224, 0.225] - - -def build_vggsfm_tracker(model_path=None): - """ - Build and initialize the VGGSfM tracker. - - Args: - model_path: Path to the model weights file. If None, weights are downloaded from HuggingFace. - - Returns: - Initialized tracker model in eval mode. - """ - tracker = TrackerPredictor() - - if model_path is None: - default_url = "https://huggingface.co/facebook/VGGSfM/resolve/main/vggsfm_v2_tracker.pt" - tracker.load_state_dict(torch.hub.load_state_dict_from_url(default_url)) - else: - tracker.load_state_dict(torch.load(model_path)) - - tracker.eval() - return tracker - - -def generate_rank_by_dino( - images, query_frame_num, image_size=336, model_name="dinov2_vitb14_reg", device="cuda", spatial_similarity=False -): - """ - Generate a ranking of frames using DINO ViT features. - - Args: - images: Tensor of shape (S, 3, H, W) with values in range [0, 1] - query_frame_num: Number of frames to select - image_size: Size to resize images to before processing - model_name: Name of the DINO model to use - device: Device to run the model on - spatial_similarity: Whether to use spatial token similarity or CLS token similarity - - Returns: - List of frame indices ranked by their representativeness - """ - # Resize images to the target size - images = F.interpolate(images, (image_size, image_size), mode="bilinear", align_corners=False) - - # Load DINO model - dino_v2_model = torch.hub.load("facebookresearch/dinov2", model_name) - dino_v2_model.eval() - dino_v2_model = dino_v2_model.to(device) - - # Normalize images using ResNet normalization - resnet_mean = torch.tensor(_RESNET_MEAN, device=device).view(1, 3, 1, 1) - resnet_std = torch.tensor(_RESNET_STD, device=device).view(1, 3, 1, 1) - images_resnet_norm = (images - resnet_mean) / resnet_std - - with torch.no_grad(): - frame_feat = dino_v2_model(images_resnet_norm, is_training=True) - - # Process features based on similarity type - if spatial_similarity: - frame_feat = frame_feat["x_norm_patchtokens"] - frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) - - # Compute the similarity matrix - frame_feat_norm = frame_feat_norm.permute(1, 0, 2) - similarity_matrix = torch.bmm(frame_feat_norm, frame_feat_norm.transpose(-1, -2)) - similarity_matrix = similarity_matrix.mean(dim=0) - else: - frame_feat = frame_feat["x_norm_clstoken"] - frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) - similarity_matrix = torch.mm(frame_feat_norm, frame_feat_norm.transpose(-1, -2)) - - distance_matrix = 100 - similarity_matrix.clone() - - # Ignore self-pairing - similarity_matrix.fill_diagonal_(-100) - similarity_sum = similarity_matrix.sum(dim=1) - - # Find the most common frame - most_common_frame_index = torch.argmax(similarity_sum).item() - - # Conduct FPS sampling starting from the most common frame - fps_idx = farthest_point_sampling(distance_matrix, query_frame_num, most_common_frame_index) - - # Clean up all tensors and models to free memory - del frame_feat, frame_feat_norm, similarity_matrix, distance_matrix - del dino_v2_model - torch.cuda.empty_cache() - - return fps_idx - - -def farthest_point_sampling(distance_matrix, num_samples, most_common_frame_index=0): - """ - Farthest point sampling algorithm to select diverse frames. - - Args: - distance_matrix: Matrix of distances between frames - num_samples: Number of frames to select - most_common_frame_index: Index of the first frame to select - - Returns: - List of selected frame indices - """ - distance_matrix = distance_matrix.clamp(min=0) - N = distance_matrix.size(0) - - # Initialize with the most common frame - selected_indices = [most_common_frame_index] - check_distances = distance_matrix[selected_indices] - - while len(selected_indices) < num_samples: - # Find the farthest point from the current set of selected points - farthest_point = torch.argmax(check_distances) - selected_indices.append(farthest_point.item()) - - check_distances = distance_matrix[farthest_point] - # Mark already selected points to avoid selecting them again - check_distances[selected_indices] = 0 - - # Break if all points have been selected - if len(selected_indices) == N: - break - - return selected_indices - - -def calculate_index_mappings(query_index, S, device=None): - """ - Construct an order that switches [query_index] and [0] - so that the content of query_index would be placed at [0]. - - Args: - query_index: Index to swap with 0 - S: Total number of elements - device: Device to place the tensor on - - Returns: - Tensor of indices with the swapped order - """ - new_order = torch.arange(S) - new_order[0] = query_index - new_order[query_index] = 0 - if device is not None: - new_order = new_order.to(device) - return new_order - - -def switch_tensor_order(tensors, order, dim=1): - """ - Reorder tensors along a specific dimension according to the given order. - - Args: - tensors: List of tensors to reorder - order: Tensor of indices specifying the new order - dim: Dimension along which to reorder - - Returns: - List of reordered tensors - """ - return [torch.index_select(tensor, dim, order) if tensor is not None else None for tensor in tensors] - - -def initialize_feature_extractors(max_query_num, det_thres=0.005, extractor_method="aliked", device="cuda"): - """ - Initialize feature extractors that can be reused based on a method string. - - Args: - max_query_num: Maximum number of keypoints to extract - det_thres: Detection threshold for keypoint extraction - extractor_method: String specifying which extractors to use (e.g., "aliked", "sp+sift", "aliked+sp+sift") - device: Device to run extraction on - - Returns: - Dictionary of initialized extractors - """ - extractors = {} - methods = extractor_method.lower().split("+") - - for method in methods: - method = method.strip() - if method == "aliked": - aliked_extractor = ALIKED(max_num_keypoints=max_query_num, detection_threshold=det_thres) - extractors["aliked"] = aliked_extractor.to(device).eval() - elif method == "sp": - sp_extractor = SuperPoint(max_num_keypoints=max_query_num, detection_threshold=det_thres) - extractors["sp"] = sp_extractor.to(device).eval() - elif method == "sift": - sift_extractor = SIFT(max_num_keypoints=max_query_num) - extractors["sift"] = sift_extractor.to(device).eval() - else: - print(f"Warning: Unknown feature extractor '{method}', ignoring.") - - if not extractors: - print(f"Warning: No valid extractors found in '{extractor_method}'. Using ALIKED by default.") - aliked_extractor = ALIKED(max_num_keypoints=max_query_num, detection_threshold=det_thres) - extractors["aliked"] = aliked_extractor.to(device).eval() - - return extractors - - -def extract_keypoints(query_image, extractors, round_keypoints=True): - """ - Extract keypoints using pre-initialized feature extractors. - - Args: - query_image: Input image tensor (3xHxW, range [0, 1]) - extractors: Dictionary of initialized extractors - - Returns: - Tensor of keypoint coordinates (1xNx2) - """ - query_points = None - - with torch.no_grad(): - for extractor_name, extractor in extractors.items(): - query_points_data = extractor.extract(query_image, invalid_mask=None) - extractor_points = query_points_data["keypoints"] - if round_keypoints: - extractor_points = extractor_points.round() - - if query_points is not None: - query_points = torch.cat([query_points, extractor_points], dim=1) - else: - query_points = extractor_points - - return query_points - - -def predict_tracks_in_chunks( - track_predictor, images_feed, query_points_list, fmaps_feed, fine_tracking, num_splits=None, fine_chunk=40960 -): - """ - Process a list of query points to avoid memory issues. - - Args: - track_predictor (object): The track predictor object used for predicting tracks. - images_feed (torch.Tensor): A tensor of shape (B, T, C, H, W) representing a batch of images. - query_points_list (list or tuple): A list/tuple of tensors, each of shape (B, Ni, 2) representing chunks of query points. - fmaps_feed (torch.Tensor): A tensor of feature maps for the tracker. - fine_tracking (bool): Whether to perform fine tracking. - num_splits (int, optional): Ignored when query_points_list is provided. Kept for backward compatibility. - - Returns: - tuple: A tuple containing the concatenated predicted tracks, visibility, and scores. - """ - # If query_points_list is not a list or tuple but a single tensor, handle it like the old version for backward compatibility - if not isinstance(query_points_list, (list, tuple)): - query_points = query_points_list - if num_splits is None: - num_splits = 1 - query_points_list = torch.chunk(query_points, num_splits, dim=1) - - # Ensure query_points_list is a list for iteration (as torch.chunk returns a tuple) - if isinstance(query_points_list, tuple): - query_points_list = list(query_points_list) - - fine_pred_track_list = [] - pred_vis_list = [] - pred_score_list = [] - - for split_points in query_points_list: - # Feed into track predictor for each split - fine_pred_track, _, pred_vis, pred_score = track_predictor( - images_feed, split_points, fmaps=fmaps_feed, fine_tracking=fine_tracking, fine_chunk=fine_chunk - ) - fine_pred_track_list.append(fine_pred_track) - pred_vis_list.append(pred_vis) - pred_score_list.append(pred_score) - - # Concatenate the results from all splits - fine_pred_track = torch.cat(fine_pred_track_list, dim=2) - pred_vis = torch.cat(pred_vis_list, dim=2) - - if pred_score is not None: - pred_score = torch.cat(pred_score_list, dim=2) - else: - pred_score = None - - return fine_pred_track, pred_vis, pred_score diff --git a/capvector-pi05/src/vggt/heads/camera_head.py b/capvector-pi05/src/vggt/heads/camera_head.py deleted file mode 100644 index b88b7d6e909adf45853af03d546343b5a8bbe472..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/heads/camera_head.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import math -import numpy as np - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from vggt.layers import Mlp -from vggt.layers.block import Block -from vggt.heads.head_act import activate_pose - - -class CameraHead(nn.Module): - """ - CameraHead predicts camera parameters from token representations using iterative refinement. - - It applies a series of transformer blocks (the "trunk") to dedicated camera tokens. - """ - - def __init__( - self, - dim_in: int = 2048, - trunk_depth: int = 4, - pose_encoding_type: str = "absT_quaR_FoV", - num_heads: int = 16, - mlp_ratio: int = 4, - init_values: float = 0.01, - trans_act: str = "linear", - quat_act: str = "linear", - fl_act: str = "relu", # Field of view activations: ensures FOV values are positive. - ): - super().__init__() - - if pose_encoding_type == "absT_quaR_FoV": - self.target_dim = 9 - else: - raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}") - - self.trans_act = trans_act - self.quat_act = quat_act - self.fl_act = fl_act - self.trunk_depth = trunk_depth - - # Build the trunk using a sequence of transformer blocks. - self.trunk = nn.Sequential( - *[ - Block(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values) - for _ in range(trunk_depth) - ] - ) - - # Normalizations for camera token and trunk output. - self.token_norm = nn.LayerNorm(dim_in) - self.trunk_norm = nn.LayerNorm(dim_in) - - # Learnable empty camera pose token. - self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim)) - self.embed_pose = nn.Linear(self.target_dim, dim_in) - - # Module for producing modulation parameters: shift, scale, and a gate. - self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True)) - - # Adaptive layer normalization without affine parameters. - self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6) - self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0) - - def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list: - """ - Forward pass to predict camera parameters. - - Args: - aggregated_tokens_list (list): List of token tensors from the network; - the last tensor is used for prediction. - num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4. - - Returns: - list: A list of predicted camera encodings (post-activation) from each iteration. - """ - # Use tokens from the last block for camera prediction. - tokens = aggregated_tokens_list[-1] - - # Extract the camera tokens - pose_tokens = tokens[:, :, 0] - pose_tokens = self.token_norm(pose_tokens) - - pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations) - return pred_pose_enc_list - - def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list: - """ - Iteratively refine camera pose predictions. - - Args: - pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, S, C]. - num_iterations (int): Number of refinement iterations. - - Returns: - list: List of activated camera encodings from each iteration. - """ - B, S, C = pose_tokens.shape - pred_pose_enc = None - pred_pose_enc_list = [] - - for _ in range(num_iterations): - # Use a learned empty pose for the first iteration. - if pred_pose_enc is None: - module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1)) - else: - # Detach the previous prediction to avoid backprop through time. - pred_pose_enc = pred_pose_enc.detach() - module_input = self.embed_pose(pred_pose_enc) - - # Generate modulation parameters and split them into shift, scale, and gate components. - shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1) - - # Adaptive layer normalization and modulation. - pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa) - pose_tokens_modulated = pose_tokens_modulated + pose_tokens - - pose_tokens_modulated = self.trunk(pose_tokens_modulated) - # Compute the delta update for the pose encoding. - pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated)) - - if pred_pose_enc is None: - pred_pose_enc = pred_pose_enc_delta - else: - pred_pose_enc = pred_pose_enc + pred_pose_enc_delta - - # Apply final activation functions for translation, quaternion, and field-of-view. - activated_pose = activate_pose( - pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act - ) - pred_pose_enc_list.append(activated_pose) - - return pred_pose_enc_list - - -def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: - """ - Modulate the input tensor using scaling and shifting parameters. - """ - # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19 - return x * (1 + scale) + shift diff --git a/capvector-pi05/src/vggt/heads/dpt_head.py b/capvector-pi05/src/vggt/heads/dpt_head.py deleted file mode 100644 index 6f88a404f50735ece07b44714c986de0f4efcfe3..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/heads/dpt_head.py +++ /dev/null @@ -1,484 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - - -# Inspired by https://github.com/DepthAnything/Depth-Anything-V2 - - -import os -from typing import List, Dict, Tuple, Union - -import torch -import torch.nn as nn -import torch.nn.functional as F -from .head_act import activate_head -from .utils import create_uv_grid, position_grid_to_embed - - -class DPTHead(nn.Module): - """ - DPT Head for dense prediction tasks. - - This implementation follows the architecture described in "Vision Transformers for Dense Prediction" - (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer - backbone and produces dense predictions by fusing multi-scale features. - - Args: - dim_in (int): Input dimension (channels). - patch_size (int, optional): Patch size. Default is 14. - output_dim (int, optional): Number of output channels. Default is 4. - activation (str, optional): Activation type. Default is "inv_log". - conf_activation (str, optional): Confidence activation type. Default is "expp1". - features (int, optional): Feature channels for intermediate representations. Default is 256. - out_channels (List[int], optional): Output channels for each intermediate layer. - intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT. - pos_embed (bool, optional): Whether to use positional embedding. Default is True. - feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False. - down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1. - """ - - def __init__( - self, - dim_in: int, - patch_size: int = 14, - output_dim: int = 4, - activation: str = "inv_log", - conf_activation: str = "expp1", - features: int = 256, - out_channels: List[int] = [256, 512, 1024, 1024], - intermediate_layer_idx: List[int] = [4, 11, 17, 23], - pos_embed: bool = True, - feature_only: bool = False, - down_ratio: int = 1, - ) -> None: - super(DPTHead, self).__init__() - self.patch_size = patch_size - self.activation = activation - self.conf_activation = conf_activation - self.pos_embed = pos_embed - self.feature_only = feature_only - self.down_ratio = down_ratio - self.intermediate_layer_idx = intermediate_layer_idx - - self.norm = nn.LayerNorm(dim_in) - - # Projection layers for each output channel from tokens. - self.projects = nn.ModuleList( - [nn.Conv2d(in_channels=dim_in, out_channels=oc, kernel_size=1, stride=1, padding=0) for oc in out_channels] - ) - - # Resize layers for upsampling feature maps. - self.resize_layers = nn.ModuleList( - [ - nn.ConvTranspose2d( - in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0 - ), - nn.ConvTranspose2d( - in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0 - ), - nn.Identity(), - nn.Conv2d( - in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1 - ), - ] - ) - - self.scratch = _make_scratch(out_channels, features, expand=False) - - # Attach additional modules to scratch. - self.scratch.stem_transpose = None - self.scratch.refinenet1 = _make_fusion_block(features) - self.scratch.refinenet2 = _make_fusion_block(features) - self.scratch.refinenet3 = _make_fusion_block(features) - self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False) - - head_features_1 = features - head_features_2 = 32 - - if feature_only: - self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1) - else: - self.scratch.output_conv1 = nn.Conv2d( - head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1 - ) - conv2_in_channels = head_features_1 // 2 - - self.scratch.output_conv2 = nn.Sequential( - nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1), - nn.ReLU(inplace=True), - nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0), - ) - - def forward( - self, - aggregated_tokens_list: List[torch.Tensor], - images: torch.Tensor, - patch_start_idx: int, - frames_chunk_size: int = 8, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - Forward pass through the DPT head, supports processing by chunking frames. - Args: - aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. - images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. - patch_start_idx (int): Starting index for patch tokens in the token sequence. - Used to separate patch tokens from other tokens (e.g., camera or register tokens). - frames_chunk_size (int, optional): Number of frames to process in each chunk. - If None or larger than S, all frames are processed at once. Default: 8. - - Returns: - Tensor or Tuple[Tensor, Tensor]: - - If feature_only=True: Feature maps with shape [B, S, C, H, W] - - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W] - """ - B, S, _, H, W = images.shape - - # If frames_chunk_size is not specified or greater than S, process all frames at once - if frames_chunk_size is None or frames_chunk_size >= S: - return self._forward_impl(aggregated_tokens_list, images, patch_start_idx) - - # Otherwise, process frames in chunks to manage memory usage - assert frames_chunk_size > 0 - - # Process frames in batches - all_preds = [] - all_conf = [] - - for frames_start_idx in range(0, S, frames_chunk_size): - frames_end_idx = min(frames_start_idx + frames_chunk_size, S) - - # Process batch of frames - if self.feature_only: - chunk_output = self._forward_impl( - aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx - ) - all_preds.append(chunk_output) - else: - chunk_preds, chunk_conf = self._forward_impl( - aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx - ) - all_preds.append(chunk_preds) - all_conf.append(chunk_conf) - - # Concatenate results along the sequence dimension - if self.feature_only: - return torch.cat(all_preds, dim=1) - else: - return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1) - - def _forward_impl( - self, - aggregated_tokens_list: List[torch.Tensor], - images: torch.Tensor, - patch_start_idx: int, - frames_start_idx: int = None, - frames_end_idx: int = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - Implementation of the forward pass through the DPT head. - - This method processes a specific chunk of frames from the sequence. - - Args: - aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. - images (Tensor): Input images with shape [B, S, 3, H, W]. - patch_start_idx (int): Starting index for patch tokens. - frames_start_idx (int, optional): Starting index for frames to process. - frames_end_idx (int, optional): Ending index for frames to process. - - Returns: - Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence). - """ - if frames_start_idx is not None and frames_end_idx is not None: - images = images[:, frames_start_idx:frames_end_idx].contiguous() - - B, S, _, H, W = images.shape - - patch_h, patch_w = H // self.patch_size, W // self.patch_size - - out = [] - dpt_idx = 0 - - for layer_idx in self.intermediate_layer_idx: - x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:] - - # Select frames if processing a chunk - if frames_start_idx is not None and frames_end_idx is not None: - x = x[:, frames_start_idx:frames_end_idx] - - x = x.reshape(B * S, -1, x.shape[-1]) - - x = self.norm(x) - - x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) - - x = self.projects[dpt_idx](x) - if self.pos_embed: - x = self._apply_pos_embed(x, W, H) - x = self.resize_layers[dpt_idx](x) - - out.append(x) - dpt_idx += 1 - - # Fuse features from multiple layers. - out = self.scratch_forward(out) - # Interpolate fused output to match target image resolution. - out = custom_interpolate( - out, - (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)), - mode="bilinear", - align_corners=True, - ) - - if self.pos_embed: - out = self._apply_pos_embed(out, W, H) - - if self.feature_only: - return out.view(B, S, *out.shape[1:]) - - out = self.scratch.output_conv2(out) - preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation) - - preds = preds.view(B, S, *preds.shape[1:]) - conf = conf.view(B, S, *conf.shape[1:]) - return preds, conf - - def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor: - """ - Apply positional embedding to tensor x. - """ - patch_w = x.shape[-1] - patch_h = x.shape[-2] - pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device) - pos_embed = position_grid_to_embed(pos_embed, x.shape[1]) - pos_embed = pos_embed * ratio - pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1) - return x + pos_embed - - def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor: - """ - Forward pass through the fusion blocks. - - Args: - features (List[Tensor]): List of feature maps from different layers. - - Returns: - Tensor: Fused feature map. - """ - layer_1, layer_2, layer_3, layer_4 = features - - layer_1_rn = self.scratch.layer1_rn(layer_1) - layer_2_rn = self.scratch.layer2_rn(layer_2) - layer_3_rn = self.scratch.layer3_rn(layer_3) - layer_4_rn = self.scratch.layer4_rn(layer_4) - - out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) - del layer_4_rn, layer_4 - - out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:]) - del layer_3_rn, layer_3 - - out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:]) - del layer_2_rn, layer_2 - - out = self.scratch.refinenet1(out, layer_1_rn) - del layer_1_rn, layer_1 - - out = self.scratch.output_conv1(out) - return out - - -################################################################################ -# Modules -################################################################################ - - -def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module: - return FeatureFusionBlock( - features, - nn.ReLU(inplace=True), - deconv=False, - bn=False, - expand=False, - align_corners=True, - size=size, - has_residual=has_residual, - groups=groups, - ) - - -def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module: - scratch = nn.Module() - out_shape1 = out_shape - out_shape2 = out_shape - out_shape3 = out_shape - if len(in_shape) >= 4: - out_shape4 = out_shape - - if expand: - out_shape1 = out_shape - out_shape2 = out_shape * 2 - out_shape3 = out_shape * 4 - if len(in_shape) >= 4: - out_shape4 = out_shape * 8 - - scratch.layer1_rn = nn.Conv2d( - in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups - ) - scratch.layer2_rn = nn.Conv2d( - in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups - ) - scratch.layer3_rn = nn.Conv2d( - in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups - ) - if len(in_shape) >= 4: - scratch.layer4_rn = nn.Conv2d( - in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups - ) - return scratch - - -class ResidualConvUnit(nn.Module): - """Residual convolution module.""" - - def __init__(self, features, activation, bn, groups=1): - """Init. - - Args: - features (int): number of features - """ - super().__init__() - - self.bn = bn - self.groups = groups - self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) - self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) - - self.norm1 = None - self.norm2 = None - - self.activation = activation - self.skip_add = nn.quantized.FloatFunctional() - - def forward(self, x): - """Forward pass. - - Args: - x (tensor): input - - Returns: - tensor: output - """ - - out = self.activation(x) - out = self.conv1(out) - if self.norm1 is not None: - out = self.norm1(out) - - out = self.activation(out) - out = self.conv2(out) - if self.norm2 is not None: - out = self.norm2(out) - - return self.skip_add.add(out, x) - - -class FeatureFusionBlock(nn.Module): - """Feature fusion block.""" - - def __init__( - self, - features, - activation, - deconv=False, - bn=False, - expand=False, - align_corners=True, - size=None, - has_residual=True, - groups=1, - ): - """Init. - - Args: - features (int): number of features - """ - super(FeatureFusionBlock, self).__init__() - - self.deconv = deconv - self.align_corners = align_corners - self.groups = groups - self.expand = expand - out_features = features - if self.expand == True: - out_features = features // 2 - - self.out_conv = nn.Conv2d( - features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups - ) - - if has_residual: - self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups) - - self.has_residual = has_residual - self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups) - - self.skip_add = nn.quantized.FloatFunctional() - self.size = size - - def forward(self, *xs, size=None): - """Forward pass. - - Returns: - tensor: output - """ - output = xs[0] - - if self.has_residual: - res = self.resConfUnit1(xs[1]) - output = self.skip_add.add(output, res) - - output = self.resConfUnit2(output) - - if (size is None) and (self.size is None): - modifier = {"scale_factor": 2} - elif size is None: - modifier = {"size": self.size} - else: - modifier = {"size": size} - - output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) - output = self.out_conv(output) - - return output - - -def custom_interpolate( - x: torch.Tensor, - size: Tuple[int, int] = None, - scale_factor: float = None, - mode: str = "bilinear", - align_corners: bool = True, -) -> torch.Tensor: - """ - Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate. - """ - if size is None: - size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor)) - - INT_MAX = 1610612736 - - input_elements = size[0] * size[1] * x.shape[0] * x.shape[1] - - if input_elements > INT_MAX: - chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0) - interpolated_chunks = [ - nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks - ] - x = torch.cat(interpolated_chunks, dim=0) - return x.contiguous() - else: - return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners) diff --git a/capvector-pi05/src/vggt/heads/head_act.py b/capvector-pi05/src/vggt/heads/head_act.py deleted file mode 100644 index a37669d50cf9b52dba297c2c7a5bea00987bb67d..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/heads/head_act.py +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - - -import torch -import torch.nn.functional as F - - -def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"): - """ - Activate pose parameters with specified activation functions. - - Args: - pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length] - trans_act: Activation type for translation component - quat_act: Activation type for quaternion component - fl_act: Activation type for focal length component - - Returns: - Activated pose parameters tensor - """ - T = pred_pose_enc[..., :3] - quat = pred_pose_enc[..., 3:7] - fl = pred_pose_enc[..., 7:] # or fov - - T = base_pose_act(T, trans_act) - quat = base_pose_act(quat, quat_act) - fl = base_pose_act(fl, fl_act) # or fov - - pred_pose_enc = torch.cat([T, quat, fl], dim=-1) - - return pred_pose_enc - - -def base_pose_act(pose_enc, act_type="linear"): - """ - Apply basic activation function to pose parameters. - - Args: - pose_enc: Tensor containing encoded pose parameters - act_type: Activation type ("linear", "inv_log", "exp", "relu") - - Returns: - Activated pose parameters - """ - if act_type == "linear": - return pose_enc - elif act_type == "inv_log": - return inverse_log_transform(pose_enc) - elif act_type == "exp": - return torch.exp(pose_enc) - elif act_type == "relu": - return F.relu(pose_enc) - else: - raise ValueError(f"Unknown act_type: {act_type}") - - -def activate_head(out, activation="norm_exp", conf_activation="expp1"): - """ - Process network output to extract 3D points and confidence values. - - Args: - out: Network output tensor (B, C, H, W) - activation: Activation type for 3D points - conf_activation: Activation type for confidence values - - Returns: - Tuple of (3D points tensor, confidence tensor) - """ - # Move channels from last dim to the 4th dimension => (B, H, W, C) - fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected - - # Split into xyz (first C-1 channels) and confidence (last channel) - xyz = fmap[:, :, :, :-1] - conf = fmap[:, :, :, -1] - - if activation == "norm_exp": - d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8) - xyz_normed = xyz / d - pts3d = xyz_normed * torch.expm1(d) - elif activation == "norm": - pts3d = xyz / xyz.norm(dim=-1, keepdim=True) - elif activation == "exp": - pts3d = torch.exp(xyz) - elif activation == "relu": - pts3d = F.relu(xyz) - elif activation == "inv_log": - pts3d = inverse_log_transform(xyz) - elif activation == "xy_inv_log": - xy, z = xyz.split([2, 1], dim=-1) - z = inverse_log_transform(z) - pts3d = torch.cat([xy * z, z], dim=-1) - elif activation == "sigmoid": - pts3d = torch.sigmoid(xyz) - elif activation == "linear": - pts3d = xyz - else: - raise ValueError(f"Unknown activation: {activation}") - - if conf_activation == "expp1": - conf_out = 1 + conf.exp() - elif conf_activation == "expp0": - conf_out = conf.exp() - elif conf_activation == "sigmoid": - conf_out = torch.sigmoid(conf) - else: - raise ValueError(f"Unknown conf_activation: {conf_activation}") - - return pts3d, conf_out - - -def inverse_log_transform(y): - """ - Apply inverse log transform: sign(y) * (exp(|y|) - 1) - - Args: - y: Input tensor - - Returns: - Transformed tensor - """ - return torch.sign(y) * (torch.expm1(torch.abs(y))) diff --git a/capvector-pi05/src/vggt/heads/track_head.py b/capvector-pi05/src/vggt/heads/track_head.py deleted file mode 100644 index e6356cbd8273557deac9225cd09125ebca34fc65..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/heads/track_head.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch.nn as nn -from .dpt_head import DPTHead -from .track_modules.base_track_predictor import BaseTrackerPredictor - - -class TrackHead(nn.Module): - """ - Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking. - The tracking is performed iteratively, refining predictions over multiple iterations. - """ - - def __init__( - self, - dim_in, - patch_size=14, - features=128, - iters=4, - predict_conf=True, - stride=2, - corr_levels=7, - corr_radius=4, - hidden_size=384, - ): - """ - Initialize the TrackHead module. - - Args: - dim_in (int): Input dimension of tokens from the backbone. - patch_size (int): Size of image patches used in the vision transformer. - features (int): Number of feature channels in the feature extractor output. - iters (int): Number of refinement iterations for tracking predictions. - predict_conf (bool): Whether to predict confidence scores for tracked points. - stride (int): Stride value for the tracker predictor. - corr_levels (int): Number of correlation pyramid levels - corr_radius (int): Radius for correlation computation, controlling the search area. - hidden_size (int): Size of hidden layers in the tracker network. - """ - super().__init__() - - self.patch_size = patch_size - - # Feature extractor based on DPT architecture - # Processes tokens into feature maps for tracking - self.feature_extractor = DPTHead( - dim_in=dim_in, - patch_size=patch_size, - features=features, - feature_only=True, # Only output features, no activation - down_ratio=2, # Reduces spatial dimensions by factor of 2 - pos_embed=False, - ) - - # Tracker module that predicts point trajectories - # Takes feature maps and predicts coordinates and visibility - self.tracker = BaseTrackerPredictor( - latent_dim=features, # Match the output_dim of feature extractor - predict_conf=predict_conf, - stride=stride, - corr_levels=corr_levels, - corr_radius=corr_radius, - hidden_size=hidden_size, - ) - - self.iters = iters - - def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None): - """ - Forward pass of the TrackHead. - - Args: - aggregated_tokens_list (list): List of aggregated tokens from the backbone. - images (torch.Tensor): Input images of shape (B, S, C, H, W) where: - B = batch size, S = sequence length. - patch_start_idx (int): Starting index for patch tokens. - query_points (torch.Tensor, optional): Initial query points to track. - If None, points are initialized by the tracker. - iters (int, optional): Number of refinement iterations. If None, uses self.iters. - - Returns: - tuple: - - coord_preds (torch.Tensor): Predicted coordinates for tracked points. - - vis_scores (torch.Tensor): Visibility scores for tracked points. - - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True). - """ - B, S, _, H, W = images.shape - - # Extract features from tokens - # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2 - feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx) - - # Use default iterations if not specified - if iters is None: - iters = self.iters - - # Perform tracking using the extracted features - coord_preds, vis_scores, conf_scores = self.tracker(query_points=query_points, fmaps=feature_maps, iters=iters) - - return coord_preds, vis_scores, conf_scores diff --git a/capvector-pi05/src/vggt/heads/track_modules/__init__.py b/capvector-pi05/src/vggt/heads/track_modules/__init__.py deleted file mode 100644 index c4196294309799347172dba54a17360698071ca8..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/heads/track_modules/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. diff --git a/capvector-pi05/src/vggt/heads/track_modules/base_track_predictor.py b/capvector-pi05/src/vggt/heads/track_modules/base_track_predictor.py deleted file mode 100644 index 540c1d110d4b35b36fdbd2a8a81121d9f9cf2f9b..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/heads/track_modules/base_track_predictor.py +++ /dev/null @@ -1,209 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import torch.nn as nn -from einops import rearrange, repeat - - -from .blocks import EfficientUpdateFormer, CorrBlock -from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed -from .modules import Mlp - - -class BaseTrackerPredictor(nn.Module): - def __init__( - self, - stride=1, - corr_levels=5, - corr_radius=4, - latent_dim=128, - hidden_size=384, - use_spaceatt=True, - depth=6, - max_scale=518, - predict_conf=True, - ): - super(BaseTrackerPredictor, self).__init__() - """ - The base template to create a track predictor - - Modified from https://github.com/facebookresearch/co-tracker/ - and https://github.com/facebookresearch/vggsfm - """ - - self.stride = stride - self.latent_dim = latent_dim - self.corr_levels = corr_levels - self.corr_radius = corr_radius - self.hidden_size = hidden_size - self.max_scale = max_scale - self.predict_conf = predict_conf - - self.flows_emb_dim = latent_dim // 2 - - self.corr_mlp = Mlp( - in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2, - hidden_features=self.hidden_size, - out_features=self.latent_dim, - ) - - self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4 - - self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim)) - - space_depth = depth if use_spaceatt else 0 - time_depth = depth - - self.updateformer = EfficientUpdateFormer( - space_depth=space_depth, - time_depth=time_depth, - input_dim=self.transformer_dim, - hidden_size=self.hidden_size, - output_dim=self.latent_dim + 2, - mlp_ratio=4.0, - add_space_attn=use_spaceatt, - ) - - self.fmap_norm = nn.LayerNorm(self.latent_dim) - self.ffeat_norm = nn.GroupNorm(1, self.latent_dim) - - # A linear layer to update track feats at each iteration - self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU()) - - self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) - - if predict_conf: - self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) - - def forward(self, query_points, fmaps=None, iters=6, return_feat=False, down_ratio=1, apply_sigmoid=True): - """ - query_points: B x N x 2, the number of batches, tracks, and xy - fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension. - note HH and WW is the size of feature maps instead of original images - """ - B, N, D = query_points.shape - B, S, C, HH, WW = fmaps.shape - - assert D == 2, "Input points must be 2D coordinates" - - # apply a layernorm to fmaps here - fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2)) - fmaps = fmaps.permute(0, 1, 4, 2, 3) - - # Scale the input query_points because we may downsample the images - # by down_ratio or self.stride - # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map - # its query_points should be query_points/4 - if down_ratio > 1: - query_points = query_points / float(down_ratio) - - query_points = query_points / float(self.stride) - - # Init with coords as the query points - # It means the search will start from the position of query points at the reference frames - coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1) - - # Sample/extract the features of the query points in the query frame - query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0]) - - # init track feats by query feats - track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C - # back up the init coords - coords_backup = coords.clone() - - fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius) - - coord_preds = [] - - # Iterative Refinement - for _ in range(iters): - # Detach the gradients from the last iteration - # (in my experience, not very important for performance) - coords = coords.detach() - - fcorrs = fcorr_fn.corr_sample(track_feats, coords) - - corr_dim = fcorrs.shape[3] - fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim) - fcorrs_ = self.corr_mlp(fcorrs_) - - # Movement of current coords relative to query points - flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2) - - flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False) - - # (In my trials, it is also okay to just add the flows_emb instead of concat) - flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1) - - track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim) - - # Concatenate them as the input for the transformers - transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2) - - # 2D positional embed - # TODO: this can be much simplified - pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device) - sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0]) - - sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1) - - x = transformer_input + sampled_pos_emb - - # Add the query ref token to the track feats - query_ref_token = torch.cat( - [self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1 - ) - x = x + query_ref_token.to(x.device).to(x.dtype) - - # B, N, S, C - x = rearrange(x, "(b n) s d -> b n s d", b=B) - - # Compute the delta coordinates and delta track features - delta, _ = self.updateformer(x) - - # BN, S, C - delta = rearrange(delta, " b n s d -> (b n) s d", b=B) - delta_coords_ = delta[:, :, :2] - delta_feats_ = delta[:, :, 2:] - - track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim) - delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim) - - # Update the track features - track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_ - - track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC - - # B x S x N x 2 - coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3) - - # Force coord0 as query - # because we assume the query points should not be changed - coords[:, 0] = coords_backup[:, 0] - - # The predicted tracks are in the original image scale - if down_ratio > 1: - coord_preds.append(coords * self.stride * down_ratio) - else: - coord_preds.append(coords * self.stride) - - # B, S, N - vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N) - if apply_sigmoid: - vis_e = torch.sigmoid(vis_e) - - if self.predict_conf: - conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N) - if apply_sigmoid: - conf_e = torch.sigmoid(conf_e) - else: - conf_e = None - - if return_feat: - return coord_preds, vis_e, track_feats, query_track_feat, conf_e - else: - return coord_preds, vis_e, conf_e diff --git a/capvector-pi05/src/vggt/heads/track_modules/blocks.py b/capvector-pi05/src/vggt/heads/track_modules/blocks.py deleted file mode 100644 index 394c31d120a716bee1e82911c841c0f63d9965d3..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/heads/track_modules/blocks.py +++ /dev/null @@ -1,236 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - - -# Modified from https://github.com/facebookresearch/co-tracker/ - -import math -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .utils import bilinear_sampler -from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock - - -class EfficientUpdateFormer(nn.Module): - """ - Transformer model that updates track estimates. - """ - - def __init__( - self, - space_depth=6, - time_depth=6, - input_dim=320, - hidden_size=384, - num_heads=8, - output_dim=130, - mlp_ratio=4.0, - add_space_attn=True, - num_virtual_tracks=64, - ): - super().__init__() - - self.out_channels = 2 - self.num_heads = num_heads - self.hidden_size = hidden_size - self.add_space_attn = add_space_attn - - # Add input LayerNorm before linear projection - self.input_norm = nn.LayerNorm(input_dim) - self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) - - # Add output LayerNorm before final projection - self.output_norm = nn.LayerNorm(hidden_size) - self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True) - self.num_virtual_tracks = num_virtual_tracks - - if self.add_space_attn: - self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size)) - else: - self.virual_tracks = None - - self.time_blocks = nn.ModuleList( - [ - AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention) - for _ in range(time_depth) - ] - ) - - if add_space_attn: - self.space_virtual_blocks = nn.ModuleList( - [ - AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention) - for _ in range(space_depth) - ] - ) - self.space_point2virtual_blocks = nn.ModuleList( - [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] - ) - self.space_virtual2point_blocks = nn.ModuleList( - [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] - ) - assert len(self.time_blocks) >= len(self.space_virtual2point_blocks) - self.initialize_weights() - - def initialize_weights(self): - def _basic_init(module): - if isinstance(module, nn.Linear): - torch.nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001) - - self.apply(_basic_init) - - def forward(self, input_tensor, mask=None): - # Apply input LayerNorm - input_tensor = self.input_norm(input_tensor) - tokens = self.input_transform(input_tensor) - - init_tokens = tokens - - B, _, T, _ = tokens.shape - - if self.add_space_attn: - virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1) - tokens = torch.cat([tokens, virtual_tokens], dim=1) - - _, N, _, _ = tokens.shape - - j = 0 - for i in range(len(self.time_blocks)): - time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C - - time_tokens = self.time_blocks[i](time_tokens) - - tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C - if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0): - space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C - point_tokens = space_tokens[:, : N - self.num_virtual_tracks] - virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :] - - virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask) - virtual_tokens = self.space_virtual_blocks[j](virtual_tokens) - point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask) - - space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1) - tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C - j += 1 - - if self.add_space_attn: - tokens = tokens[:, : N - self.num_virtual_tracks] - - tokens = tokens + init_tokens - - # Apply output LayerNorm before final projection - tokens = self.output_norm(tokens) - flow = self.flow_head(tokens) - - return flow, None - - -class CorrBlock: - def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"): - """ - Build a pyramid of feature maps from the input. - - fmaps: Tensor (B, S, C, H, W) - num_levels: number of pyramid levels (each downsampled by factor 2) - radius: search radius for sampling correlation - multiple_track_feats: if True, split the target features per pyramid level - padding_mode: passed to grid_sample / bilinear_sampler - """ - B, S, C, H, W = fmaps.shape - self.S, self.C, self.H, self.W = S, C, H, W - self.num_levels = num_levels - self.radius = radius - self.padding_mode = padding_mode - self.multiple_track_feats = multiple_track_feats - - # Build pyramid: each level is half the spatial resolution of the previous - self.fmaps_pyramid = [fmaps] # level 0 is full resolution - current_fmaps = fmaps - for i in range(num_levels - 1): - B, S, C, H, W = current_fmaps.shape - # Merge batch & sequence dimensions - current_fmaps = current_fmaps.reshape(B * S, C, H, W) - # Avg pool down by factor 2 - current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2) - _, _, H_new, W_new = current_fmaps.shape - current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new) - self.fmaps_pyramid.append(current_fmaps) - - # Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling. - # This grid is added to the (scaled) coordinate centroids. - r = self.radius - dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype) - dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype) - # delta: for every (dy,dx) displacement (i.e. Δx, Δy) - self.delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), dim=-1) # shape: (2r+1, 2r+1, 2) - - def corr_sample(self, targets, coords): - """ - Instead of storing the entire correlation pyramid, we compute each level's correlation - volume, sample it immediately, then discard it. This saves GPU memory. - - Args: - targets: Tensor (B, S, N, C) — features for the current targets. - coords: Tensor (B, S, N, 2) — coordinates at full resolution. - - Returns: - Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations) - """ - B, S, N, C = targets.shape - - # If you have multiple track features, split them per level. - if self.multiple_track_feats: - targets_split = torch.split(targets, C // self.num_levels, dim=-1) - - out_pyramid = [] - for i, fmaps in enumerate(self.fmaps_pyramid): - # Get current spatial resolution H, W for this pyramid level. - B, S, C, H, W = fmaps.shape - # Reshape feature maps for correlation computation: - # fmap2s: (B, S, C, H*W) - fmap2s = fmaps.view(B, S, C, H * W) - # Choose appropriate target features. - fmap1 = targets_split[i] if self.multiple_track_feats else targets # shape: (B, S, N, C) - - # Compute correlation directly - corrs = compute_corr_level(fmap1, fmap2s, C) - corrs = corrs.view(B, S, N, H, W) - - # Prepare sampling grid: - # Scale down the coordinates for the current level. - centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i) - # Make sure our precomputed delta grid is on the same device/dtype. - delta_lvl = self.delta.to(coords.device).to(coords.dtype) - # Now the grid for grid_sample is: - # coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid) - coords_lvl = centroid_lvl + delta_lvl.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2) - - # Sample from the correlation volume using bilinear interpolation. - # We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target. - corrs_sampled = bilinear_sampler( - corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode - ) - # The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims. - corrs_sampled = corrs_sampled.view(B, S, N, -1) # Now shape: (B, S, N, (2r+1)^2) - out_pyramid.append(corrs_sampled) - - # Concatenate all levels along the last dimension. - out = torch.cat(out_pyramid, dim=-1).contiguous() - return out - - -def compute_corr_level(fmap1, fmap2s, C): - # fmap1: (B, S, N, C) - # fmap2s: (B, S, C, H*W) - corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W) - corrs = corrs.view(fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1) # (B, S, N, H*W) - return corrs / math.sqrt(C) diff --git a/capvector-pi05/src/vggt/heads/track_modules/modules.py b/capvector-pi05/src/vggt/heads/track_modules/modules.py deleted file mode 100644 index 84a9f64bda7d749f01b9b9243b13659461008355..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/heads/track_modules/modules.py +++ /dev/null @@ -1,204 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - - -import torch -import torch.nn as nn -import torch.nn.functional as F -from functools import partial -from typing import Callable -import collections -from torch import Tensor -from itertools import repeat - - -# From PyTorch internals -def _ntuple(n): - def parse(x): - if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): - return tuple(x) - return tuple(repeat(x, n)) - - return parse - - -def exists(val): - return val is not None - - -def default(val, d): - return val if exists(val) else d - - -to_2tuple = _ntuple(2) - - -class ResidualBlock(nn.Module): - """ - ResidualBlock: construct a block of two conv layers with residual connections - """ - - def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3): - super(ResidualBlock, self).__init__() - - self.conv1 = nn.Conv2d( - in_planes, planes, kernel_size=kernel_size, padding=1, stride=stride, padding_mode="zeros" - ) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, padding=1, padding_mode="zeros") - self.relu = nn.ReLU(inplace=True) - - num_groups = planes // 8 - - if norm_fn == "group": - self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - if not stride == 1: - self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - - elif norm_fn == "batch": - self.norm1 = nn.BatchNorm2d(planes) - self.norm2 = nn.BatchNorm2d(planes) - if not stride == 1: - self.norm3 = nn.BatchNorm2d(planes) - - elif norm_fn == "instance": - self.norm1 = nn.InstanceNorm2d(planes) - self.norm2 = nn.InstanceNorm2d(planes) - if not stride == 1: - self.norm3 = nn.InstanceNorm2d(planes) - - elif norm_fn == "none": - self.norm1 = nn.Sequential() - self.norm2 = nn.Sequential() - if not stride == 1: - self.norm3 = nn.Sequential() - else: - raise NotImplementedError - - if stride == 1: - self.downsample = None - else: - self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) - - def forward(self, x): - y = x - y = self.relu(self.norm1(self.conv1(y))) - y = self.relu(self.norm2(self.conv2(y))) - - if self.downsample is not None: - x = self.downsample(x) - - return self.relu(x + y) - - -class Mlp(nn.Module): - """MLP as used in Vision Transformer, MLP-Mixer and related networks""" - - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.GELU, - norm_layer=None, - bias=True, - drop=0.0, - use_conv=False, - ): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - bias = to_2tuple(bias) - drop_probs = to_2tuple(drop) - linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear - - self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) - self.act = act_layer() - self.drop1 = nn.Dropout(drop_probs[0]) - self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) - self.drop2 = nn.Dropout(drop_probs[1]) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop1(x) - x = self.fc2(x) - x = self.drop2(x) - return x - - -class AttnBlock(nn.Module): - def __init__( - self, - hidden_size, - num_heads, - attn_class: Callable[..., nn.Module] = nn.MultiheadAttention, - mlp_ratio=4.0, - **block_kwargs, - ): - """ - Self attention block - """ - super().__init__() - - self.norm1 = nn.LayerNorm(hidden_size) - self.norm2 = nn.LayerNorm(hidden_size) - - self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs) - - mlp_hidden_dim = int(hidden_size * mlp_ratio) - - self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) - - def forward(self, x, mask=None): - # Prepare the mask for PyTorch's attention (it expects a different format) - # attn_mask = mask if mask is not None else None - # Normalize before attention - x = self.norm1(x) - - # PyTorch's MultiheadAttention returns attn_output, attn_output_weights - # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask) - - attn_output, _ = self.attn(x, x, x) - - # Add & Norm - x = x + attn_output - x = x + self.mlp(self.norm2(x)) - return x - - -class CrossAttnBlock(nn.Module): - def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs): - """ - Cross attention block - """ - super().__init__() - - self.norm1 = nn.LayerNorm(hidden_size) - self.norm_context = nn.LayerNorm(hidden_size) - self.norm2 = nn.LayerNorm(hidden_size) - - self.cross_attn = nn.MultiheadAttention( - embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs - ) - - mlp_hidden_dim = int(hidden_size * mlp_ratio) - - self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) - - def forward(self, x, context, mask=None): - # Normalize inputs - x = self.norm1(x) - context = self.norm_context(context) - - # Apply cross attention - # Note: nn.MultiheadAttention returns attn_output, attn_output_weights - attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask) - - # Add & Norm - x = x + attn_output - x = x + self.mlp(self.norm2(x)) - return x diff --git a/capvector-pi05/src/vggt/heads/track_modules/utils.py b/capvector-pi05/src/vggt/heads/track_modules/utils.py deleted file mode 100644 index 3fc9486f5d070c882273de1165e6f1322b6c5ce4..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/heads/track_modules/utils.py +++ /dev/null @@ -1,223 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -# Modified from https://github.com/facebookresearch/vggsfm -# and https://github.com/facebookresearch/co-tracker/tree/main - - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from typing import Optional, Tuple, Union - - -def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor: - """ - This function initializes a grid and generates a 2D positional embedding using sine and cosine functions. - It is a wrapper of get_2d_sincos_pos_embed_from_grid. - Args: - - embed_dim: The embedding dimension. - - grid_size: The grid size. - Returns: - - pos_embed: The generated 2D positional embedding. - """ - if isinstance(grid_size, tuple): - grid_size_h, grid_size_w = grid_size - else: - grid_size_h = grid_size_w = grid_size - grid_h = torch.arange(grid_size_h, dtype=torch.float) - grid_w = torch.arange(grid_size_w, dtype=torch.float) - grid = torch.meshgrid(grid_w, grid_h, indexing="xy") - grid = torch.stack(grid, dim=0) - grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) - if return_grid: - return (pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2), grid) - return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2) - - -def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor: - """ - This function generates a 2D positional embedding from a given grid using sine and cosine functions. - - Args: - - embed_dim: The embedding dimension. - - grid: The grid to generate the embedding from. - - Returns: - - emb: The generated 2D positional embedding. - """ - assert embed_dim % 2 == 0 - - # use half of dimensions to encode grid_h - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) - - emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D) - return emb - - -def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor: - """ - This function generates a 1D positional embedding from a given grid using sine and cosine functions. - - Args: - - embed_dim: The embedding dimension. - - pos: The position to generate the embedding from. - - Returns: - - emb: The generated 1D positional embedding. - """ - assert embed_dim % 2 == 0 - omega = torch.arange(embed_dim // 2, dtype=torch.double) - omega /= embed_dim / 2.0 - omega = 1.0 / 10000**omega # (D/2,) - - pos = pos.reshape(-1) # (M,) - out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product - - emb_sin = torch.sin(out) # (M, D/2) - emb_cos = torch.cos(out) # (M, D/2) - - emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) - return emb[None].float() - - -def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor: - """ - This function generates a 2D positional embedding from given coordinates using sine and cosine functions. - - Args: - - xy: The coordinates to generate the embedding from. - - C: The size of the embedding. - - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding. - - Returns: - - pe: The generated 2D positional embedding. - """ - B, N, D = xy.shape - assert D == 2 - - x = xy[:, :, 0:1] - y = xy[:, :, 1:2] - div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2)) - - pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) - pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) - - pe_x[:, :, 0::2] = torch.sin(x * div_term) - pe_x[:, :, 1::2] = torch.cos(x * div_term) - - pe_y[:, :, 0::2] = torch.sin(y * div_term) - pe_y[:, :, 1::2] = torch.cos(y * div_term) - - pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3) - if cat_coords: - pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3) - return pe - - -def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"): - r"""Sample a tensor using bilinear interpolation - - `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at - coordinates :attr:`coords` using bilinear interpolation. It is the same - as `torch.nn.functional.grid_sample()` but with a different coordinate - convention. - - The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where - :math:`B` is the batch size, :math:`C` is the number of channels, - :math:`H` is the height of the image, and :math:`W` is the width of the - image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is - interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`. - - Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`, - in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note - that in this case the order of the components is slightly different - from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`. - - If `align_corners` is `True`, the coordinate :math:`x` is assumed to be - in the range :math:`[0,W-1]`, with 0 corresponding to the center of the - left-most image pixel :math:`W-1` to the center of the right-most - pixel. - - If `align_corners` is `False`, the coordinate :math:`x` is assumed to - be in the range :math:`[0,W]`, with 0 corresponding to the left edge of - the left-most pixel :math:`W` to the right edge of the right-most - pixel. - - Similar conventions apply to the :math:`y` for the range - :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range - :math:`[0,T-1]` and :math:`[0,T]`. - - Args: - input (Tensor): batch of input images. - coords (Tensor): batch of coordinates. - align_corners (bool, optional): Coordinate convention. Defaults to `True`. - padding_mode (str, optional): Padding mode. Defaults to `"border"`. - - Returns: - Tensor: sampled points. - """ - coords = coords.detach().clone() - ############################################################ - # IMPORTANT: - coords = coords.to(input.device).to(input.dtype) - ############################################################ - - sizes = input.shape[2:] - - assert len(sizes) in [2, 3] - - if len(sizes) == 3: - # t x y -> x y t to match dimensions T H W in grid_sample - coords = coords[..., [1, 2, 0]] - - if align_corners: - scale = torch.tensor( - [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype - ) - else: - scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype) - - coords.mul_(scale) # coords = coords * scale - coords.sub_(1) # coords = coords - 1 - - return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode) - - -def sample_features4d(input, coords): - r"""Sample spatial features - - `sample_features4d(input, coords)` samples the spatial features - :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`. - - The field is sampled at coordinates :attr:`coords` using bilinear - interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R, - 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the - same convention as :func:`bilinear_sampler` with `align_corners=True`. - - The output tensor has one feature per point, and has shape :math:`(B, - R, C)`. - - Args: - input (Tensor): spatial features. - coords (Tensor): points. - - Returns: - Tensor: sampled features. - """ - - B, _, _, _ = input.shape - - # B R 2 -> B R 1 2 - coords = coords.unsqueeze(2) - - # B C R 1 - feats = bilinear_sampler(input, coords) - - return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C diff --git a/capvector-pi05/src/vggt/heads/utils.py b/capvector-pi05/src/vggt/heads/utils.py deleted file mode 100644 index 1804227cff9d9fde67712bbb7e5d64b4be88d6cf..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/heads/utils.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import torch.nn as nn -import numpy as np -from typing import List, Dict, Tuple, Union -from einops import rearrange - -def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor: - """ - Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC) - - Args: - pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates - embed_dim: Output channel dimension for embeddings - - Returns: - Tensor of shape (H, W, embed_dim) with positional embeddings - """ - H, W, grid_dim = pos_grid.shape - assert grid_dim == 2 - pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2) - - # Process x and y coordinates separately - emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2] - emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2] - - # Combine and reshape - emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D] - - return emb.view(H, W, embed_dim) # [H, W, D] - - -def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor: - """ - This function generates a 1D positional embedding from a given grid using sine and cosine functions. - - Args: - - embed_dim: The embedding dimension. - - pos: The position to generate the embedding from. - - Returns: - - emb: The generated 1D positional embedding. - """ - assert embed_dim % 2 == 0 - device = pos.device - omega = torch.arange(embed_dim // 2, dtype=torch.float32 if device.type == "mps" else torch.double, device=device) - omega /= embed_dim / 2.0 - omega = 1.0 / omega_0**omega # (D/2,) - - pos = pos.reshape(-1) # (M,) - out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product - - emb_sin = torch.sin(out) # (M, D/2) - emb_cos = torch.cos(out) # (M, D/2) - - emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) - return emb.float() - - -# Inspired by https://github.com/microsoft/moge - - -def create_uv_grid( - width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None -) -> torch.Tensor: - """ - Create a normalized UV grid of shape (width, height, 2). - - The grid spans horizontally and vertically according to an aspect ratio, - ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right - corner is at (x_span, y_span), normalized by the diagonal of the plane. - - Args: - width (int): Number of points horizontally. - height (int): Number of points vertically. - aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height. - dtype (torch.dtype, optional): Data type of the resulting tensor. - device (torch.device, optional): Device on which the tensor is created. - - Returns: - torch.Tensor: A (width, height, 2) tensor of UV coordinates. - """ - # Derive aspect ratio if not explicitly provided - if aspect_ratio is None: - aspect_ratio = float(width) / float(height) - - # Compute normalized spans for X and Y - diag_factor = (aspect_ratio**2 + 1.0) ** 0.5 - span_x = aspect_ratio / diag_factor - span_y = 1.0 / diag_factor - - # Establish the linspace boundaries - left_x = -span_x * (width - 1) / width - right_x = span_x * (width - 1) / width - top_y = -span_y * (height - 1) / height - bottom_y = span_y * (height - 1) / height - - # Generate 1D coordinates - x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device) - y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device) - - # Create 2D meshgrid (width x height) and stack into UV - uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy") - uv_grid = torch.stack((uu, vv), dim=-1) - - return uv_grid - - -def _interpolate( - x: torch.Tensor, - size: Tuple[int, int] = None, - scale_factor: float = None, - mode: str = "bilinear", - align_corners: bool = True, -) -> torch.Tensor: - """ - Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate. - """ - if size is None: - size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor)) - - INT_MAX = 1610612736 - - input_elements = size[0] * size[1] * x.shape[0] * x.shape[1] - - if input_elements > INT_MAX: - chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0) - interpolated_chunks = [ - nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks - ] - x = torch.cat(interpolated_chunks, dim=0) - return x.contiguous() - else: - return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners) - - -def _apply_pos_embed(x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor: - """ - Apply positional embedding to tensor x. - """ - patch_w = x.shape[-1] - patch_h = x.shape[-2] - pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device) - pos_embed = position_grid_to_embed(pos_embed, x.shape[1]) - pos_embed = pos_embed * ratio - pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1) - return x + pos_embed - - -def interpolate_pooling(hidden, patch_hw, img_hw, reference, pooling_func, use_vggt_pe): - (patch_h, patch_w) = patch_hw - (img_h, img_w) = img_hw - bs, N, S, D = hidden.shape - re_sample_ratio = 1 / np.sqrt(N * S / reference.shape[1]) - - _hidden = hidden.permute(0, 1, 3, 2) - _hidden = _hidden.reshape(bs*N, D, patch_h, patch_w) - if use_vggt_pe: - _hidden = _apply_pos_embed(_hidden, img_w, img_h) - hidden_pooling = _interpolate( - _hidden, scale_factor=re_sample_ratio, mode=pooling_func, align_corners=True - ) - hidden_pooling = hidden_pooling.reshape(bs, N, D, -1).permute(0, 1, 3, 2).reshape(bs, -1, D) - return hidden_pooling - - -def custom_pooling(hidden, patch_hw, img_hw, reference, pooling_func, use_vggt_pe): - if pooling_func in ['bilinear']: - return interpolate_pooling(hidden, patch_hw, img_hw, reference, pooling_func, use_vggt_pe) - else: - raise NotImplementedError(f"Pooling function {pooling_func} is not implemented.") diff --git a/capvector-pi05/src/vggt/layers/__init__.py b/capvector-pi05/src/vggt/layers/__init__.py deleted file mode 100644 index e59a83eb90512d763b03e4d38536b6ae07e87541..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/layers/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from .mlp import Mlp -from .patch_embed import PatchEmbed -from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused -from .block import NestedTensorBlock -from .attention import MemEffAttention diff --git a/capvector-pi05/src/vggt/layers/attention.py b/capvector-pi05/src/vggt/layers/attention.py deleted file mode 100644 index 27329716a95a1c3e70a12e74b3be5fe79f2663f9..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/layers/attention.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -# References: -# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py -# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py - -import logging -import os -import warnings - -from torch import Tensor -from torch import nn -import torch.nn.functional as F - -XFORMERS_AVAILABLE = False - - -class Attention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = True, - proj_bias: bool = True, - attn_drop: float = 0.0, - proj_drop: float = 0.0, - norm_layer: nn.Module = nn.LayerNorm, - qk_norm: bool = False, - fused_attn: bool = True, # use F.scaled_dot_product_attention or not - rope=None, - ) -> None: - super().__init__() - assert dim % num_heads == 0, "dim should be divisible by num_heads" - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = self.head_dim**-0.5 - self.fused_attn = fused_attn - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim, bias=proj_bias) - self.proj_drop = nn.Dropout(proj_drop) - self.rope = rope - - def forward(self, x: Tensor, pos=None) -> Tensor: - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) - q, k, v = qkv.unbind(0) - q, k = self.q_norm(q), self.k_norm(k) - - if self.rope is not None: - q = self.rope(q, pos) - k = self.rope(k, pos) - - if self.fused_attn: - x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0) - else: - q = q * self.scale - attn = q @ k.transpose(-2, -1) - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - x = attn @ v - - x = x.transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - -class MemEffAttention(Attention): - def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor: - assert pos is None - if not XFORMERS_AVAILABLE: - if attn_bias is not None: - raise AssertionError("xFormers is required for using nested tensors") - return super().forward(x) - - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) - - q, k, v = unbind(qkv, 2) - - x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) - x = x.reshape([B, N, C]) - - x = self.proj(x) - x = self.proj_drop(x) - return x diff --git a/capvector-pi05/src/vggt/layers/block.py b/capvector-pi05/src/vggt/layers/block.py deleted file mode 100644 index 15fa99ce76d14b6b9c2e98c1031fa35a3046a429..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/layers/block.py +++ /dev/null @@ -1,247 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -# References: -# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py -# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py - -import logging -import os -from typing import Callable, List, Any, Tuple, Dict -import warnings - -import torch -from torch import nn, Tensor - -from .attention import Attention -from .drop_path import DropPath -from .layer_scale import LayerScale -from .mlp import Mlp - - -XFORMERS_AVAILABLE = False - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - mlp_ratio: float = 4.0, - qkv_bias: bool = True, - proj_bias: bool = True, - ffn_bias: bool = True, - drop: float = 0.0, - attn_drop: float = 0.0, - init_values=None, - drop_path: float = 0.0, - act_layer: Callable[..., nn.Module] = nn.GELU, - norm_layer: Callable[..., nn.Module] = nn.LayerNorm, - attn_class: Callable[..., nn.Module] = Attention, - ffn_layer: Callable[..., nn.Module] = Mlp, - qk_norm: bool = False, - fused_attn: bool = True, # use F.scaled_dot_product_attention or not - rope=None, - ) -> None: - super().__init__() - - self.norm1 = norm_layer(dim) - - self.attn = attn_class( - dim, - num_heads=num_heads, - qkv_bias=qkv_bias, - proj_bias=proj_bias, - attn_drop=attn_drop, - proj_drop=drop, - qk_norm=qk_norm, - fused_attn=fused_attn, - rope=rope, - ) - - self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() - self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = ffn_layer( - in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, bias=ffn_bias - ) - self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() - self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - - self.sample_drop_ratio = drop_path - - def forward(self, x: Tensor, pos=None) -> Tensor: - def attn_residual_func(x: Tensor, pos=None) -> Tensor: - return self.ls1(self.attn(self.norm1(x), pos=pos)) - - def ffn_residual_func(x: Tensor) -> Tensor: - return self.ls2(self.mlp(self.norm2(x))) - - if self.training and self.sample_drop_ratio > 0.1: - # the overhead is compensated only for a drop path rate larger than 0.1 - x = drop_add_residual_stochastic_depth( - x, pos=pos, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio - ) - x = drop_add_residual_stochastic_depth( - x, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio - ) - elif self.training and self.sample_drop_ratio > 0.0: - x = x + self.drop_path1(attn_residual_func(x, pos=pos)) - x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 - else: - x = x + attn_residual_func(x, pos=pos) - x = x + ffn_residual_func(x) - return x - - -def drop_add_residual_stochastic_depth( - x: Tensor, residual_func: Callable[[Tensor], Tensor], sample_drop_ratio: float = 0.0, pos=None -) -> Tensor: - # 1) extract subset using permutation - b, n, d = x.shape - sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) - brange = (torch.randperm(b, device=x.device))[:sample_subset_size] - x_subset = x[brange] - - # 2) apply residual_func to get residual - if pos is not None: - # if necessary, apply rope to the subset - pos = pos[brange] - residual = residual_func(x_subset, pos=pos) - else: - residual = residual_func(x_subset) - - x_flat = x.flatten(1) - residual = residual.flatten(1) - - residual_scale_factor = b / sample_subset_size - - # 3) add the residual - x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) - return x_plus_residual.view_as(x) - - -def get_branges_scales(x, sample_drop_ratio=0.0): - b, n, d = x.shape - sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) - brange = (torch.randperm(b, device=x.device))[:sample_subset_size] - residual_scale_factor = b / sample_subset_size - return brange, residual_scale_factor - - -def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): - if scaling_vector is None: - x_flat = x.flatten(1) - residual = residual.flatten(1) - x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) - else: - x_plus_residual = scaled_index_add( - x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor - ) - return x_plus_residual - - -attn_bias_cache: Dict[Tuple, Any] = {} - - -def get_attn_bias_and_cat(x_list, branges=None): - """ - this will perform the index select, cat the tensors, and provide the attn_bias from cache - """ - batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] - all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) - if all_shapes not in attn_bias_cache.keys(): - seqlens = [] - for b, x in zip(batch_sizes, x_list): - for _ in range(b): - seqlens.append(x.shape[1]) - attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) - attn_bias._batch_sizes = batch_sizes - attn_bias_cache[all_shapes] = attn_bias - - if branges is not None: - cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) - else: - tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) - cat_tensors = torch.cat(tensors_bs1, dim=1) - - return attn_bias_cache[all_shapes], cat_tensors - - -def drop_add_residual_stochastic_depth_list( - x_list: List[Tensor], - residual_func: Callable[[Tensor, Any], Tensor], - sample_drop_ratio: float = 0.0, - scaling_vector=None, -) -> Tensor: - # 1) generate random set of indices for dropping samples in the batch - branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] - branges = [s[0] for s in branges_scales] - residual_scale_factors = [s[1] for s in branges_scales] - - # 2) get attention bias and index+concat the tensors - attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) - - # 3) apply residual_func to get residual, and split the result - residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore - - outputs = [] - for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): - outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) - return outputs - - -class NestedTensorBlock(Block): - def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: - """ - x_list contains a list of tensors to nest together and run - """ - assert isinstance(self.attn, MemEffAttention) - - if self.training and self.sample_drop_ratio > 0.0: - - def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: - return self.attn(self.norm1(x), attn_bias=attn_bias) - - def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: - return self.mlp(self.norm2(x)) - - x_list = drop_add_residual_stochastic_depth_list( - x_list, - residual_func=attn_residual_func, - sample_drop_ratio=self.sample_drop_ratio, - scaling_vector=(self.ls1.gamma if isinstance(self.ls1, LayerScale) else None), - ) - x_list = drop_add_residual_stochastic_depth_list( - x_list, - residual_func=ffn_residual_func, - sample_drop_ratio=self.sample_drop_ratio, - scaling_vector=(self.ls2.gamma if isinstance(self.ls1, LayerScale) else None), - ) - return x_list - else: - - def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: - return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) - - def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: - return self.ls2(self.mlp(self.norm2(x))) - - attn_bias, x = get_attn_bias_and_cat(x_list) - x = x + attn_residual_func(x, attn_bias=attn_bias) - x = x + ffn_residual_func(x) - return attn_bias.split(x) - - def forward(self, x_or_x_list): - if isinstance(x_or_x_list, Tensor): - return super().forward(x_or_x_list) - elif isinstance(x_or_x_list, list): - if not XFORMERS_AVAILABLE: - raise AssertionError("xFormers is required for using nested tensors") - return self.forward_nested(x_or_x_list) - else: - raise AssertionError diff --git a/capvector-pi05/src/vggt/layers/drop_path.py b/capvector-pi05/src/vggt/layers/drop_path.py deleted file mode 100644 index 4bb1487b0eed4cb14dc0d5d1ee57a2acc78de34a..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/layers/drop_path.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -# References: -# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py -# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py - - -from torch import nn - - -def drop_path(x, drop_prob: float = 0.0, training: bool = False): - if drop_prob == 0.0 or not training: - return x - keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = x.new_empty(shape).bernoulli_(keep_prob) - if keep_prob > 0.0: - random_tensor.div_(keep_prob) - output = x * random_tensor - return output - - -class DropPath(nn.Module): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - - def __init__(self, drop_prob=None): - super(DropPath, self).__init__() - self.drop_prob = drop_prob - - def forward(self, x): - return drop_path(x, self.drop_prob, self.training) diff --git a/capvector-pi05/src/vggt/layers/layer_scale.py b/capvector-pi05/src/vggt/layers/layer_scale.py deleted file mode 100644 index 9047736a9fcd57a091aac8d42a8c07cc348cd1b3..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/layers/layer_scale.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 - -from typing import Union - -import torch -from torch import Tensor -from torch import nn - - -class LayerScale(nn.Module): - def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False) -> None: - super().__init__() - self.inplace = inplace - self.gamma = nn.Parameter(init_values * torch.ones(dim)) - - def forward(self, x: Tensor) -> Tensor: - return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/capvector-pi05/src/vggt/layers/mlp.py b/capvector-pi05/src/vggt/layers/mlp.py deleted file mode 100644 index 0965768a9aef04ac6b81322f4dd60cf035159e91..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/layers/mlp.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -# References: -# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py -# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py - - -from typing import Callable, Optional - -from torch import Tensor, nn - - -class Mlp(nn.Module): - def __init__( - self, - in_features: int, - hidden_features: Optional[int] = None, - out_features: Optional[int] = None, - act_layer: Callable[..., nn.Module] = nn.GELU, - drop: float = 0.0, - bias: bool = True, - ) -> None: - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) - self.drop = nn.Dropout(drop) - - def forward(self, x: Tensor) -> Tensor: - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x diff --git a/capvector-pi05/src/vggt/layers/patch_embed.py b/capvector-pi05/src/vggt/layers/patch_embed.py deleted file mode 100644 index 7244ad8e3b956417f52b4bcea1aefb3796fc7e59..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/layers/patch_embed.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -# References: -# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py -# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py - -from typing import Callable, Optional, Tuple, Union - -from torch import Tensor -import torch.nn as nn - - -def make_2tuple(x): - if isinstance(x, tuple): - assert len(x) == 2 - return x - - assert isinstance(x, int) - return (x, x) - - -class PatchEmbed(nn.Module): - """ - 2D image to patch embedding: (B,C,H,W) -> (B,N,D) - - Args: - img_size: Image size. - patch_size: Patch token size. - in_chans: Number of input image channels. - embed_dim: Number of linear projection output channels. - norm_layer: Normalization layer. - """ - - def __init__( - self, - img_size: Union[int, Tuple[int, int]] = 224, - patch_size: Union[int, Tuple[int, int]] = 16, - in_chans: int = 3, - embed_dim: int = 768, - norm_layer: Optional[Callable] = None, - flatten_embedding: bool = True, - ) -> None: - super().__init__() - - image_HW = make_2tuple(img_size) - patch_HW = make_2tuple(patch_size) - patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1]) - - self.img_size = image_HW - self.patch_size = patch_HW - self.patches_resolution = patch_grid_size - self.num_patches = patch_grid_size[0] * patch_grid_size[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - self.flatten_embedding = flatten_embedding - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - - def forward(self, x: Tensor) -> Tensor: - _, _, H, W = x.shape - patch_H, patch_W = self.patch_size - - assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" - assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" - - x = self.proj(x) # B C H W - H, W = x.size(2), x.size(3) - x = x.flatten(2).transpose(1, 2) # B HW C - x = self.norm(x) - if not self.flatten_embedding: - x = x.reshape(-1, H, W, self.embed_dim) # B H W C - return x - - def flops(self) -> float: - Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) - if self.norm is not None: - flops += Ho * Wo * self.embed_dim - return flops diff --git a/capvector-pi05/src/vggt/layers/rope.py b/capvector-pi05/src/vggt/layers/rope.py deleted file mode 100644 index 107ff7a2267e936bc01c4dbd576e1da4f038f904..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/layers/rope.py +++ /dev/null @@ -1,188 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - - -# Implementation of 2D Rotary Position Embeddings (RoPE). - -# This module provides a clean implementation of 2D Rotary Position Embeddings, -# which extends the original RoPE concept to handle 2D spatial positions. - -# Inspired by: -# https://github.com/meta-llama/codellama/blob/main/llama/model.py -# https://github.com/naver-ai/rope-vit - - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from typing import Dict, Tuple - - -class PositionGetter: - """Generates and caches 2D spatial positions for patches in a grid. - - This class efficiently manages the generation of spatial coordinates for patches - in a 2D grid, caching results to avoid redundant computations. - - Attributes: - position_cache: Dictionary storing precomputed position tensors for different - grid dimensions. - """ - - def __init__(self): - """Initializes the position generator with an empty cache.""" - self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {} - - def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor: - """Generates spatial positions for a batch of patches. - - Args: - batch_size: Number of samples in the batch. - height: Height of the grid in patches. - width: Width of the grid in patches. - device: Target device for the position tensor. - - Returns: - Tensor of shape (batch_size, height*width, 2) containing y,x coordinates - for each position in the grid, repeated for each batch item. - """ - if (height, width) not in self.position_cache: - y_coords = torch.arange(height, device=device) - x_coords = torch.arange(width, device=device) - positions = torch.cartesian_prod(y_coords, x_coords) - self.position_cache[height, width] = positions - - cached_positions = self.position_cache[height, width] - return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone() - - -class RotaryPositionEmbedding2D(nn.Module): - """2D Rotary Position Embedding implementation. - - This module applies rotary position embeddings to input tokens based on their - 2D spatial positions. It handles the position-dependent rotation of features - separately for vertical and horizontal dimensions. - - Args: - frequency: Base frequency for the position embeddings. Default: 100.0 - scaling_factor: Scaling factor for frequency computation. Default: 1.0 - - Attributes: - base_frequency: Base frequency for computing position embeddings. - scaling_factor: Factor to scale the computed frequencies. - frequency_cache: Cache for storing precomputed frequency components. - """ - - def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0): - """Initializes the 2D RoPE module.""" - super().__init__() - self.base_frequency = frequency - self.scaling_factor = scaling_factor - self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {} - - def _compute_frequency_components( - self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Computes frequency components for rotary embeddings. - - Args: - dim: Feature dimension (must be even). - seq_len: Maximum sequence length. - device: Target device for computations. - dtype: Data type for the computed tensors. - - Returns: - Tuple of (cosine, sine) tensors for frequency components. - """ - cache_key = (dim, seq_len, device, dtype) - if cache_key not in self.frequency_cache: - # Compute frequency bands - exponents = torch.arange(0, dim, 2, device=device).float() / dim - inv_freq = 1.0 / (self.base_frequency**exponents) - - # Generate position-dependent frequencies - positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) - angles = torch.einsum("i,j->ij", positions, inv_freq) - - # Compute and cache frequency components - angles = angles.to(dtype) - angles = torch.cat((angles, angles), dim=-1) - cos_components = angles.cos().to(dtype) - sin_components = angles.sin().to(dtype) - self.frequency_cache[cache_key] = (cos_components, sin_components) - - return self.frequency_cache[cache_key] - - @staticmethod - def _rotate_features(x: torch.Tensor) -> torch.Tensor: - """Performs feature rotation by splitting and recombining feature dimensions. - - Args: - x: Input tensor to rotate. - - Returns: - Rotated feature tensor. - """ - feature_dim = x.shape[-1] - x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :] - return torch.cat((-x2, x1), dim=-1) - - def _apply_1d_rope( - self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor - ) -> torch.Tensor: - """Applies 1D rotary position embeddings along one dimension. - - Args: - tokens: Input token features. - positions: Position indices. - cos_comp: Cosine components for rotation. - sin_comp: Sine components for rotation. - - Returns: - Tokens with applied rotary position embeddings. - """ - # Embed positions with frequency components - cos = F.embedding(positions, cos_comp)[:, None, :, :] - sin = F.embedding(positions, sin_comp)[:, None, :, :] - - # Apply rotation - return (tokens * cos) + (self._rotate_features(tokens) * sin) - - def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: - """Applies 2D rotary position embeddings to input tokens. - - Args: - tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim). - The feature dimension (dim) must be divisible by 4. - positions: Position tensor of shape (batch_size, n_tokens, 2) containing - the y and x coordinates for each token. - - Returns: - Tensor of same shape as input with applied 2D rotary position embeddings. - - Raises: - AssertionError: If input dimensions are invalid or positions are malformed. - """ - # Validate inputs - assert tokens.size(-1) % 2 == 0, "Feature dimension must be even" - assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)" - - # Compute feature dimension for each spatial direction - feature_dim = tokens.size(-1) // 2 - - # Get frequency components - max_position = int(positions.max()) + 1 - cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype) - - # Split features for vertical and horizontal processing - vertical_features, horizontal_features = tokens.chunk(2, dim=-1) - - # Apply RoPE separately for each dimension - vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp) - horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp) - - # Combine processed features - return torch.cat((vertical_features, horizontal_features), dim=-1) diff --git a/capvector-pi05/src/vggt/layers/swiglu_ffn.py b/capvector-pi05/src/vggt/layers/swiglu_ffn.py deleted file mode 100644 index 9c6b6c74f97e61041ecef912ea21c2d259335aa7..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/layers/swiglu_ffn.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import os -from typing import Callable, Optional -import warnings - -from torch import Tensor, nn -import torch.nn.functional as F - - -class SwiGLUFFN(nn.Module): - def __init__( - self, - in_features: int, - hidden_features: Optional[int] = None, - out_features: Optional[int] = None, - act_layer: Callable[..., nn.Module] = None, - drop: float = 0.0, - bias: bool = True, - ) -> None: - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) - self.w3 = nn.Linear(hidden_features, out_features, bias=bias) - - def forward(self, x: Tensor) -> Tensor: - x12 = self.w12(x) - x1, x2 = x12.chunk(2, dim=-1) - hidden = F.silu(x1) * x2 - return self.w3(hidden) - - -XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None -# try: -# if XFORMERS_ENABLED: -# from xformers.ops import SwiGLU - -# XFORMERS_AVAILABLE = True -# warnings.warn("xFormers is available (SwiGLU)") -# else: -# warnings.warn("xFormers is disabled (SwiGLU)") -# raise ImportError -# except ImportError: -SwiGLU = SwiGLUFFN -XFORMERS_AVAILABLE = False - -# warnings.warn("xFormers is not available (SwiGLU)") - - -class SwiGLUFFNFused(SwiGLU): - def __init__( - self, - in_features: int, - hidden_features: Optional[int] = None, - out_features: Optional[int] = None, - act_layer: Callable[..., nn.Module] = None, - drop: float = 0.0, - bias: bool = True, - ) -> None: - out_features = out_features or in_features - hidden_features = hidden_features or in_features - hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 - super().__init__(in_features=in_features, hidden_features=hidden_features, out_features=out_features, bias=bias) diff --git a/capvector-pi05/src/vggt/layers/vision_transformer.py b/capvector-pi05/src/vggt/layers/vision_transformer.py deleted file mode 100644 index ced58dd042a84b44ca97ce3f25d3983f322a8e27..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/layers/vision_transformer.py +++ /dev/null @@ -1,397 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -# References: -# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py -# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py - -from functools import partial -import math -import logging -from typing import Sequence, Tuple, Union, Callable - -import torch -import torch.nn as nn -from torch.utils.checkpoint import checkpoint -from torch.nn.init import trunc_normal_ -from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block - -logger = logging.getLogger("dinov2") - - -def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: - if not depth_first and include_root: - fn(module=module, name=name) - for child_name, child_module in module.named_children(): - child_name = ".".join((name, child_name)) if name else child_name - named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) - if depth_first and include_root: - fn(module=module, name=name) - return module - - -class BlockChunk(nn.ModuleList): - def forward(self, x): - for b in self: - x = b(x) - return x - - -class DinoVisionTransformer(nn.Module): - def __init__( - self, - img_size=224, - patch_size=16, - in_chans=3, - embed_dim=768, - depth=12, - num_heads=12, - mlp_ratio=4.0, - qkv_bias=True, - ffn_bias=True, - proj_bias=True, - drop_path_rate=0.0, - drop_path_uniform=False, - init_values=None, # for layerscale: None or 0 => no layerscale - embed_layer=PatchEmbed, - act_layer=nn.GELU, - block_fn=Block, - ffn_layer="mlp", - block_chunks=1, - num_register_tokens=0, - interpolate_antialias=False, - interpolate_offset=0.1, - qk_norm=False, - ): - """ - Args: - img_size (int, tuple): input image size - patch_size (int, tuple): patch size - in_chans (int): number of input channels - embed_dim (int): embedding dimension - depth (int): depth of transformer - num_heads (int): number of attention heads - mlp_ratio (int): ratio of mlp hidden dim to embedding dim - qkv_bias (bool): enable bias for qkv if True - proj_bias (bool): enable bias for proj in attn if True - ffn_bias (bool): enable bias for ffn if True - drop_path_rate (float): stochastic depth rate - drop_path_uniform (bool): apply uniform drop rate across blocks - weight_init (str): weight init scheme - init_values (float): layer-scale init values - embed_layer (nn.Module): patch embedding layer - act_layer (nn.Module): MLP activation layer - block_fn (nn.Module): transformer block class - ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" - block_chunks: (int) split block sequence into block_chunks units for FSDP wrap - num_register_tokens: (int) number of extra cls tokens (so-called "registers") - interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings - interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings - """ - super().__init__() - norm_layer = partial(nn.LayerNorm, eps=1e-6) - - self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models - self.num_tokens = 1 - self.n_blocks = depth - self.num_heads = num_heads - self.patch_size = patch_size - self.num_register_tokens = num_register_tokens - self.interpolate_antialias = interpolate_antialias - self.interpolate_offset = interpolate_offset - self.use_reentrant = False # hardcoded to False - - self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) - num_patches = self.patch_embed.num_patches - - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) - assert num_register_tokens >= 0 - self.register_tokens = ( - nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None - ) - - if drop_path_uniform is True: - dpr = [drop_path_rate] * depth - else: - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule - - if ffn_layer == "mlp": - logger.info("using MLP layer as FFN") - ffn_layer = Mlp - elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": - logger.info("using SwiGLU layer as FFN") - ffn_layer = SwiGLUFFNFused - elif ffn_layer == "identity": - logger.info("using Identity layer as FFN") - - def f(*args, **kwargs): - return nn.Identity() - - ffn_layer = f - else: - raise NotImplementedError - - blocks_list = [ - block_fn( - dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - proj_bias=proj_bias, - ffn_bias=ffn_bias, - drop_path=dpr[i], - norm_layer=norm_layer, - act_layer=act_layer, - ffn_layer=ffn_layer, - init_values=init_values, - qk_norm=qk_norm, - ) - for i in range(depth) - ] - if block_chunks > 0: - self.chunked_blocks = True - chunked_blocks = [] - chunksize = depth // block_chunks - for i in range(0, depth, chunksize): - # this is to keep the block index consistent if we chunk the block list - chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) - self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) - else: - self.chunked_blocks = False - self.blocks = nn.ModuleList(blocks_list) - - self.norm = norm_layer(embed_dim) - self.head = nn.Identity() - - self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) - - self.init_weights() - - def init_weights(self): - trunc_normal_(self.pos_embed, std=0.02) - nn.init.normal_(self.cls_token, std=1e-6) - if self.register_tokens is not None: - nn.init.normal_(self.register_tokens, std=1e-6) - named_apply(init_weights_vit_timm, self) - - def interpolate_pos_encoding(self, x, w, h): - previous_dtype = x.dtype - npatch = x.shape[1] - 1 - N = self.pos_embed.shape[1] - 1 - if npatch == N and w == h: - return self.pos_embed - pos_embed = self.pos_embed.float() - class_pos_embed = pos_embed[:, 0] - patch_pos_embed = pos_embed[:, 1:] - dim = x.shape[-1] - w0 = w // self.patch_size - h0 = h // self.patch_size - M = int(math.sqrt(N)) # Recover the number of patches in each dimension - assert N == M * M - kwargs = {} - if self.interpolate_offset: - # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 - # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors - sx = float(w0 + self.interpolate_offset) / M - sy = float(h0 + self.interpolate_offset) / M - kwargs["scale_factor"] = (sx, sy) - else: - # Simply specify an output size instead of a scale factor - kwargs["size"] = (w0, h0) - patch_pos_embed = nn.functional.interpolate( - patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), - mode="bicubic", - antialias=self.interpolate_antialias, - **kwargs, - ) - assert (w0, h0) == patch_pos_embed.shape[-2:] - patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) - - def prepare_tokens_with_masks(self, x, masks=None): - B, nc, w, h = x.shape - x = self.patch_embed(x) - if masks is not None: - x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) - - x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) - x = x + self.interpolate_pos_encoding(x, w, h) - - if self.register_tokens is not None: - x = torch.cat((x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]), dim=1) - - return x - - def forward_features_list(self, x_list, masks_list): - x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] - - for blk in self.blocks: - if self.training: - x = checkpoint(blk, x, use_reentrant=self.use_reentrant) - else: - x = blk(x) - - all_x = x - output = [] - for x, masks in zip(all_x, masks_list): - x_norm = self.norm(x) - output.append( - { - "x_norm_clstoken": x_norm[:, 0], - "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], - "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], - "x_prenorm": x, - "masks": masks, - } - ) - return output - - def forward_features(self, x, masks=None): - if isinstance(x, list): - return self.forward_features_list(x, masks) - - x = self.prepare_tokens_with_masks(x, masks) - - for blk in self.blocks: - if self.training: - x = checkpoint(blk, x, use_reentrant=self.use_reentrant) - else: - x = blk(x) - - x_norm = self.norm(x) - return { - "x_norm_clstoken": x_norm[:, 0], - "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], - "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], - "x_prenorm": x, - "masks": masks, - } - - def _get_intermediate_layers_not_chunked(self, x, n=1): - x = self.prepare_tokens_with_masks(x) - # If n is an int, take the n last blocks. If it's a list, take them - output, total_block_len = [], len(self.blocks) - blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n - for i, blk in enumerate(self.blocks): - x = blk(x) - if i in blocks_to_take: - output.append(x) - assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" - return output - - def _get_intermediate_layers_chunked(self, x, n=1): - x = self.prepare_tokens_with_masks(x) - output, i, total_block_len = [], 0, len(self.blocks[-1]) - # If n is an int, take the n last blocks. If it's a list, take them - blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n - for block_chunk in self.blocks: - for blk in block_chunk[i:]: # Passing the nn.Identity() - x = blk(x) - if i in blocks_to_take: - output.append(x) - i += 1 - assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" - return output - - def get_intermediate_layers( - self, - x: torch.Tensor, - n: Union[int, Sequence] = 1, # Layers or n last layers to take - reshape: bool = False, - return_class_token: bool = False, - norm=True, - ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: - if self.chunked_blocks: - outputs = self._get_intermediate_layers_chunked(x, n) - else: - outputs = self._get_intermediate_layers_not_chunked(x, n) - if norm: - outputs = [self.norm(out) for out in outputs] - class_tokens = [out[:, 0] for out in outputs] - outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] - if reshape: - B, _, w, h = x.shape - outputs = [ - out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() - for out in outputs - ] - if return_class_token: - return tuple(zip(outputs, class_tokens)) - return tuple(outputs) - - def forward(self, *args, is_training=True, **kwargs): - ret = self.forward_features(*args, **kwargs) - if is_training: - return ret - else: - return self.head(ret["x_norm_clstoken"]) - - -def init_weights_vit_timm(module: nn.Module, name: str = ""): - """ViT weight initialization, original timm impl (for reproducibility)""" - if isinstance(module, nn.Linear): - trunc_normal_(module.weight, std=0.02) - if module.bias is not None: - nn.init.zeros_(module.bias) - - -def vit_small(patch_size=16, num_register_tokens=0, **kwargs): - model = DinoVisionTransformer( - patch_size=patch_size, - embed_dim=384, - depth=12, - num_heads=6, - mlp_ratio=4, - block_fn=partial(Block, attn_class=MemEffAttention), - num_register_tokens=num_register_tokens, - **kwargs, - ) - return model - - -def vit_base(patch_size=16, num_register_tokens=0, **kwargs): - model = DinoVisionTransformer( - patch_size=patch_size, - embed_dim=768, - depth=12, - num_heads=12, - mlp_ratio=4, - block_fn=partial(Block, attn_class=MemEffAttention), - num_register_tokens=num_register_tokens, - **kwargs, - ) - return model - - -def vit_large(patch_size=16, num_register_tokens=0, **kwargs): - model = DinoVisionTransformer( - patch_size=patch_size, - embed_dim=1024, - depth=24, - num_heads=16, - mlp_ratio=4, - block_fn=partial(Block, attn_class=MemEffAttention), - num_register_tokens=num_register_tokens, - **kwargs, - ) - return model - - -def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): - """ - Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 - """ - model = DinoVisionTransformer( - patch_size=patch_size, - embed_dim=1536, - depth=40, - num_heads=24, - mlp_ratio=4, - block_fn=partial(Block, attn_class=MemEffAttention), - num_register_tokens=num_register_tokens, - **kwargs, - ) - return model diff --git a/capvector-pi05/src/vggt/models/aggregator.py b/capvector-pi05/src/vggt/models/aggregator.py deleted file mode 100644 index 3ccc16110c82008a7652d05834a478226780a9b5..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/models/aggregator.py +++ /dev/null @@ -1,331 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import logging -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.utils.checkpoint import checkpoint -from typing import Optional, Tuple, Union, List, Dict, Any - -from vggt.layers import PatchEmbed -from vggt.layers.block import Block -from vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter -from vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2 - -logger = logging.getLogger(__name__) - -_RESNET_MEAN = [0.485, 0.456, 0.406] -_RESNET_STD = [0.229, 0.224, 0.225] - - -class Aggregator(nn.Module): - """ - The Aggregator applies alternating-attention over input frames, - as described in VGGT: Visual Geometry Grounded Transformer. - - Remember to set model.train() to enable gradient checkpointing to reduce memory usage. - - Args: - img_size (int): Image size in pixels. - patch_size (int): Size of each patch for PatchEmbed. - embed_dim (int): Dimension of the token embeddings. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - mlp_ratio (float): Ratio of MLP hidden dim to embedding dim. - num_register_tokens (int): Number of register tokens. - block_fn (nn.Module): The block type used for attention (Block by default). - qkv_bias (bool): Whether to include bias in QKV projections. - proj_bias (bool): Whether to include bias in the output projection. - ffn_bias (bool): Whether to include bias in MLP layers. - patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg". - aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"]. - aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1. - qk_norm (bool): Whether to apply QK normalization. - rope_freq (int): Base frequency for rotary embedding. -1 to disable. - init_values (float): Init scale for layer scale. - """ - - def __init__( - self, - img_size=518, - patch_size=14, - embed_dim=1024, - depth=24, - num_heads=16, - mlp_ratio=4.0, - num_register_tokens=4, - block_fn=Block, - qkv_bias=True, - proj_bias=True, - ffn_bias=True, - patch_embed="dinov2_vitl14_reg", - aa_order=["frame", "global"], - aa_block_size=1, - qk_norm=True, - rope_freq=100, - init_values=0.01, - ): - super().__init__() - - self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim) - - # Initialize rotary position embedding if frequency > 0 - self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None - self.position_getter = PositionGetter() if self.rope is not None else None - - self.frame_blocks = nn.ModuleList( - [ - block_fn( - dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - proj_bias=proj_bias, - ffn_bias=ffn_bias, - init_values=init_values, - qk_norm=qk_norm, - rope=self.rope, - ) - for _ in range(depth) - ] - ) - - self.global_blocks = nn.ModuleList( - [ - block_fn( - dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - proj_bias=proj_bias, - ffn_bias=ffn_bias, - init_values=init_values, - qk_norm=qk_norm, - rope=self.rope, - ) - for _ in range(depth) - ] - ) - - self.depth = depth - self.aa_order = aa_order - self.patch_size = patch_size - self.aa_block_size = aa_block_size - - # Validate that depth is divisible by aa_block_size - if self.depth % self.aa_block_size != 0: - raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})") - - self.aa_block_num = self.depth // self.aa_block_size - - # Note: We have two camera tokens, one for the first frame and one for the rest - # The same applies for register tokens - self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim)) - self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim)) - - # The patch tokens start after the camera and register tokens - self.patch_start_idx = 1 + num_register_tokens - - # Initialize parameters with small values - nn.init.normal_(self.camera_token, std=1e-6) - nn.init.normal_(self.register_token, std=1e-6) - - # Register normalization constants as buffers - for name, value in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)): - self.register_buffer(name, torch.FloatTensor(value).view(1, 1, 3, 1, 1), persistent=False) - - self.use_reentrant = False # hardcoded to False - - def __build_patch_embed__( - self, - patch_embed, - img_size, - patch_size, - num_register_tokens, - interpolate_antialias=True, - interpolate_offset=0.0, - block_chunks=0, - init_values=1.0, - embed_dim=1024, - ): - """ - Build the patch embed layer. If 'conv', we use a - simple PatchEmbed conv layer. Otherwise, we use a vision transformer. - """ - - if "conv" in patch_embed: - self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim) - else: - vit_models = { - "dinov2_vitl14_reg": vit_large, - "dinov2_vitb14_reg": vit_base, - "dinov2_vits14_reg": vit_small, - "dinov2_vitg2_reg": vit_giant2, - } - - self.patch_embed = vit_models[patch_embed]( - img_size=img_size, - patch_size=patch_size, - num_register_tokens=num_register_tokens, - interpolate_antialias=interpolate_antialias, - interpolate_offset=interpolate_offset, - block_chunks=block_chunks, - init_values=init_values, - ) - - # Disable gradient updates for mask token - if hasattr(self.patch_embed, "mask_token"): - self.patch_embed.mask_token.requires_grad_(False) - - def forward(self, images: torch.Tensor) -> Tuple[List[torch.Tensor], int]: - """ - Args: - images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. - B: batch size, S: sequence length, 3: RGB channels, H: height, W: width - - Returns: - (list[torch.Tensor], int): - The list of outputs from the attention blocks, - and the patch_start_idx indicating where patch tokens begin. - """ - B, S, C_in, H, W = images.shape - - if C_in != 3: - raise ValueError(f"Expected 3 input channels, got {C_in}") - - # Normalize images and reshape for patch embed - images = (images - self._resnet_mean) / self._resnet_std - - # Reshape to [B*S, C, H, W] for patch embedding - images = images.view(B * S, C_in, H, W) - patch_tokens = self.patch_embed(images) - - if isinstance(patch_tokens, dict): - patch_tokens = patch_tokens["x_norm_patchtokens"] - - _, P, C = patch_tokens.shape - - # Expand camera and register tokens to match batch size and sequence length - camera_token = slice_expand_and_flatten(self.camera_token, B, S) - register_token = slice_expand_and_flatten(self.register_token, B, S) - - # Concatenate special tokens with patch tokens - tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1) - - pos = None - if self.rope is not None: - pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device) - - if self.patch_start_idx > 0: - # do not use position embedding for special tokens (camera and register tokens) - # so set pos to 0 for the special tokens - pos = pos + 1 - pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype) - pos = torch.cat([pos_special, pos], dim=1) - - # update P because we added special tokens - _, P, C = tokens.shape - - frame_idx = 0 - global_idx = 0 - output_list = [] - - for _ in range(self.aa_block_num): - for attn_type in self.aa_order: - if attn_type == "frame": - tokens, frame_idx, frame_intermediates = self._process_frame_attention( - tokens, B, S, P, C, frame_idx, pos=pos - ) - elif attn_type == "global": - tokens, global_idx, global_intermediates = self._process_global_attention( - tokens, B, S, P, C, global_idx, pos=pos - ) - else: - raise ValueError(f"Unknown attention type: {attn_type}") - - for i in range(len(frame_intermediates)): - # concat frame and global intermediates, [B x S x P x 2C] - concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1) - output_list.append(concat_inter) - - del concat_inter - del frame_intermediates - del global_intermediates - return output_list, self.patch_start_idx - - def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None): - """ - Process frame attention blocks. We keep tokens in shape (B*S, P, C). - """ - # If needed, reshape tokens or positions: - if tokens.shape != (B * S, P, C): - tokens = tokens.view(B, S, P, C).view(B * S, P, C) - - if pos is not None and pos.shape != (B * S, P, 2): - pos = pos.view(B, S, P, 2).view(B * S, P, 2) - - intermediates = [] - - # by default, self.aa_block_size=1, which processes one block at a time - for _ in range(self.aa_block_size): - if self.training: - tokens = checkpoint(self.frame_blocks[frame_idx], tokens, pos, use_reentrant=self.use_reentrant) - else: - tokens = self.frame_blocks[frame_idx](tokens, pos=pos) - frame_idx += 1 - intermediates.append(tokens.view(B, S, P, C)) - - return tokens, frame_idx, intermediates - - def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None): - """ - Process global attention blocks. We keep tokens in shape (B, S*P, C). - """ - if tokens.shape != (B, S * P, C): - tokens = tokens.view(B, S, P, C).view(B, S * P, C) - - if pos is not None and pos.shape != (B, S * P, 2): - pos = pos.view(B, S, P, 2).view(B, S * P, 2) - - intermediates = [] - - # by default, self.aa_block_size=1, which processes one block at a time - for _ in range(self.aa_block_size): - if self.training: - tokens = checkpoint(self.global_blocks[global_idx], tokens, pos, use_reentrant=self.use_reentrant) - else: - tokens = self.global_blocks[global_idx](tokens, pos=pos) - global_idx += 1 - intermediates.append(tokens.view(B, S, P, C)) - - return tokens, global_idx, intermediates - - -def slice_expand_and_flatten(token_tensor, B, S): - """ - Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing: - 1) Uses the first position (index=0) for the first frame only - 2) Uses the second position (index=1) for all remaining frames (S-1 frames) - 3) Expands both to match batch size B - 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token - followed by (S-1) second-position tokens - 5) Flattens to (B*S, X, C) for processing - - Returns: - torch.Tensor: Processed tokens with shape (B*S, X, C) - """ - - # Slice out the "query" tokens => shape (1, 1, ...) - query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:]) - # Slice out the "other" tokens => shape (1, S-1, ...) - others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:]) - # Concatenate => shape (B, S, ...) - combined = torch.cat([query, others], dim=1) - - # Finally flatten => shape (B*S, ...) - combined = combined.view(B * S, *combined.shape[2:]) - return combined diff --git a/capvector-pi05/src/vggt/models/vggt.py b/capvector-pi05/src/vggt/models/vggt.py deleted file mode 100644 index 7decb0a73def7ab3e124c0e873c7081bdc954542..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/models/vggt.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import torch.nn as nn -from huggingface_hub import PyTorchModelHubMixin # used for model hub - -from vggt.models.aggregator import Aggregator -from vggt.heads.camera_head import CameraHead -from vggt.heads.dpt_head import DPTHead -from vggt.heads.track_head import TrackHead - - -class VGGT(nn.Module, PyTorchModelHubMixin): - def __init__(self, img_size=518, patch_size=14, embed_dim=1024, - enable_camera=True, enable_point=True, enable_depth=True, enable_track=True, feature_only=False): - super().__init__() - - self.aggregator = Aggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim) - - self.camera_head = CameraHead(dim_in=2 * embed_dim) if enable_camera else None - self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1") if enable_point else None - self.depth_head = DPTHead(dim_in=2 * embed_dim, output_dim=2, activation="exp", conf_activation="expp1") if enable_depth else None - self.track_head = TrackHead(dim_in=2 * embed_dim, patch_size=patch_size) if enable_track else None - self.feature_only = feature_only - self.embed_dim = embed_dim - self.patch_size = patch_size - - def forward(self, images: torch.Tensor, query_points: torch.Tensor = None): - """ - Forward pass of the VGGT model. - - Args: - images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1]. - B: batch size, S: sequence length, 3: RGB channels, H: height, W: width - query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates. - Shape: [N, 2] or [B, N, 2], where N is the number of query points. - Default: None - - Returns: - dict: A dictionary containing the following predictions: - - pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration) - - depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1] - - depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W] - - world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3] - - world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W] - - images (torch.Tensor): Original input images, preserved for visualization - - If query_points is provided, also includes: - - track (torch.Tensor): Point tracks with shape [B, S, N, 2] (from the last iteration), in pixel coordinates - - vis (torch.Tensor): Visibility scores for tracked points with shape [B, S, N] - - conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N] - """ - # If without batch dimension, add it - if len(images.shape) == 4: - images = images.unsqueeze(0) - - if query_points is not None and len(query_points.shape) == 2: - query_points = query_points.unsqueeze(0) - - aggregated_tokens_list, patch_start_idx = self.aggregator(images) - - predictions = {} - - with torch.amp.autocast(device_type="cuda", enabled=False): - if self.camera_head is not None: - pose_enc_list = self.camera_head(aggregated_tokens_list) - predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration - predictions["pose_enc_list"] = pose_enc_list - - if self.depth_head is not None: - depth, depth_conf = self.depth_head( - aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx - ) - predictions["depth"] = depth - predictions["depth_conf"] = depth_conf - - if self.point_head is not None: - pts3d, pts3d_conf = self.point_head( - aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx - ) - predictions["world_points"] = pts3d - predictions["world_points_conf"] = pts3d_conf - - if self.track_head is not None and query_points is not None: - track_list, vis, conf = self.track_head( - aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx, query_points=query_points - ) - predictions["track"] = track_list[-1] # track of the last iteration - predictions["vis"] = vis - predictions["conf"] = conf - - if not self.training: - predictions["images"] = images # store the images for visualization during inference - - if self.feature_only: - predictions["features"] = aggregated_tokens_list - predictions['patch_start_idx'] = patch_start_idx - return predictions - diff --git a/capvector-pi05/src/vggt/pyproject.toml b/capvector-pi05/src/vggt/pyproject.toml deleted file mode 100644 index 2e3ea16f6bc84967a90d7d319c204ca92936b8b5..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/pyproject.toml +++ /dev/null @@ -1,52 +0,0 @@ -[project] -authors = [{name = "Jianyuan Wang", email = "jianyuan@robots.ox.ac.uk"}] -dependencies = [ - "numpy<2", - "Pillow", - "huggingface_hub", - "einops", - "safetensors", - "opencv-python", -] -name = "vggt" -requires-python = ">= 3.10" -version = "0.0.1" - -[project.optional-dependencies] -demo = [ - "gradio==5.17.1", - "viser==0.2.23", - "tqdm", - "hydra-core", - "omegaconf", - "opencv-python", - "scipy", - "onnxruntime", - "requests", - "trimesh", - "matplotlib", -] - -# Using setuptools as the build backend -[build-system] -requires = ["setuptools>=61.0", "wheel"] -build-backend = "setuptools.build_meta" - -# setuptools configuration -[tool.setuptools.packages.find] -where = ["."] -include = ["vggt*"] - -# Pixi configuration -[tool.pixi.workspace] -channels = ["conda-forge"] -platforms = ["linux-64"] - -[tool.pixi.pypi-dependencies] -vggt = { path = ".", editable = true } - -[tool.pixi.environments] -default = { solve-group = "default" } -demo = { features = ["demo"], solve-group = "default" } - -[tool.pixi.tasks] diff --git a/capvector-pi05/src/vggt/utils/geometry.py b/capvector-pi05/src/vggt/utils/geometry.py deleted file mode 100644 index 81c6a2eb3623a7824e55dd675595e7abe9350fae..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/utils/geometry.py +++ /dev/null @@ -1,324 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import os -import torch -import numpy as np - - -from vggt.dependency.distortion import apply_distortion, iterative_undistortion, single_undistortion - - -def unproject_depth_map_to_point_map( - depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray -) -> np.ndarray: - """ - Unproject a batch of depth maps to 3D world coordinates. - - Args: - depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W) - extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4) - intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3) - - Returns: - np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3) - """ - if isinstance(depth_map, torch.Tensor): - depth_map = depth_map.cpu().numpy() - if isinstance(extrinsics_cam, torch.Tensor): - extrinsics_cam = extrinsics_cam.cpu().numpy() - if isinstance(intrinsics_cam, torch.Tensor): - intrinsics_cam = intrinsics_cam.cpu().numpy() - - world_points_list = [] - for frame_idx in range(depth_map.shape[0]): - cur_world_points, _, _ = depth_to_world_coords_points( - depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx] - ) - world_points_list.append(cur_world_points) - world_points_array = np.stack(world_points_list, axis=0) - - return world_points_array - - -def depth_to_world_coords_points( - depth_map: np.ndarray, - extrinsic: np.ndarray, - intrinsic: np.ndarray, - eps=1e-8, -) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - """ - Convert a depth map to world coordinates. - - Args: - depth_map (np.ndarray): Depth map of shape (H, W). - intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3). - extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world. - - Returns: - tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W). - """ - if depth_map is None: - return None, None, None - - # Valid depth mask - point_mask = depth_map > eps - - # Convert depth map to camera coordinates - cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic) - - # Multiply with the inverse of extrinsic matrix to transform to world coordinates - # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4)) - cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0] - - R_cam_to_world = cam_to_world_extrinsic[:3, :3] - t_cam_to_world = cam_to_world_extrinsic[:3, 3] - - # Apply the rotation and translation to the camera coordinates - world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3 - # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world - - return world_coords_points, cam_coords_points, point_mask - - -def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]: - """ - Convert a depth map to camera coordinates. - - Args: - depth_map (np.ndarray): Depth map of shape (H, W). - intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3). - - Returns: - tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3) - """ - H, W = depth_map.shape - assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3" - assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew" - - # Intrinsic parameters - fu, fv = intrinsic[0, 0], intrinsic[1, 1] - cu, cv = intrinsic[0, 2], intrinsic[1, 2] - - # Generate grid of pixel coordinates - u, v = np.meshgrid(np.arange(W), np.arange(H)) - - # Unproject to camera coordinates - x_cam = (u - cu) * depth_map / fu - y_cam = (v - cv) * depth_map / fv - z_cam = depth_map - - # Stack to form camera coordinates - cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) - - return cam_coords - - -def closed_form_inverse_se3(se3, R=None, T=None): - """ - Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch. - - If `R` and `T` are provided, they must correspond to the rotation and translation - components of `se3`. Otherwise, they will be extracted from `se3`. - - Args: - se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices. - R (optional): Nx3x3 array or tensor of rotation matrices. - T (optional): Nx3x1 array or tensor of translation vectors. - - Returns: - Inverted SE3 matrices with the same type and device as `se3`. - - Shapes: - se3: (N, 4, 4) - R: (N, 3, 3) - T: (N, 3, 1) - """ - # Check if se3 is a numpy array or a torch tensor - is_numpy = isinstance(se3, np.ndarray) - - # Validate shapes - if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4): - raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.") - - # Extract R and T if not provided - if R is None: - R = se3[:, :3, :3] # (N,3,3) - if T is None: - T = se3[:, :3, 3:] # (N,3,1) - - # Transpose R - if is_numpy: - # Compute the transpose of the rotation for NumPy - R_transposed = np.transpose(R, (0, 2, 1)) - # -R^T t for NumPy - top_right = -np.matmul(R_transposed, T) - inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1)) - else: - R_transposed = R.transpose(1, 2) # (N,3,3) - top_right = -torch.bmm(R_transposed, T) # (N,3,1) - inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1) - inverted_matrix = inverted_matrix.to(R.dtype).to(R.device) - - inverted_matrix[:, :3, :3] = R_transposed - inverted_matrix[:, :3, 3:] = top_right - - return inverted_matrix - - -# TODO: this code can be further cleaned up - - -def project_world_points_to_camera_points_batch(world_points, cam_extrinsics): - """ - Transforms 3D points to 2D using extrinsic and intrinsic parameters. - Args: - world_points (torch.Tensor): 3D points of shape BxSxHxWx3. - cam_extrinsics (torch.Tensor): Extrinsic parameters of shape BxSx3x4. - Returns: - """ - # TODO: merge this into project_world_points_to_cam - - # device = world_points.device - # with torch.autocast(device_type=device.type, enabled=False): - ones = torch.ones_like(world_points[..., :1]) # shape: (B, S, H, W, 1) - world_points_h = torch.cat([world_points, ones], dim=-1) # shape: (B, S, H, W, 4) - - # extrinsics: (B, S, 3, 4) -> (B, S, 1, 1, 3, 4) - extrinsics_exp = cam_extrinsics.unsqueeze(2).unsqueeze(3) - - # world_points_h: (B, S, H, W, 4) -> (B, S, H, W, 4, 1) - world_points_h_exp = world_points_h.unsqueeze(-1) - - # Now perform the matrix multiplication - # (B, S, 1, 1, 3, 4) @ (B, S, H, W, 4, 1) broadcasts to (B, S, H, W, 3, 1) - camera_points = torch.matmul(extrinsics_exp, world_points_h_exp).squeeze(-1) - - return camera_points - - - -def project_world_points_to_cam( - world_points, - cam_extrinsics, - cam_intrinsics=None, - distortion_params=None, - default=0, - only_points_cam=False, -): - """ - Transforms 3D points to 2D using extrinsic and intrinsic parameters. - Args: - world_points (torch.Tensor): 3D points of shape Px3. - cam_extrinsics (torch.Tensor): Extrinsic parameters of shape Bx3x4. - cam_intrinsics (torch.Tensor): Intrinsic parameters of shape Bx3x3. - distortion_params (torch.Tensor): Extra parameters of shape BxN, which is used for radial distortion. - Returns: - torch.Tensor: Transformed 2D points of shape BxNx2. - """ - device = world_points.device - # with torch.autocast(device_type=device.type, dtype=torch.double): - with torch.autocast(device_type=device.type, enabled=False): - N = world_points.shape[0] # Number of points - B = cam_extrinsics.shape[0] # Batch size, i.e., number of cameras - world_points_homogeneous = torch.cat( - [world_points, torch.ones_like(world_points[..., 0:1])], dim=1 - ) # Nx4 - # Reshape for batch processing - world_points_homogeneous = world_points_homogeneous.unsqueeze(0).expand( - B, -1, -1 - ) # BxNx4 - - # Step 1: Apply extrinsic parameters - # Transform 3D points to camera coordinate system for all cameras - cam_points = torch.bmm( - cam_extrinsics, world_points_homogeneous.transpose(-1, -2) - ) - - if only_points_cam: - return None, cam_points - - # Step 2: Apply intrinsic parameters and (optional) distortion - image_points = img_from_cam(cam_intrinsics, cam_points, distortion_params, default=default) - - return image_points, cam_points - - - -def img_from_cam(cam_intrinsics, cam_points, distortion_params=None, default=0.0): - """ - Applies intrinsic parameters and optional distortion to the given 3D points. - - Args: - cam_intrinsics (torch.Tensor): Intrinsic camera parameters of shape Bx3x3. - cam_points (torch.Tensor): 3D points in camera coordinates of shape Bx3xN. - distortion_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4. - default (float, optional): Default value to replace NaNs in the output. - - Returns: - pixel_coords (torch.Tensor): 2D points in pixel coordinates of shape BxNx2. - """ - - # Normalized device coordinates (NDC) - cam_points = cam_points / cam_points[:, 2:3, :] - ndc_xy = cam_points[:, :2, :] - - # Apply distortion if distortion_params are provided - if distortion_params is not None: - x_distorted, y_distorted = apply_distortion(distortion_params, ndc_xy[:, 0], ndc_xy[:, 1]) - distorted_xy = torch.stack([x_distorted, y_distorted], dim=1) - else: - distorted_xy = ndc_xy - - # Prepare cam_points for batch matrix multiplication - cam_coords_homo = torch.cat( - (distorted_xy, torch.ones_like(distorted_xy[:, :1, :])), dim=1 - ) # Bx3xN - # Apply intrinsic parameters using batch matrix multiplication - pixel_coords = torch.bmm(cam_intrinsics, cam_coords_homo) # Bx3xN - - # Extract x and y coordinates - pixel_coords = pixel_coords[:, :2, :] # Bx2xN - - # Replace NaNs with default value - pixel_coords = torch.nan_to_num(pixel_coords, nan=default) - - return pixel_coords.transpose(1, 2) # BxNx2 - - - - -def cam_from_img(pred_tracks, intrinsics, extra_params=None): - """ - Normalize predicted tracks based on camera intrinsics. - Args: - intrinsics (torch.Tensor): The camera intrinsics tensor of shape [batch_size, 3, 3]. - pred_tracks (torch.Tensor): The predicted tracks tensor of shape [batch_size, num_tracks, 2]. - extra_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4. - Returns: - torch.Tensor: Normalized tracks tensor. - """ - - # We don't want to do intrinsics_inv = torch.inverse(intrinsics) here - # otherwise we can use something like - # tracks_normalized_homo = torch.bmm(pred_tracks_homo, intrinsics_inv.transpose(1, 2)) - - principal_point = intrinsics[:, [0, 1], [2, 2]].unsqueeze(-2) - focal_length = intrinsics[:, [0, 1], [0, 1]].unsqueeze(-2) - tracks_normalized = (pred_tracks - principal_point) / focal_length - - if extra_params is not None: - # Apply iterative undistortion - try: - tracks_normalized = iterative_undistortion( - extra_params, tracks_normalized - ) - except: - tracks_normalized = single_undistortion( - extra_params, tracks_normalized - ) - - return tracks_normalized \ No newline at end of file diff --git a/capvector-pi05/src/vggt/utils/helper.py b/capvector-pi05/src/vggt/utils/helper.py deleted file mode 100644 index 405edd1879055fa5eda820438753fb99b82297a5..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/utils/helper.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import numpy as np - - -def randomly_limit_trues(mask: np.ndarray, max_trues: int) -> np.ndarray: - """ - If mask has more than max_trues True values, - randomly keep only max_trues of them and set the rest to False. - """ - # 1D positions of all True entries - true_indices = np.flatnonzero(mask) # shape = (N_true,) - - # if already within budget, return as-is - if true_indices.size <= max_trues: - return mask - - # randomly pick which True positions to keep - sampled_indices = np.random.choice(true_indices, size=max_trues, replace=False) # shape = (max_trues,) - - # build new flat mask: True only at sampled positions - limited_flat_mask = np.zeros(mask.size, dtype=bool) - limited_flat_mask[sampled_indices] = True - - # restore original shape - return limited_flat_mask.reshape(mask.shape) - - -def create_pixel_coordinate_grid(num_frames, height, width): - """ - Creates a grid of pixel coordinates and frame indices for all frames. - Returns: - tuple: A tuple containing: - - points_xyf (numpy.ndarray): Array of shape (num_frames, height, width, 3) - with x, y coordinates and frame indices - - y_coords (numpy.ndarray): Array of y coordinates for all frames - - x_coords (numpy.ndarray): Array of x coordinates for all frames - - f_coords (numpy.ndarray): Array of frame indices for all frames - """ - # Create coordinate grids for a single frame - y_grid, x_grid = np.indices((height, width), dtype=np.float32) - x_grid = x_grid[np.newaxis, :, :] - y_grid = y_grid[np.newaxis, :, :] - - # Broadcast to all frames - x_coords = np.broadcast_to(x_grid, (num_frames, height, width)) - y_coords = np.broadcast_to(y_grid, (num_frames, height, width)) - - # Create frame indices and broadcast - f_idx = np.arange(num_frames, dtype=np.float32)[:, np.newaxis, np.newaxis] - f_coords = np.broadcast_to(f_idx, (num_frames, height, width)) - - # Stack coordinates and frame indices - points_xyf = np.stack((x_coords, y_coords, f_coords), axis=-1) - - return points_xyf diff --git a/capvector-pi05/src/vggt/utils/load_fn.py b/capvector-pi05/src/vggt/utils/load_fn.py deleted file mode 100644 index 0cb0dcab98cb3ce92de74f7875e4bf33cfb936e3..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/utils/load_fn.py +++ /dev/null @@ -1,303 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import torch.nn.functional as F -from PIL import Image -from torchvision import transforms as TF -import numpy as np -import torchvision.transforms.functional as TVF -from copy import deepcopy - - -def load_and_preprocess_images_square(image_path_list, target_size=1024): - """ - Load and preprocess images by center padding to square and resizing to target size. - Also returns the position information of original pixels after transformation. - - Args: - image_path_list (list): List of paths to image files - target_size (int, optional): Target size for both width and height. Defaults to 518. - - Returns: - tuple: ( - torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, target_size, target_size), - torch.Tensor: Array of shape (N, 5) containing [x1, y1, x2, y2, width, height] for each image - ) - - Raises: - ValueError: If the input list is empty - """ - # Check for empty list - if len(image_path_list) == 0: - raise ValueError("At least 1 image is required") - - images = [] - original_coords = [] # Renamed from position_info to be more descriptive - to_tensor = TF.ToTensor() - - for image_path in image_path_list: - # Open image - img = Image.open(image_path) - - # If there's an alpha channel, blend onto white background - if img.mode == "RGBA": - background = Image.new("RGBA", img.size, (255, 255, 255, 255)) - img = Image.alpha_composite(background, img) - - # Convert to RGB - img = img.convert("RGB") - - # Get original dimensions - width, height = img.size - - # Make the image square by padding the shorter dimension - max_dim = max(width, height) - - # Calculate padding - left = (max_dim - width) // 2 - top = (max_dim - height) // 2 - - # Calculate scale factor for resizing - scale = target_size / max_dim - - # Calculate final coordinates of original image in target space - x1 = left * scale - y1 = top * scale - x2 = (left + width) * scale - y2 = (top + height) * scale - - # Store original image coordinates and scale - original_coords.append(np.array([x1, y1, x2, y2, width, height])) - - # Create a new black square image and paste original - square_img = Image.new("RGB", (max_dim, max_dim), (0, 0, 0)) - square_img.paste(img, (left, top)) - - # Resize to target size - square_img = square_img.resize((target_size, target_size), Image.Resampling.BICUBIC) - - # Convert to tensor - img_tensor = to_tensor(square_img) - images.append(img_tensor) - - # Stack all images - images = torch.stack(images) - original_coords = torch.from_numpy(np.array(original_coords)).float() - - # Add additional dimension if single image to ensure correct shape - if len(image_path_list) == 1: - if images.dim() == 3: - images = images.unsqueeze(0) - original_coords = original_coords.unsqueeze(0) - - return images, original_coords - - -def load_and_preprocess_images(image_path_list, mode="crop"): - """ - A quick start function to load and preprocess images for model input. - This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes. - - Args: - image_path_list (list): List of paths to image files - mode (str, optional): Preprocessing mode, either "crop" or "pad". - - "crop" (default): Sets width to 518px and center crops height if needed. - - "pad": Preserves all pixels by making the largest dimension 518px - and padding the smaller dimension to reach a square shape. - - Returns: - torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W) - - Raises: - ValueError: If the input list is empty or if mode is invalid - - Notes: - - Images with different dimensions will be padded with white (value=1.0) - - A warning is printed when images have different shapes - - When mode="crop": The function ensures width=518px while maintaining aspect ratio - and height is center-cropped if larger than 518px - - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio - and the smaller dimension is padded to reach a square shape (518x518) - - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements - """ - # Check for empty list - if len(image_path_list) == 0: - raise ValueError("At least 1 image is required") - - # Validate mode - if mode not in ["crop", "pad"]: - raise ValueError("Mode must be either 'crop' or 'pad'") - - images = [] - shapes = set() - to_tensor = TF.ToTensor() - target_size = 518 - - # First process all images and collect their shapes - for image_path in image_path_list: - # Open image - img = Image.open(image_path) - - # If there's an alpha channel, blend onto white background: - if img.mode == "RGBA": - # Create white background - background = Image.new("RGBA", img.size, (255, 255, 255, 255)) - # Alpha composite onto the white background - img = Image.alpha_composite(background, img) - - # Now convert to "RGB" (this step assigns white for transparent areas) - img = img.convert("RGB") - - width, height = img.size - - if mode == "pad": - # Make the largest dimension 518px while maintaining aspect ratio - if width >= height: - new_width = target_size - new_height = round(height * (new_width / width) / 14) * 14 # Make divisible by 14 - else: - new_height = target_size - new_width = round(width * (new_height / height) / 14) * 14 # Make divisible by 14 - else: # mode == "crop" - # Original behavior: set width to 518px - new_width = target_size - # Calculate height maintaining aspect ratio, divisible by 14 - new_height = round(height * (new_width / width) / 14) * 14 - - # Resize with new dimensions (width, height) - img = img.resize((new_width, new_height), Image.Resampling.BICUBIC) - img = to_tensor(img) # Convert to tensor (0, 1) - - # Center crop height if it's larger than 518 (only in crop mode) - if mode == "crop" and new_height > target_size: - start_y = (new_height - target_size) // 2 - img = img[:, start_y : start_y + target_size, :] - - # For pad mode, pad to make a square of target_size x target_size - if mode == "pad": - h_padding = target_size - img.shape[1] - w_padding = target_size - img.shape[2] - - if h_padding > 0 or w_padding > 0: - pad_top = h_padding // 2 - pad_bottom = h_padding - pad_top - pad_left = w_padding // 2 - pad_right = w_padding - pad_left - - # Pad with white (value=1.0) - img = torch.nn.functional.pad( - img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 - ) - - shapes.add((img.shape[1], img.shape[2])) - images.append(img) # 3*518*518 - - # Check if we have different shapes - # In theory our model can also work well with different shapes - if len(shapes) > 1: - print(f"Warning: Found images with different shapes: {shapes}") - # Find maximum dimensions - max_height = max(shape[0] for shape in shapes) - max_width = max(shape[1] for shape in shapes) - - # Pad images if necessary - padded_images = [] - for img in images: - h_padding = max_height - img.shape[1] - w_padding = max_width - img.shape[2] - - if h_padding > 0 or w_padding > 0: - pad_top = h_padding // 2 - pad_bottom = h_padding - pad_top - pad_left = w_padding // 2 - pad_right = w_padding - pad_left - - img = torch.nn.functional.pad( - img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 - ) - padded_images.append(img) - images = padded_images - # N*3*518*518 - images = torch.stack(images) # concatenate images - - # Ensure correct shape when single image - if len(image_path_list) == 1: - # Verify shape is (1, C, H, W) - if images.dim() == 3: - images = images.unsqueeze(0) - - return images - - -def preprocess_images_from_openpi(image_list, mode="crop"): - """ - This function convert a list of images provided by openpi (e.g., pi0) to the demanded format for VGGT. - Mostly, the openpi images are already square, and there is no need to crop or pad them. - Then images are resized to 518x518. - - Args: - image_list (list(torch.tensor)): each image in the list is a torch tensor with shape (B, C, H, W) - Returns: - torch.Tensor: Batched tensor of preprocessed images with shape (B, N, C, H, W) - """ - # Check for empty list - if len(image_list) == 0: - raise ValueError("At least 1 image is required") - - # Validate mode - if mode not in ["crop", "pad"]: - raise ValueError("Mode must be either 'crop' or 'pad'") - - images = [] - shapes = set() - target_size = 518 - - # Resize primary and wrist images to 518px (VGGT required) - _height, _width = image_list[0].shape[-2:] - if mode == "pad": - # Make the largest dimension 518px while maintaining aspect ratio - if _width >= _height: - new_width = target_size - new_height = round(_height * (new_width / _width) / 14) * 14 # Make divisible by 14 - else: - new_height = target_size - new_width = round(_width * (new_height / _height) / 14) * 14 # Make divisible by 14 - else: # mode == "crop" - # Original behavior: set width to 518px - new_width = target_size - # Calculate height maintaining aspect ratio, divisible by 14 - new_height = round(_height * (new_width / _width) / 14) * 14 - interpolate_img_list = [ - F.interpolate(img, size=(new_height, new_width), mode='bicubic', align_corners=False) for img in image_list - ] - - # Center crop height if it's larger than 518 (only in crop mode) - if mode == "crop": - start_y = (new_height - target_size) // 2 - reshaped_img_list = [img[:, :, start_y : start_y + target_size, :] for img in interpolate_img_list] - - # For pad mode, pad to make a square of target_size x target_size - if mode == "pad": - height, width = interpolate_img_list[0].shape[-2:] - h_padding = target_size - height - w_padding = target_size - width - - if h_padding > 0 or w_padding > 0: - pad_top = h_padding // 2 - pad_bottom = h_padding - pad_top - pad_left = w_padding // 2 - pad_right = w_padding - pad_left - - # Pad with white (value=1.0) - reshaped_img_list = [ - torch.nn.functional.pad(img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0) - for img in interpolate_img_list - ] - else: - reshaped_img_list = interpolate_img_list - - return torch.stack(reshaped_img_list, dim=1) # [bs, N, C, H, W] \ No newline at end of file diff --git a/capvector-pi05/src/vggt/utils/pose_enc.py b/capvector-pi05/src/vggt/utils/pose_enc.py deleted file mode 100644 index a6ddccbe71f3587860e91f9e3eb2a9a3c3af1ab6..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/utils/pose_enc.py +++ /dev/null @@ -1,124 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch -from .rotation import quat_to_mat, mat_to_quat - - -def extri_intri_to_pose_encoding( - extrinsics, intrinsics, image_size_hw=None, pose_encoding_type="absT_quaR_FoV" # e.g., (256, 512) -): - """Convert camera extrinsics and intrinsics to a compact pose encoding. - - This function transforms camera parameters into a unified pose encoding format, - which can be used for various downstream tasks like pose prediction or representation. - - Args: - extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4, - where B is batch size and S is sequence length. - In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation. - The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector. - intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3. - Defined in pixels, with format: - [[fx, 0, cx], - [0, fy, cy], - [0, 0, 1]] - where fx, fy are focal lengths and (cx, cy) is the principal point - image_size_hw (tuple): Tuple of (height, width) of the image in pixels. - Required for computing field of view values. For example: (256, 512). - pose_encoding_type (str): Type of pose encoding to use. Currently only - supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view). - - Returns: - torch.Tensor: Encoded camera pose parameters with shape BxSx9. - For "absT_quaR_FoV" type, the 9 dimensions are: - - [:3] = absolute translation vector T (3D) - - [3:7] = rotation as quaternion quat (4D) - - [7:] = field of view (2D) - """ - - # extrinsics: BxSx3x4 - # intrinsics: BxSx3x3 - - if pose_encoding_type == "absT_quaR_FoV": - R = extrinsics[:, :, :3, :3] # BxSx3x3 - T = extrinsics[:, :, :3, 3] # BxSx3 - - quat = mat_to_quat(R) - # Note the order of h and w here - H, W = image_size_hw - fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1]) - fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0]) - pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float() - else: - raise NotImplementedError - - return pose_encoding - - -def pose_encoding_to_extri_intri( - pose_encoding, image_size_hw=None, pose_encoding_type="absT_quaR_FoV", build_intrinsics=True # e.g., (256, 512) -): - """Convert a pose encoding back to camera extrinsics and intrinsics. - - This function performs the inverse operation of extri_intri_to_pose_encoding, - reconstructing the full camera parameters from the compact encoding. - - Args: - pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9, - where B is batch size and S is sequence length. - For "absT_quaR_FoV" type, the 9 dimensions are: - - [:3] = absolute translation vector T (3D) - - [3:7] = rotation as quaternion quat (4D) - - [7:] = field of view (2D) - image_size_hw (tuple): Tuple of (height, width) of the image in pixels. - Required for reconstructing intrinsics from field of view values. - For example: (256, 512). - pose_encoding_type (str): Type of pose encoding used. Currently only - supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view). - build_intrinsics (bool): Whether to reconstruct the intrinsics matrix. - If False, only extrinsics are returned and intrinsics will be None. - - Returns: - tuple: (extrinsics, intrinsics) - - extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4. - In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world - transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is - a 3x1 translation vector. - - intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3, - or None if build_intrinsics is False. Defined in pixels, with format: - [[fx, 0, cx], - [0, fy, cy], - [0, 0, 1]] - where fx, fy are focal lengths and (cx, cy) is the principal point, - assumed to be at the center of the image (W/2, H/2). - """ - - intrinsics = None - - if pose_encoding_type == "absT_quaR_FoV": - T = pose_encoding[..., :3] - quat = pose_encoding[..., 3:7] - fov_h = pose_encoding[..., 7] - fov_w = pose_encoding[..., 8] - - R = quat_to_mat(quat) - extrinsics = torch.cat([R, T[..., None]], dim=-1) - - if build_intrinsics: - H, W = image_size_hw - fy = (H / 2.0) / torch.tan(fov_h / 2.0) - fx = (W / 2.0) / torch.tan(fov_w / 2.0) - intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device) - intrinsics[..., 0, 0] = fx - intrinsics[..., 1, 1] = fy - intrinsics[..., 0, 2] = W / 2 - intrinsics[..., 1, 2] = H / 2 - intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1 - else: - raise NotImplementedError - - return extrinsics, intrinsics diff --git a/capvector-pi05/src/vggt/utils/rotation.py b/capvector-pi05/src/vggt/utils/rotation.py deleted file mode 100644 index 494f176450cbf1fd4dd3ef21787201b12f357843..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/utils/rotation.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -# Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d - -import torch -import numpy as np -import torch.nn.functional as F - - -def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor: - """ - Quaternion Order: XYZW or say ijkr, scalar-last - - Convert rotations given as quaternions to rotation matrices. - Args: - quaternions: quaternions with real part last, - as tensor of shape (..., 4). - - Returns: - Rotation matrices as tensor of shape (..., 3, 3). - """ - i, j, k, r = torch.unbind(quaternions, -1) - # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. - two_s = 2.0 / (quaternions * quaternions).sum(-1) - - o = torch.stack( - ( - 1 - two_s * (j * j + k * k), - two_s * (i * j - k * r), - two_s * (i * k + j * r), - two_s * (i * j + k * r), - 1 - two_s * (i * i + k * k), - two_s * (j * k - i * r), - two_s * (i * k - j * r), - two_s * (j * k + i * r), - 1 - two_s * (i * i + j * j), - ), - -1, - ) - return o.reshape(quaternions.shape[:-1] + (3, 3)) - - -def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor: - """ - Convert rotations given as rotation matrices to quaternions. - - Args: - matrix: Rotation matrices as tensor of shape (..., 3, 3). - - Returns: - quaternions with real part last, as tensor of shape (..., 4). - Quaternion Order: XYZW or say ijkr, scalar-last - """ - if matrix.size(-1) != 3 or matrix.size(-2) != 3: - raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") - - batch_dim = matrix.shape[:-2] - m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1) - - q_abs = _sqrt_positive_part( - torch.stack( - [1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22], dim=-1 - ) - ) - - # we produce the desired quaternion multiplied by each of r, i, j, k - quat_by_rijk = torch.stack( - [ - # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and - # `int`. - torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), - # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and - # `int`. - torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), - # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and - # `int`. - torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), - # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and - # `int`. - torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), - ], - dim=-2, - ) - - # We floor here at 0.1 but the exact level is not important; if q_abs is small, - # the candidate won't be picked. - flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) - quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) - - # if not for numerical problems, quat_candidates[i] should be same (up to a sign), - # forall i; we pick the best-conditioned one (with the largest denominator) - out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,)) - - # Convert from rijk to ijkr - out = out[..., [1, 2, 3, 0]] - - out = standardize_quaternion(out) - - return out - - -def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: - """ - Returns torch.sqrt(torch.max(0, x)) - but with a zero subgradient where x is 0. - """ - ret = torch.zeros_like(x) - positive_mask = x > 0 - if torch.is_grad_enabled(): - ret[positive_mask] = torch.sqrt(x[positive_mask]) - else: - ret = torch.where(positive_mask, torch.sqrt(x), ret) - return ret - - -def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: - """ - Convert a unit quaternion to a standard form: one in which the real - part is non negative. - - Args: - quaternions: Quaternions with real part last, - as tensor of shape (..., 4). - - Returns: - Standardized quaternions as tensor of shape (..., 4). - """ - return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions) diff --git a/capvector-pi05/src/vggt/utils/visual_track.py b/capvector-pi05/src/vggt/utils/visual_track.py deleted file mode 100644 index a4e0a27dedb261173ccd5f87d04fbe472657bb96..0000000000000000000000000000000000000000 --- a/capvector-pi05/src/vggt/utils/visual_track.py +++ /dev/null @@ -1,239 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import cv2 -import torch -import numpy as np -import os - - -def color_from_xy(x, y, W, H, cmap_name="hsv"): - """ - Map (x, y) -> color in (R, G, B). - 1) Normalize x,y to [0,1]. - 2) Combine them into a single scalar c in [0,1]. - 3) Use matplotlib's colormap to convert c -> (R,G,B). - - You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y). - """ - import matplotlib.cm - import matplotlib.colors - - x_norm = x / max(W - 1, 1) - y_norm = y / max(H - 1, 1) - # Simple combination: - c = (x_norm + y_norm) / 2.0 - - cmap = matplotlib.cm.get_cmap(cmap_name) - # cmap(c) -> (r,g,b,a) in [0,1] - rgba = cmap(c) - r, g, b = rgba[0], rgba[1], rgba[2] - return (r, g, b) # in [0,1], RGB order - - -def get_track_colors_by_position(tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv"): - """ - Given all tracks in one sample (b), compute a (N,3) array of RGB color values - in [0,255]. The color is determined by the (x,y) position in the first - visible frame for each track. - - Args: - tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame. - vis_mask_b: (S, N) boolean mask; if None, assume all are visible. - image_width, image_height: used for normalizing (x, y). - cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet'). - - Returns: - track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255]. - """ - S, N, _ = tracks_b.shape - track_colors = np.zeros((N, 3), dtype=np.uint8) - - if vis_mask_b is None: - # treat all as visible - vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device) - - for i in range(N): - # Find first visible frame for track i - visible_frames = torch.where(vis_mask_b[:, i])[0] - if len(visible_frames) == 0: - # track is never visible; just assign black or something - track_colors[i] = (0, 0, 0) - continue - - first_s = int(visible_frames[0].item()) - # use that frame's (x,y) - x, y = tracks_b[first_s, i].tolist() - - # map (x,y) -> (R,G,B) in [0,1] - r, g, b = color_from_xy(x, y, W=image_width, H=image_height, cmap_name=cmap_name) - # scale to [0,255] - r, g, b = int(r * 255), int(g * 255), int(b * 255) - track_colors[i] = (r, g, b) - - return track_colors - - -def visualize_tracks_on_images( - images, - tracks, - track_vis_mask=None, - out_dir="track_visuals_concat_by_xy", - image_format="CHW", # "CHW" or "HWC" - normalize_mode="[0,1]", - cmap_name="hsv", # e.g. "hsv", "rainbow", "jet" - frames_per_row=4, # New parameter for grid layout - save_grid=True, # Flag to control whether to save the grid image -): - """ - Visualizes frames in a grid layout with specified frames per row. - Each track's color is determined by its (x,y) position - in the first visible frame (or frame 0 if always visible). - Finally convert the BGR result to RGB before saving. - Also saves each individual frame as a separate PNG file. - - Args: - images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC. - tracks: torch.Tensor (S, N, 2), last dim = (x, y). - track_vis_mask: torch.Tensor (S, N) or None. - out_dir: folder to save visualizations. - image_format: "CHW" or "HWC". - normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255 - cmap_name: a matplotlib colormap name for color_from_xy. - frames_per_row: number of frames to display in each row of the grid. - save_grid: whether to save all frames in one grid image. - - Returns: - None (saves images in out_dir). - """ - - if len(tracks.shape) == 4: - tracks = tracks.squeeze(0) - images = images.squeeze(0) - if track_vis_mask is not None: - track_vis_mask = track_vis_mask.squeeze(0) - - import matplotlib - - matplotlib.use("Agg") # for non-interactive (optional) - - os.makedirs(out_dir, exist_ok=True) - - S = images.shape[0] - _, N, _ = tracks.shape # (S, N, 2) - - # Move to CPU - images = images.cpu().clone() - tracks = tracks.cpu().clone() - if track_vis_mask is not None: - track_vis_mask = track_vis_mask.cpu().clone() - - # Infer H, W from images shape - if image_format == "CHW": - # e.g. images[s].shape = (3, H, W) - H, W = images.shape[2], images.shape[3] - else: - # e.g. images[s].shape = (H, W, 3) - H, W = images.shape[1], images.shape[2] - - # Pre-compute the color for each track i based on first visible position - track_colors_rgb = get_track_colors_by_position( - tracks, # shape (S, N, 2) - vis_mask_b=track_vis_mask if track_vis_mask is not None else None, - image_width=W, - image_height=H, - cmap_name=cmap_name, - ) - - # We'll accumulate each frame's drawn image in a list - frame_images = [] - - for s in range(S): - # shape => either (3, H, W) or (H, W, 3) - img = images[s] - - # Convert to (H, W, 3) - if image_format == "CHW": - img = img.permute(1, 2, 0) # (H, W, 3) - # else "HWC", do nothing - - img = img.numpy().astype(np.float32) - - # Scale to [0,255] if needed - if normalize_mode == "[0,1]": - img = np.clip(img, 0, 1) * 255.0 - elif normalize_mode == "[-1,1]": - img = (img + 1.0) * 0.5 * 255.0 - img = np.clip(img, 0, 255.0) - # else no normalization - - # Convert to uint8 - img = img.astype(np.uint8) - - # For drawing in OpenCV, convert to BGR - img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) - - # Draw each visible track - cur_tracks = tracks[s] # shape (N, 2) - if track_vis_mask is not None: - valid_indices = torch.where(track_vis_mask[s])[0] - else: - valid_indices = range(N) - - cur_tracks_np = cur_tracks.numpy() - for i in valid_indices: - x, y = cur_tracks_np[i] - pt = (int(round(x)), int(round(y))) - - # track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR - R, G, B = track_colors_rgb[i] - color_bgr = (int(B), int(G), int(R)) - cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1) - - # Convert back to RGB for consistent final saving: - img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) - - # Save individual frame - frame_path = os.path.join(out_dir, f"frame_{s:04d}.png") - # Convert to BGR for OpenCV imwrite - frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) - cv2.imwrite(frame_path, frame_bgr) - - frame_images.append(img_rgb) - - # Only create and save the grid image if save_grid is True - if save_grid: - # Calculate grid dimensions - num_rows = (S + frames_per_row - 1) // frames_per_row # Ceiling division - - # Create a grid of images - grid_img = None - for row in range(num_rows): - start_idx = row * frames_per_row - end_idx = min(start_idx + frames_per_row, S) - - # Concatenate this row horizontally - row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1) - - # If this row has fewer than frames_per_row images, pad with black - if end_idx - start_idx < frames_per_row: - padding_width = (frames_per_row - (end_idx - start_idx)) * W - padding = np.zeros((H, padding_width, 3), dtype=np.uint8) - row_img = np.concatenate([row_img, padding], axis=1) - - # Add this row to the grid - if grid_img is None: - grid_img = row_img - else: - grid_img = np.concatenate([grid_img, row_img], axis=0) - - out_path = os.path.join(out_dir, "tracks_grid.png") - # Convert back to BGR for OpenCV imwrite - grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR) - cv2.imwrite(out_path, grid_img_bgr) - print(f"[INFO] Saved color-by-XY track visualization grid -> {out_path}") - - print(f"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png")