Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- capvector-pi05/examples/libero/Dockerfile +59 -0
- capvector-pi05/examples/libero/README.md +71 -0
- capvector-pi05/examples/libero/main.py +219 -0
- capvector-pi05/examples/libero/requirements.in +11 -0
- capvector-pi05/examples/libero/requirements.txt +136 -0
- capvector-pi05/examples/simple_client/Dockerfile +32 -0
- capvector-pi05/examples/simple_client/README.md +30 -0
- capvector-pi05/examples/simple_client/compose.yml +42 -0
- capvector-pi05/examples/simple_client/main.py +187 -0
- capvector-pi05/examples/simple_client/requirements.in +5 -0
- capvector-pi05/examples/simple_client/requirements.txt +30 -0
- capvector-pi05/examples/ur5/README.md +142 -0
- capvector-pi05/packages/openpi-client/pyproject.toml +23 -0
- capvector-pi05/packages/openpi-client/src/openpi_client/__init__.py +1 -0
- capvector-pi05/packages/openpi-client/src/openpi_client/action_chunk_broker.py +50 -0
- capvector-pi05/packages/openpi-client/src/openpi_client/base_policy.py +12 -0
- capvector-pi05/packages/openpi-client/src/openpi_client/image_tools.py +78 -0
- capvector-pi05/packages/openpi-client/src/openpi_client/image_tools_test.py +37 -0
- capvector-pi05/packages/openpi-client/src/openpi_client/msgpack_numpy.py +57 -0
- capvector-pi05/packages/openpi-client/src/openpi_client/msgpack_numpy_test.py +45 -0
- capvector-pi05/packages/openpi-client/src/openpi_client/runtime/agent.py +17 -0
- capvector-pi05/packages/openpi-client/src/openpi_client/runtime/agents/policy_agent.py +18 -0
- capvector-pi05/packages/openpi-client/src/openpi_client/runtime/environment.py +32 -0
- capvector-pi05/packages/openpi-client/src/openpi_client/runtime/runtime.py +92 -0
- capvector-pi05/packages/openpi-client/src/openpi_client/runtime/subscriber.py +20 -0
- capvector-pi05/packages/openpi-client/src/openpi_client/websocket_client_policy.py +55 -0
- capvector-pi05/scripts/__init__.py +0 -0
- capvector-pi05/scripts/compute_norm_stats.py +117 -0
- capvector-pi05/scripts/docker/compose.yml +29 -0
- capvector-pi05/scripts/docker/install_docker_ubuntu22.sh +37 -0
- capvector-pi05/scripts/docker/install_nvidia_container_toolkit.sh +17 -0
- capvector-pi05/scripts/docker/serve_policy.Dockerfile +38 -0
- capvector-pi05/scripts/serve_policy.py +122 -0
- capvector-pi05/scripts/train.py +280 -0
- capvector-pi05/scripts/train_align_pytorch.py +658 -0
- capvector-pi05/scripts/train_pytorch.py +632 -0
- capvector-pi05/scripts/train_regular_loss_pytorch.py +754 -0
- capvector-pi05/scripts/train_test.py +30 -0
- capvector-pi05/src/openpi/__init__.py +0 -0
- capvector-pi05/src/openpi/conftest.py +17 -0
- capvector-pi05/src/openpi/models/__init__.py +0 -0
- capvector-pi05/src/openpi/models/gemma.py +459 -0
- capvector-pi05/src/openpi/models/gemma_fast.py +437 -0
- capvector-pi05/src/openpi/models/lora.py +148 -0
- capvector-pi05/src/openpi/models/lora_test.py +94 -0
- capvector-pi05/src/openpi/models/model.py +335 -0
- capvector-pi05/src/openpi/models/model_test.py +94 -0
- capvector-pi05/src/openpi/models/pi0.py +279 -0
- capvector-pi05/src/openpi/models/pi0_config.py +108 -0
- capvector-pi05/src/openpi/models/pi0_fast.py +313 -0
capvector-pi05/examples/libero/Dockerfile
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dockerfile for the LIBERO benchmark.
|
| 2 |
+
|
| 3 |
+
# Build the container:
|
| 4 |
+
# docker build . -t libero -f examples/libero/Dockerfile
|
| 5 |
+
|
| 6 |
+
# Run the container:
|
| 7 |
+
# docker run --rm -it --network=host -v .:/app -v /tmp/.X11-unix:/tmp/.X11-unix:ro -e DISPLAY=$DISPLAY --gpus all libero /bin/bash
|
| 8 |
+
|
| 9 |
+
FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0
|
| 10 |
+
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
|
| 11 |
+
|
| 12 |
+
RUN apt-get update && \
|
| 13 |
+
apt-get install -y \
|
| 14 |
+
make \
|
| 15 |
+
g++ \
|
| 16 |
+
clang \
|
| 17 |
+
libosmesa6-dev \
|
| 18 |
+
libgl1-mesa-glx \
|
| 19 |
+
libglew-dev \
|
| 20 |
+
libglfw3-dev \
|
| 21 |
+
libgles2-mesa-dev \
|
| 22 |
+
libglib2.0-0 \
|
| 23 |
+
libsm6 \
|
| 24 |
+
libxrender1 \
|
| 25 |
+
libxext6
|
| 26 |
+
|
| 27 |
+
WORKDIR /app
|
| 28 |
+
|
| 29 |
+
# Copy from the cache instead of linking since it's a mounted volume
|
| 30 |
+
ENV UV_LINK_MODE=copy
|
| 31 |
+
|
| 32 |
+
# Write the virtual environment outside of the project directory so it doesn't
|
| 33 |
+
# leak out of the container when we mount the application code.
|
| 34 |
+
ENV UV_PROJECT_ENVIRONMENT=/.venv
|
| 35 |
+
|
| 36 |
+
# Copy the requirements files so we can install dependencies.
|
| 37 |
+
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
|
| 38 |
+
# This strategy is best for development-style usage.
|
| 39 |
+
COPY ./examples/libero/requirements.txt /tmp/requirements.txt
|
| 40 |
+
COPY ./third_party/libero/requirements.txt /tmp/requirements-libero.txt
|
| 41 |
+
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
| 42 |
+
|
| 43 |
+
# Install python dependencies.
|
| 44 |
+
RUN uv venv --python 3.8 $UV_PROJECT_ENVIRONMENT
|
| 45 |
+
RUN uv pip sync /tmp/requirements.txt /tmp/requirements-libero.txt /tmp/openpi-client/pyproject.toml --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match
|
| 46 |
+
ENV PYTHONPATH=/app:/app/packages/openpi-client/src:/app/third_party/libero
|
| 47 |
+
|
| 48 |
+
# Create a default config file to avoid an input prompt from LIBERO's init script.
|
| 49 |
+
# https://github.com/Lifelong-Robot-Learning/LIBERO/blob/master/libero/libero/__init__.py
|
| 50 |
+
ENV LIBERO_CONFIG_PATH=/tmp/libero
|
| 51 |
+
RUN mkdir -p /tmp/libero && cat <<'EOF' > /tmp/libero/config.yaml
|
| 52 |
+
benchmark_root: /app/third_party/libero/libero/libero
|
| 53 |
+
bddl_files: /app/third_party/libero/libero/libero/bddl_files
|
| 54 |
+
init_states: /app/third_party/libero/libero/libero/init_files
|
| 55 |
+
datasets: /app/third_party/libero/libero/datasets
|
| 56 |
+
assets: /app/third_party/libero/libero/libero/assets
|
| 57 |
+
EOF
|
| 58 |
+
|
| 59 |
+
CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/libero/main.py $CLIENT_ARGS"]
|
capvector-pi05/examples/libero/README.md
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LIBERO Benchmark
|
| 2 |
+
|
| 3 |
+
This example runs the LIBERO benchmark: https://github.com/Lifelong-Robot-Learning/LIBERO
|
| 4 |
+
|
| 5 |
+
Note: When updating requirements.txt in this directory, there is an additional flag `--extra-index-url https://download.pytorch.org/whl/cu113` that must be added to the `uv pip compile` command.
|
| 6 |
+
|
| 7 |
+
This example requires git submodules to be initialized. Don't forget to run:
|
| 8 |
+
|
| 9 |
+
```bash
|
| 10 |
+
git submodule update --init --recursive
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
## With Docker (recommended)
|
| 14 |
+
|
| 15 |
+
```bash
|
| 16 |
+
# Grant access to the X11 server:
|
| 17 |
+
sudo xhost +local:docker
|
| 18 |
+
|
| 19 |
+
# To run with the default checkpoint and task suite:
|
| 20 |
+
SERVER_ARGS="--env LIBERO" docker compose -f examples/libero/compose.yml up --build
|
| 21 |
+
|
| 22 |
+
# To run with glx for Mujoco instead (use this if you have egl errors):
|
| 23 |
+
MUJOCO_GL=glx SERVER_ARGS="--env LIBERO" docker compose -f examples/libero/compose.yml up --build
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
You can customize the loaded checkpoint by providing additional `SERVER_ARGS` (see `scripts/serve_policy.py`), and the LIBERO task suite by providing additional `CLIENT_ARGS` (see `examples/libero/main.py`).
|
| 27 |
+
For example:
|
| 28 |
+
|
| 29 |
+
```bash
|
| 30 |
+
# To load a custom checkpoint (located in the top-level openpi/ directory):
|
| 31 |
+
export SERVER_ARGS="--env LIBERO policy:checkpoint --policy.config pi05_libero --policy.dir ./my_custom_checkpoint"
|
| 32 |
+
|
| 33 |
+
# To run the libero_10 task suite:
|
| 34 |
+
export CLIENT_ARGS="--args.task-suite-name libero_10"
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
## Without Docker (not recommended)
|
| 38 |
+
|
| 39 |
+
Terminal window 1:
|
| 40 |
+
|
| 41 |
+
```bash
|
| 42 |
+
# Create virtual environment
|
| 43 |
+
uv venv --python 3.8 examples/libero/.venv
|
| 44 |
+
source examples/libero/.venv/bin/activate
|
| 45 |
+
uv pip sync examples/libero/requirements.txt third_party/libero/requirements.txt --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match
|
| 46 |
+
uv pip install -e packages/openpi-client
|
| 47 |
+
uv pip install -e third_party/libero
|
| 48 |
+
export PYTHONPATH=$PYTHONPATH:$PWD/third_party/libero
|
| 49 |
+
|
| 50 |
+
# Run the simulation
|
| 51 |
+
python examples/libero/main.py
|
| 52 |
+
|
| 53 |
+
# To run with glx for Mujoco instead (use this if you have egl errors):
|
| 54 |
+
MUJOCO_GL=glx python examples/libero/main.py
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
Terminal window 2:
|
| 58 |
+
|
| 59 |
+
```bash
|
| 60 |
+
# Run the server
|
| 61 |
+
uv run scripts/serve_policy.py --env LIBERO
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
## Results
|
| 65 |
+
|
| 66 |
+
If you want to reproduce the following numbers, you can evaluate the checkpoint at `gs://openpi-assets/checkpoints/pi05_libero/`. This
|
| 67 |
+
checkpoint was trained in openpi with the `pi05_libero` config.
|
| 68 |
+
|
| 69 |
+
| Model | Libero Spatial | Libero Object | Libero Goal | Libero 10 | Average |
|
| 70 |
+
|-------|---------------|---------------|-------------|-----------|---------|
|
| 71 |
+
| π0.5 @ 30k (finetuned) | 98.8 | 98.2 | 98.0 | 92.4 | 96.85
|
capvector-pi05/examples/libero/main.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import dataclasses
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
import pathlib
|
| 6 |
+
|
| 7 |
+
import imageio
|
| 8 |
+
from libero.libero import benchmark
|
| 9 |
+
from libero.libero import get_libero_path
|
| 10 |
+
from libero.libero.envs import OffScreenRenderEnv
|
| 11 |
+
import numpy as np
|
| 12 |
+
from openpi_client import image_tools
|
| 13 |
+
from openpi_client import websocket_client_policy as _websocket_client_policy
|
| 14 |
+
import tqdm
|
| 15 |
+
import tyro
|
| 16 |
+
|
| 17 |
+
LIBERO_DUMMY_ACTION = [0.0] * 6 + [-1.0]
|
| 18 |
+
LIBERO_ENV_RESOLUTION = 256 # resolution used to render training data
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclasses.dataclass
|
| 22 |
+
class Args:
|
| 23 |
+
#################################################################################################################
|
| 24 |
+
# Model server parameters
|
| 25 |
+
#################################################################################################################
|
| 26 |
+
host: str = "0.0.0.0"
|
| 27 |
+
port: int = 8000
|
| 28 |
+
resize_size: int = 224
|
| 29 |
+
replan_steps: int = 5
|
| 30 |
+
|
| 31 |
+
#################################################################################################################
|
| 32 |
+
# LIBERO environment-specific parameters
|
| 33 |
+
#################################################################################################################
|
| 34 |
+
task_suite_name: str = (
|
| 35 |
+
"libero_spatial" # Task suite. Options: libero_spatial, libero_object, libero_goal, libero_10, libero_90
|
| 36 |
+
)
|
| 37 |
+
num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize i n sim
|
| 38 |
+
num_trials_per_task: int = 50 # Number of rollouts per task
|
| 39 |
+
|
| 40 |
+
#################################################################################################################
|
| 41 |
+
# Utils
|
| 42 |
+
#################################################################################################################
|
| 43 |
+
video_out_path: str = "data/libero/videos" # Path to save videos
|
| 44 |
+
|
| 45 |
+
seed: int = 7 # Random Seed (for reproducibility)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def eval_libero(args: Args) -> None:
|
| 49 |
+
# Set random seed
|
| 50 |
+
np.random.seed(args.seed)
|
| 51 |
+
|
| 52 |
+
# Initialize LIBERO task suite
|
| 53 |
+
benchmark_dict = benchmark.get_benchmark_dict()
|
| 54 |
+
task_suite = benchmark_dict[args.task_suite_name]()
|
| 55 |
+
num_tasks_in_suite = task_suite.n_tasks
|
| 56 |
+
logging.info(f"Task suite: {args.task_suite_name}")
|
| 57 |
+
|
| 58 |
+
pathlib.Path(args.video_out_path).mkdir(parents=True, exist_ok=True)
|
| 59 |
+
|
| 60 |
+
if args.task_suite_name == "libero_spatial":
|
| 61 |
+
max_steps = 220 # longest training demo has 193 steps
|
| 62 |
+
elif args.task_suite_name == "libero_object":
|
| 63 |
+
max_steps = 280 # longest training demo has 254 steps
|
| 64 |
+
elif args.task_suite_name == "libero_goal":
|
| 65 |
+
max_steps = 300 # longest training demo has 270 steps
|
| 66 |
+
elif args.task_suite_name == "libero_10":
|
| 67 |
+
max_steps = 520 # longest training demo has 505 steps
|
| 68 |
+
elif args.task_suite_name == "libero_90":
|
| 69 |
+
max_steps = 400 # longest training demo has 373 steps
|
| 70 |
+
else:
|
| 71 |
+
raise ValueError(f"Unknown task suite: {args.task_suite_name}")
|
| 72 |
+
|
| 73 |
+
client = _websocket_client_policy.WebsocketClientPolicy(args.host, args.port)
|
| 74 |
+
|
| 75 |
+
# Start evaluation
|
| 76 |
+
total_episodes, total_successes = 0, 0
|
| 77 |
+
for task_id in tqdm.tqdm(range(num_tasks_in_suite)):
|
| 78 |
+
# Get task
|
| 79 |
+
task = task_suite.get_task(task_id)
|
| 80 |
+
|
| 81 |
+
# Get default LIBERO initial states
|
| 82 |
+
initial_states = task_suite.get_task_init_states(task_id)
|
| 83 |
+
|
| 84 |
+
# Initialize LIBERO environment and task description
|
| 85 |
+
env, task_description = _get_libero_env(task, LIBERO_ENV_RESOLUTION, args.seed)
|
| 86 |
+
|
| 87 |
+
# Start episodes
|
| 88 |
+
task_episodes, task_successes = 0, 0
|
| 89 |
+
for episode_idx in tqdm.tqdm(range(args.num_trials_per_task)):
|
| 90 |
+
logging.info(f"\nTask: {task_description}")
|
| 91 |
+
|
| 92 |
+
# Reset environment
|
| 93 |
+
env.reset()
|
| 94 |
+
action_plan = collections.deque()
|
| 95 |
+
|
| 96 |
+
# Set initial states
|
| 97 |
+
obs = env.set_init_state(initial_states[episode_idx])
|
| 98 |
+
|
| 99 |
+
# Setup
|
| 100 |
+
t = 0
|
| 101 |
+
replay_images = []
|
| 102 |
+
|
| 103 |
+
logging.info(f"Starting episode {task_episodes+1}...")
|
| 104 |
+
while t < max_steps + args.num_steps_wait:
|
| 105 |
+
try:
|
| 106 |
+
# IMPORTANT: Do nothing for the first few timesteps because the simulator drops objects
|
| 107 |
+
# and we need to wait for them to fall
|
| 108 |
+
if t < args.num_steps_wait:
|
| 109 |
+
obs, reward, done, info = env.step(LIBERO_DUMMY_ACTION)
|
| 110 |
+
t += 1
|
| 111 |
+
continue
|
| 112 |
+
|
| 113 |
+
# Get preprocessed image
|
| 114 |
+
# IMPORTANT: rotate 180 degrees to match train preprocessing
|
| 115 |
+
img = np.ascontiguousarray(obs["agentview_image"][::-1, ::-1])
|
| 116 |
+
wrist_img = np.ascontiguousarray(obs["robot0_eye_in_hand_image"][::-1, ::-1])
|
| 117 |
+
img = image_tools.convert_to_uint8(
|
| 118 |
+
image_tools.resize_with_pad(img, args.resize_size, args.resize_size)
|
| 119 |
+
)
|
| 120 |
+
wrist_img = image_tools.convert_to_uint8(
|
| 121 |
+
image_tools.resize_with_pad(wrist_img, args.resize_size, args.resize_size)
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Save preprocessed image for replay video
|
| 125 |
+
replay_images.append(img)
|
| 126 |
+
|
| 127 |
+
if not action_plan:
|
| 128 |
+
# Finished executing previous action chunk -- compute new chunk
|
| 129 |
+
# Prepare observations dict
|
| 130 |
+
element = {
|
| 131 |
+
"observation/image": img,
|
| 132 |
+
"observation/wrist_image": wrist_img,
|
| 133 |
+
"observation/state": np.concatenate(
|
| 134 |
+
(
|
| 135 |
+
obs["robot0_eef_pos"],
|
| 136 |
+
_quat2axisangle(obs["robot0_eef_quat"]),
|
| 137 |
+
obs["robot0_gripper_qpos"],
|
| 138 |
+
)
|
| 139 |
+
),
|
| 140 |
+
"prompt": str(task_description),
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
# Query model to get action
|
| 144 |
+
action_chunk = client.infer(element)["actions"]
|
| 145 |
+
assert (
|
| 146 |
+
len(action_chunk) >= args.replan_steps
|
| 147 |
+
), f"We want to replan every {args.replan_steps} steps, but policy only predicts {len(action_chunk)} steps."
|
| 148 |
+
action_plan.extend(action_chunk[: args.replan_steps])
|
| 149 |
+
|
| 150 |
+
action = action_plan.popleft()
|
| 151 |
+
|
| 152 |
+
# Execute action in environment
|
| 153 |
+
obs, reward, done, info = env.step(action.tolist())
|
| 154 |
+
if done:
|
| 155 |
+
task_successes += 1
|
| 156 |
+
total_successes += 1
|
| 157 |
+
break
|
| 158 |
+
t += 1
|
| 159 |
+
|
| 160 |
+
except Exception as e:
|
| 161 |
+
logging.error(f"Caught exception: {e}")
|
| 162 |
+
break
|
| 163 |
+
|
| 164 |
+
task_episodes += 1
|
| 165 |
+
total_episodes += 1
|
| 166 |
+
|
| 167 |
+
# Save a replay video of the episode
|
| 168 |
+
suffix = "success" if done else "failure"
|
| 169 |
+
task_segment = task_description.replace(" ", "_")
|
| 170 |
+
imageio.mimwrite(
|
| 171 |
+
pathlib.Path(args.video_out_path) / f"rollout_{task_segment}_{suffix}.mp4",
|
| 172 |
+
[np.asarray(x) for x in replay_images],
|
| 173 |
+
fps=10,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# Log current results
|
| 177 |
+
logging.info(f"Success: {done}")
|
| 178 |
+
logging.info(f"# episodes completed so far: {total_episodes}")
|
| 179 |
+
logging.info(f"# successes: {total_successes} ({total_successes / total_episodes * 100:.1f}%)")
|
| 180 |
+
|
| 181 |
+
# Log final results
|
| 182 |
+
logging.info(f"Current task success rate: {float(task_successes) / float(task_episodes)}")
|
| 183 |
+
logging.info(f"Current total success rate: {float(total_successes) / float(total_episodes)}")
|
| 184 |
+
|
| 185 |
+
logging.info(f"Total success rate: {float(total_successes) / float(total_episodes)}")
|
| 186 |
+
logging.info(f"Total episodes: {total_episodes}")
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def _get_libero_env(task, resolution, seed):
|
| 190 |
+
"""Initializes and returns the LIBERO environment, along with the task description."""
|
| 191 |
+
task_description = task.language
|
| 192 |
+
task_bddl_file = pathlib.Path(get_libero_path("bddl_files")) / task.problem_folder / task.bddl_file
|
| 193 |
+
env_args = {"bddl_file_name": task_bddl_file, "camera_heights": resolution, "camera_widths": resolution}
|
| 194 |
+
env = OffScreenRenderEnv(**env_args)
|
| 195 |
+
env.seed(seed) # IMPORTANT: seed seems to affect object positions even when using fixed initial state
|
| 196 |
+
return env, task_description
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def _quat2axisangle(quat):
|
| 200 |
+
"""
|
| 201 |
+
Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
|
| 202 |
+
"""
|
| 203 |
+
# clip quaternion
|
| 204 |
+
if quat[3] > 1.0:
|
| 205 |
+
quat[3] = 1.0
|
| 206 |
+
elif quat[3] < -1.0:
|
| 207 |
+
quat[3] = -1.0
|
| 208 |
+
|
| 209 |
+
den = np.sqrt(1.0 - quat[3] * quat[3])
|
| 210 |
+
if math.isclose(den, 0.0):
|
| 211 |
+
# This is (close to) a zero degree rotation, immediately return
|
| 212 |
+
return np.zeros(3)
|
| 213 |
+
|
| 214 |
+
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
if __name__ == "__main__":
|
| 218 |
+
logging.basicConfig(level=logging.INFO)
|
| 219 |
+
tyro.cli(eval_libero)
|
capvector-pi05/examples/libero/requirements.in
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
imageio[ffmpeg]
|
| 2 |
+
numpy==1.22.4
|
| 3 |
+
tqdm
|
| 4 |
+
tyro
|
| 5 |
+
PyYaml
|
| 6 |
+
opencv-python==4.6.0.66
|
| 7 |
+
torch==1.11.0+cu113
|
| 8 |
+
torchvision==0.12.0+cu113
|
| 9 |
+
torchaudio==0.11.0+cu113
|
| 10 |
+
robosuite==1.4.1
|
| 11 |
+
matplotlib==3.5.3
|
capvector-pi05/examples/libero/requirements.txt
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file was autogenerated by uv via the following command:
|
| 2 |
+
# uv pip compile examples/libero/requirements.in -o examples/libero/requirements.txt --python-version 3.8 --index-strategy=unsafe-best-match
|
| 3 |
+
absl-py==2.1.0
|
| 4 |
+
# via mujoco
|
| 5 |
+
certifi==2024.12.14
|
| 6 |
+
# via requests
|
| 7 |
+
charset-normalizer==3.4.0
|
| 8 |
+
# via requests
|
| 9 |
+
cycler==0.12.1
|
| 10 |
+
# via matplotlib
|
| 11 |
+
docstring-parser==0.16
|
| 12 |
+
# via tyro
|
| 13 |
+
etils==1.3.0
|
| 14 |
+
# via mujoco
|
| 15 |
+
eval-type-backport==0.2.0
|
| 16 |
+
# via tyro
|
| 17 |
+
evdev==1.7.1
|
| 18 |
+
# via pynput
|
| 19 |
+
fonttools==4.55.3
|
| 20 |
+
# via matplotlib
|
| 21 |
+
glfw==1.12.0
|
| 22 |
+
# via mujoco
|
| 23 |
+
idna==3.10
|
| 24 |
+
# via requests
|
| 25 |
+
imageio==2.35.1
|
| 26 |
+
# via -r examples/libero/requirements.in
|
| 27 |
+
imageio-ffmpeg==0.5.1
|
| 28 |
+
# via imageio
|
| 29 |
+
importlib-metadata==8.5.0
|
| 30 |
+
# via typeguard
|
| 31 |
+
importlib-resources==6.4.5
|
| 32 |
+
# via etils
|
| 33 |
+
kiwisolver==1.4.7
|
| 34 |
+
# via matplotlib
|
| 35 |
+
llvmlite==0.36.0
|
| 36 |
+
# via numba
|
| 37 |
+
markdown-it-py==3.0.0
|
| 38 |
+
# via rich
|
| 39 |
+
matplotlib==3.5.3
|
| 40 |
+
# via -r examples/libero/requirements.in
|
| 41 |
+
mdurl==0.1.2
|
| 42 |
+
# via markdown-it-py
|
| 43 |
+
mujoco==3.2.3
|
| 44 |
+
# via robosuite
|
| 45 |
+
numba==0.53.1
|
| 46 |
+
# via robosuite
|
| 47 |
+
numpy==1.22.4
|
| 48 |
+
# via
|
| 49 |
+
# -r examples/libero/requirements.in
|
| 50 |
+
# imageio
|
| 51 |
+
# matplotlib
|
| 52 |
+
# mujoco
|
| 53 |
+
# numba
|
| 54 |
+
# opencv-python
|
| 55 |
+
# robosuite
|
| 56 |
+
# scipy
|
| 57 |
+
# torchvision
|
| 58 |
+
opencv-python==4.6.0.66
|
| 59 |
+
# via
|
| 60 |
+
# -r examples/libero/requirements.in
|
| 61 |
+
# robosuite
|
| 62 |
+
packaging==24.2
|
| 63 |
+
# via matplotlib
|
| 64 |
+
pillow==10.4.0
|
| 65 |
+
# via
|
| 66 |
+
# imageio
|
| 67 |
+
# matplotlib
|
| 68 |
+
# robosuite
|
| 69 |
+
# torchvision
|
| 70 |
+
psutil==6.1.0
|
| 71 |
+
# via imageio
|
| 72 |
+
pygments==2.18.0
|
| 73 |
+
# via rich
|
| 74 |
+
pynput==1.7.7
|
| 75 |
+
# via robosuite
|
| 76 |
+
pyopengl==3.1.7
|
| 77 |
+
# via mujoco
|
| 78 |
+
pyparsing==3.1.4
|
| 79 |
+
# via matplotlib
|
| 80 |
+
python-dateutil==2.9.0.post0
|
| 81 |
+
# via matplotlib
|
| 82 |
+
python-xlib==0.33
|
| 83 |
+
# via pynput
|
| 84 |
+
pyyaml==6.0.2
|
| 85 |
+
# via -r examples/libero/requirements.in
|
| 86 |
+
requests==2.32.3
|
| 87 |
+
# via torchvision
|
| 88 |
+
rich==13.9.4
|
| 89 |
+
# via tyro
|
| 90 |
+
robosuite==1.4.1
|
| 91 |
+
# via -r examples/libero/requirements.in
|
| 92 |
+
scipy==1.10.1
|
| 93 |
+
# via robosuite
|
| 94 |
+
setuptools==75.3.0
|
| 95 |
+
# via
|
| 96 |
+
# imageio-ffmpeg
|
| 97 |
+
# numba
|
| 98 |
+
shtab==1.7.1
|
| 99 |
+
# via tyro
|
| 100 |
+
six==1.17.0
|
| 101 |
+
# via
|
| 102 |
+
# pynput
|
| 103 |
+
# python-dateutil
|
| 104 |
+
# python-xlib
|
| 105 |
+
termcolor==2.4.0
|
| 106 |
+
# via robosuite
|
| 107 |
+
torch==1.11.0+cu113
|
| 108 |
+
# via
|
| 109 |
+
# -r examples/libero/requirements.in
|
| 110 |
+
# torchaudio
|
| 111 |
+
# torchvision
|
| 112 |
+
torchaudio==0.11.0+cu113
|
| 113 |
+
# via -r examples/libero/requirements.in
|
| 114 |
+
torchvision==0.12.0+cu113
|
| 115 |
+
# via -r examples/libero/requirements.in
|
| 116 |
+
tqdm==4.67.1
|
| 117 |
+
# via -r examples/libero/requirements.in
|
| 118 |
+
typeguard==4.4.0
|
| 119 |
+
# via tyro
|
| 120 |
+
typing-extensions==4.12.2
|
| 121 |
+
# via
|
| 122 |
+
# etils
|
| 123 |
+
# rich
|
| 124 |
+
# torch
|
| 125 |
+
# torchvision
|
| 126 |
+
# typeguard
|
| 127 |
+
# tyro
|
| 128 |
+
tyro==0.9.2
|
| 129 |
+
# via -r examples/libero/requirements.in
|
| 130 |
+
urllib3==2.2.3
|
| 131 |
+
# via requests
|
| 132 |
+
zipp==3.20.2
|
| 133 |
+
# via
|
| 134 |
+
# etils
|
| 135 |
+
# importlib-metadata
|
| 136 |
+
# importlib-resources
|
capvector-pi05/examples/simple_client/Dockerfile
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dockerfile for the simple client.
|
| 2 |
+
|
| 3 |
+
# Build the container:
|
| 4 |
+
# docker build . -t simple_client -f examples/simple_client/Dockerfile
|
| 5 |
+
|
| 6 |
+
# Run the container:
|
| 7 |
+
# docker run --rm -it --network=host -v .:/app simple_client /bin/bash
|
| 8 |
+
|
| 9 |
+
FROM python:3.7-slim
|
| 10 |
+
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
|
| 11 |
+
|
| 12 |
+
WORKDIR /app
|
| 13 |
+
|
| 14 |
+
# Copy from the cache instead of linking since it's a mounted volume
|
| 15 |
+
ENV UV_LINK_MODE=copy
|
| 16 |
+
|
| 17 |
+
# Write the virtual environment outside of the project directory so it doesn't
|
| 18 |
+
# leak out of the container when we mount the application code.
|
| 19 |
+
ENV UV_PROJECT_ENVIRONMENT=/.venv
|
| 20 |
+
|
| 21 |
+
# Copy the requirements files so we can install dependencies.
|
| 22 |
+
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
|
| 23 |
+
# This strategy is best for development-style usage.
|
| 24 |
+
COPY ./examples/simple_client/requirements.txt /tmp/requirements.txt
|
| 25 |
+
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
| 26 |
+
|
| 27 |
+
# Install python dependencies.
|
| 28 |
+
RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
|
| 29 |
+
RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
|
| 30 |
+
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
|
| 31 |
+
|
| 32 |
+
CMD /bin/bash -c "source /.venv/bin/activate && python examples/simple_client/main.py $SERVER_ARGS"
|
capvector-pi05/examples/simple_client/README.md
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Simple Client
|
| 2 |
+
|
| 3 |
+
A minimal client that sends observations to the server and prints the inference rate.
|
| 4 |
+
|
| 5 |
+
You can specify which runtime environment to use using the `--env` flag. You can see the available options by running:
|
| 6 |
+
|
| 7 |
+
```bash
|
| 8 |
+
uv run examples/simple_client/main.py --help
|
| 9 |
+
```
|
| 10 |
+
|
| 11 |
+
## With Docker
|
| 12 |
+
|
| 13 |
+
```bash
|
| 14 |
+
export SERVER_ARGS="--env ALOHA_SIM"
|
| 15 |
+
docker compose -f examples/simple_client/compose.yml up --build
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
## Without Docker
|
| 19 |
+
|
| 20 |
+
Terminal window 1:
|
| 21 |
+
|
| 22 |
+
```bash
|
| 23 |
+
uv run examples/simple_client/main.py --env DROID
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
Terminal window 2:
|
| 27 |
+
|
| 28 |
+
```bash
|
| 29 |
+
uv run scripts/serve_policy.py --env DROID
|
| 30 |
+
```
|
capvector-pi05/examples/simple_client/compose.yml
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Run with:
|
| 2 |
+
# docker compose -f examples/simple_client/compose.yml up --build
|
| 3 |
+
services:
|
| 4 |
+
runtime:
|
| 5 |
+
image: simple_client
|
| 6 |
+
depends_on:
|
| 7 |
+
- openpi_server
|
| 8 |
+
build:
|
| 9 |
+
context: ../..
|
| 10 |
+
dockerfile: examples/simple_client/Dockerfile
|
| 11 |
+
init: true
|
| 12 |
+
tty: true
|
| 13 |
+
network_mode: host
|
| 14 |
+
volumes:
|
| 15 |
+
- $PWD:/app
|
| 16 |
+
environment:
|
| 17 |
+
- SERVER_ARGS
|
| 18 |
+
|
| 19 |
+
openpi_server:
|
| 20 |
+
image: openpi_server
|
| 21 |
+
build:
|
| 22 |
+
context: ../..
|
| 23 |
+
dockerfile: scripts/docker/serve_policy.Dockerfile
|
| 24 |
+
init: true
|
| 25 |
+
tty: true
|
| 26 |
+
network_mode: host
|
| 27 |
+
volumes:
|
| 28 |
+
- $PWD:/app
|
| 29 |
+
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
|
| 30 |
+
environment:
|
| 31 |
+
- SERVER_ARGS
|
| 32 |
+
- OPENPI_DATA_HOME=/openpi_assets
|
| 33 |
+
- IS_DOCKER=true
|
| 34 |
+
|
| 35 |
+
# Comment out this block if not running on a machine with GPUs.
|
| 36 |
+
deploy:
|
| 37 |
+
resources:
|
| 38 |
+
reservations:
|
| 39 |
+
devices:
|
| 40 |
+
- driver: nvidia
|
| 41 |
+
count: 1
|
| 42 |
+
capabilities: [gpu]
|
capvector-pi05/examples/simple_client/main.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import enum
|
| 3 |
+
import logging
|
| 4 |
+
import pathlib
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from openpi_client import websocket_client_policy as _websocket_client_policy
|
| 9 |
+
import polars as pl
|
| 10 |
+
import rich
|
| 11 |
+
import tqdm
|
| 12 |
+
import tyro
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class EnvMode(enum.Enum):
|
| 18 |
+
"""Supported environments."""
|
| 19 |
+
|
| 20 |
+
ALOHA = "aloha"
|
| 21 |
+
ALOHA_SIM = "aloha_sim"
|
| 22 |
+
DROID = "droid"
|
| 23 |
+
LIBERO = "libero"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclasses.dataclass
|
| 27 |
+
class Args:
|
| 28 |
+
"""Command line arguments."""
|
| 29 |
+
|
| 30 |
+
# Host and port to connect to the server.
|
| 31 |
+
host: str = "0.0.0.0"
|
| 32 |
+
# Port to connect to the server. If None, the server will use the default port.
|
| 33 |
+
port: int | None = 8000
|
| 34 |
+
# API key to use for the server.
|
| 35 |
+
api_key: str | None = None
|
| 36 |
+
# Number of steps to run the policy for.
|
| 37 |
+
num_steps: int = 20
|
| 38 |
+
# Path to save the timings to a parquet file. (e.g., timing.parquet)
|
| 39 |
+
timing_file: pathlib.Path | None = None
|
| 40 |
+
# Environment to run the policy in.
|
| 41 |
+
env: EnvMode = EnvMode.ALOHA_SIM
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class TimingRecorder:
|
| 45 |
+
"""Records timing measurements for different keys."""
|
| 46 |
+
|
| 47 |
+
def __init__(self) -> None:
|
| 48 |
+
self._timings: dict[str, list[float]] = {}
|
| 49 |
+
|
| 50 |
+
def record(self, key: str, time_ms: float) -> None:
|
| 51 |
+
"""Record a timing measurement for the given key."""
|
| 52 |
+
if key not in self._timings:
|
| 53 |
+
self._timings[key] = []
|
| 54 |
+
self._timings[key].append(time_ms)
|
| 55 |
+
|
| 56 |
+
def get_stats(self, key: str) -> dict[str, float]:
|
| 57 |
+
"""Get statistics for the given key."""
|
| 58 |
+
times = self._timings[key]
|
| 59 |
+
return {
|
| 60 |
+
"mean": float(np.mean(times)),
|
| 61 |
+
"std": float(np.std(times)),
|
| 62 |
+
"p25": float(np.quantile(times, 0.25)),
|
| 63 |
+
"p50": float(np.quantile(times, 0.50)),
|
| 64 |
+
"p75": float(np.quantile(times, 0.75)),
|
| 65 |
+
"p90": float(np.quantile(times, 0.90)),
|
| 66 |
+
"p95": float(np.quantile(times, 0.95)),
|
| 67 |
+
"p99": float(np.quantile(times, 0.99)),
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
def print_all_stats(self) -> None:
|
| 71 |
+
"""Print statistics for all keys in a concise format."""
|
| 72 |
+
|
| 73 |
+
table = rich.table.Table(
|
| 74 |
+
title="[bold blue]Timing Statistics[/bold blue]",
|
| 75 |
+
show_header=True,
|
| 76 |
+
header_style="bold white",
|
| 77 |
+
border_style="blue",
|
| 78 |
+
title_justify="center",
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# Add metric column with custom styling
|
| 82 |
+
table.add_column("Metric", style="cyan", justify="left", no_wrap=True)
|
| 83 |
+
|
| 84 |
+
# Add statistical columns with consistent styling
|
| 85 |
+
stat_columns = [
|
| 86 |
+
("Mean", "yellow", "mean"),
|
| 87 |
+
("Std", "yellow", "std"),
|
| 88 |
+
("P25", "magenta", "p25"),
|
| 89 |
+
("P50", "magenta", "p50"),
|
| 90 |
+
("P75", "magenta", "p75"),
|
| 91 |
+
("P90", "magenta", "p90"),
|
| 92 |
+
("P95", "magenta", "p95"),
|
| 93 |
+
("P99", "magenta", "p99"),
|
| 94 |
+
]
|
| 95 |
+
|
| 96 |
+
for name, style, _ in stat_columns:
|
| 97 |
+
table.add_column(name, justify="right", style=style, no_wrap=True)
|
| 98 |
+
|
| 99 |
+
# Add rows for each metric with formatted values
|
| 100 |
+
for key in sorted(self._timings.keys()):
|
| 101 |
+
stats = self.get_stats(key)
|
| 102 |
+
values = [f"{stats[key]:.1f}" for _, _, key in stat_columns]
|
| 103 |
+
table.add_row(key, *values)
|
| 104 |
+
|
| 105 |
+
# Print with custom console settings
|
| 106 |
+
console = rich.console.Console(width=None, highlight=True)
|
| 107 |
+
console.print(table)
|
| 108 |
+
|
| 109 |
+
def write_parquet(self, path: pathlib.Path) -> None:
|
| 110 |
+
"""Save the timings to a parquet file."""
|
| 111 |
+
logger.info(f"Writing timings to {path}")
|
| 112 |
+
frame = pl.DataFrame(self._timings)
|
| 113 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 114 |
+
frame.write_parquet(path)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def main(args: Args) -> None:
|
| 118 |
+
obs_fn = {
|
| 119 |
+
EnvMode.ALOHA: _random_observation_aloha,
|
| 120 |
+
EnvMode.ALOHA_SIM: _random_observation_aloha,
|
| 121 |
+
EnvMode.DROID: _random_observation_droid,
|
| 122 |
+
EnvMode.LIBERO: _random_observation_libero,
|
| 123 |
+
}[args.env]
|
| 124 |
+
|
| 125 |
+
policy = _websocket_client_policy.WebsocketClientPolicy(
|
| 126 |
+
host=args.host,
|
| 127 |
+
port=args.port,
|
| 128 |
+
api_key=args.api_key,
|
| 129 |
+
)
|
| 130 |
+
logger.info(f"Server metadata: {policy.get_server_metadata()}")
|
| 131 |
+
|
| 132 |
+
# Send a few observations to make sure the model is loaded.
|
| 133 |
+
for _ in range(2):
|
| 134 |
+
policy.infer(obs_fn())
|
| 135 |
+
|
| 136 |
+
timing_recorder = TimingRecorder()
|
| 137 |
+
|
| 138 |
+
for _ in tqdm.trange(args.num_steps, desc="Running policy"):
|
| 139 |
+
inference_start = time.time()
|
| 140 |
+
action = policy.infer(obs_fn())
|
| 141 |
+
timing_recorder.record("client_infer_ms", 1000 * (time.time() - inference_start))
|
| 142 |
+
for key, value in action.get("server_timing", {}).items():
|
| 143 |
+
timing_recorder.record(f"server_{key}", value)
|
| 144 |
+
for key, value in action.get("policy_timing", {}).items():
|
| 145 |
+
timing_recorder.record(f"policy_{key}", value)
|
| 146 |
+
|
| 147 |
+
timing_recorder.print_all_stats()
|
| 148 |
+
|
| 149 |
+
if args.timing_file is not None:
|
| 150 |
+
timing_recorder.write_parquet(args.timing_file)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def _random_observation_aloha() -> dict:
|
| 154 |
+
return {
|
| 155 |
+
"state": np.ones((14,)),
|
| 156 |
+
"images": {
|
| 157 |
+
"cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
| 158 |
+
"cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
| 159 |
+
"cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
| 160 |
+
"cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
| 161 |
+
},
|
| 162 |
+
"prompt": "do something",
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def _random_observation_droid() -> dict:
|
| 167 |
+
return {
|
| 168 |
+
"observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
| 169 |
+
"observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
| 170 |
+
"observation/joint_position": np.random.rand(7),
|
| 171 |
+
"observation/gripper_position": np.random.rand(1),
|
| 172 |
+
"prompt": "do something",
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def _random_observation_libero() -> dict:
|
| 177 |
+
return {
|
| 178 |
+
"observation/state": np.random.rand(8),
|
| 179 |
+
"observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
| 180 |
+
"observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
| 181 |
+
"prompt": "do something",
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
if __name__ == "__main__":
|
| 186 |
+
logging.basicConfig(level=logging.INFO)
|
| 187 |
+
main(tyro.cli(Args))
|
capvector-pi05/examples/simple_client/requirements.in
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy>=1.22.4,<2.0.0
|
| 2 |
+
rich
|
| 3 |
+
tqdm
|
| 4 |
+
tyro
|
| 5 |
+
polars
|
capvector-pi05/examples/simple_client/requirements.txt
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file was autogenerated by uv via the following command:
|
| 2 |
+
# uv pip compile examples/simple_client/requirements.in -o examples/simple_client/requirements.txt --python-version 3.11.9
|
| 3 |
+
docstring-parser==0.16
|
| 4 |
+
# via tyro
|
| 5 |
+
markdown-it-py==3.0.0
|
| 6 |
+
# via rich
|
| 7 |
+
mdurl==0.1.2
|
| 8 |
+
# via markdown-it-py
|
| 9 |
+
numpy==1.26.4
|
| 10 |
+
# via -r examples/simple_client/requirements.in
|
| 11 |
+
polars==1.30.0
|
| 12 |
+
# via -r examples/simple_client/requirements.in
|
| 13 |
+
pygments==2.19.1
|
| 14 |
+
# via rich
|
| 15 |
+
rich==14.0.0
|
| 16 |
+
# via
|
| 17 |
+
# -r examples/simple_client/requirements.in
|
| 18 |
+
# tyro
|
| 19 |
+
shtab==1.7.2
|
| 20 |
+
# via tyro
|
| 21 |
+
tqdm==4.67.1
|
| 22 |
+
# via -r examples/simple_client/requirements.in
|
| 23 |
+
typeguard==4.4.2
|
| 24 |
+
# via tyro
|
| 25 |
+
typing-extensions==4.13.2
|
| 26 |
+
# via
|
| 27 |
+
# typeguard
|
| 28 |
+
# tyro
|
| 29 |
+
tyro==0.9.22
|
| 30 |
+
# via -r examples/simple_client/requirements.in
|
capvector-pi05/examples/ur5/README.md
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# UR5 Example
|
| 2 |
+
|
| 3 |
+
Below we provide an outline of how to implement the key components mentioned in the "Finetune on your data" section of the [README](../README.md) for finetuning on UR5 datasets.
|
| 4 |
+
|
| 5 |
+
First, we will define the `UR5Inputs` and `UR5Outputs` classes, which map the UR5 environment to the model and vice versa. Check the corresponding files in `src/openpi/policies/libero_policy.py` for comments explaining each line.
|
| 6 |
+
|
| 7 |
+
```python
|
| 8 |
+
|
| 9 |
+
@dataclasses.dataclass(frozen=True)
|
| 10 |
+
class UR5Inputs(transforms.DataTransformFn):
|
| 11 |
+
|
| 12 |
+
model_type: _model.ModelType = _model.ModelType.PI0
|
| 13 |
+
|
| 14 |
+
def __call__(self, data: dict) -> dict:
|
| 15 |
+
# First, concatenate the joints and gripper into the state vector.
|
| 16 |
+
state = np.concatenate([data["joints"], data["gripper"]])
|
| 17 |
+
|
| 18 |
+
# Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically
|
| 19 |
+
# stores as float32 (C,H,W), gets skipped for policy inference.
|
| 20 |
+
base_image = _parse_image(data["base_rgb"])
|
| 21 |
+
wrist_image = _parse_image(data["wrist_rgb"])
|
| 22 |
+
|
| 23 |
+
# Create inputs dict.
|
| 24 |
+
inputs = {
|
| 25 |
+
"state": state,
|
| 26 |
+
"image": {
|
| 27 |
+
"base_0_rgb": base_image,
|
| 28 |
+
"left_wrist_0_rgb": wrist_image,
|
| 29 |
+
# Since there is no right wrist, replace with zeros
|
| 30 |
+
"right_wrist_0_rgb": np.zeros_like(base_image),
|
| 31 |
+
},
|
| 32 |
+
"image_mask": {
|
| 33 |
+
"base_0_rgb": np.True_,
|
| 34 |
+
"left_wrist_0_rgb": np.True_,
|
| 35 |
+
# Since the "slot" for the right wrist is not used, this mask is set
|
| 36 |
+
# to False
|
| 37 |
+
"right_wrist_0_rgb": np.True_ if self.model_type == _model.ModelType.PI0_FAST else np.False_,
|
| 38 |
+
},
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
if "actions" in data:
|
| 42 |
+
inputs["actions"] = data["actions"]
|
| 43 |
+
|
| 44 |
+
# Pass the prompt (aka language instruction) to the model.
|
| 45 |
+
if "prompt" in data:
|
| 46 |
+
inputs["prompt"] = data["prompt"]
|
| 47 |
+
|
| 48 |
+
return inputs
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@dataclasses.dataclass(frozen=True)
|
| 52 |
+
class UR5Outputs(transforms.DataTransformFn):
|
| 53 |
+
|
| 54 |
+
def __call__(self, data: dict) -> dict:
|
| 55 |
+
# Since the robot has 7 action dimensions (6 DoF + gripper), return the first 7 dims
|
| 56 |
+
return {"actions": np.asarray(data["actions"][:, :7])}
|
| 57 |
+
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
Next, we will define the `UR5DataConfig` class, which defines how to process raw UR5 data from LeRobot dataset for training. For a full example, see the `LeRobotLiberoDataConfig` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py).
|
| 61 |
+
|
| 62 |
+
```python
|
| 63 |
+
|
| 64 |
+
@dataclasses.dataclass(frozen=True)
|
| 65 |
+
class LeRobotUR5DataConfig(DataConfigFactory):
|
| 66 |
+
|
| 67 |
+
@override
|
| 68 |
+
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
|
| 69 |
+
# Boilerplate for remapping keys from the LeRobot dataset. We assume no renaming needed here.
|
| 70 |
+
repack_transform = _transforms.Group(
|
| 71 |
+
inputs=[
|
| 72 |
+
_transforms.RepackTransform(
|
| 73 |
+
{
|
| 74 |
+
"base_rgb": "image",
|
| 75 |
+
"wrist_rgb": "wrist_image",
|
| 76 |
+
"joints": "joints",
|
| 77 |
+
"gripper": "gripper",
|
| 78 |
+
"prompt": "prompt",
|
| 79 |
+
}
|
| 80 |
+
)
|
| 81 |
+
]
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# These transforms are the ones we wrote earlier.
|
| 85 |
+
data_transforms = _transforms.Group(
|
| 86 |
+
inputs=[UR5Inputs(action_dim=model_config.action_dim, model_type=model_config.model_type)],
|
| 87 |
+
outputs=[UR5Outputs()],
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# Convert absolute actions to delta actions.
|
| 91 |
+
# By convention, we do not convert the gripper action (7th dimension).
|
| 92 |
+
delta_action_mask = _transforms.make_bool_mask(6, -1)
|
| 93 |
+
data_transforms = data_transforms.push(
|
| 94 |
+
inputs=[_transforms.DeltaActions(delta_action_mask)],
|
| 95 |
+
outputs=[_transforms.AbsoluteActions(delta_action_mask)],
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# Model transforms include things like tokenizing the prompt and action targets
|
| 99 |
+
# You do not need to change anything here for your own dataset.
|
| 100 |
+
model_transforms = ModelTransformFactory()(model_config)
|
| 101 |
+
|
| 102 |
+
# We return all data transforms for training and inference. No need to change anything here.
|
| 103 |
+
return dataclasses.replace(
|
| 104 |
+
self.create_base_config(assets_dirs),
|
| 105 |
+
repack_transforms=repack_transform,
|
| 106 |
+
data_transforms=data_transforms,
|
| 107 |
+
model_transforms=model_transforms,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
Finally, we define the TrainConfig for our UR5 dataset. Here, we define a config for fine-tuning pi0 on our UR5 dataset. See the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py) for more examples, e.g. for pi0-FAST or for LoRA fine-tuning.
|
| 113 |
+
|
| 114 |
+
```python
|
| 115 |
+
TrainConfig(
|
| 116 |
+
name="pi0_ur5",
|
| 117 |
+
model=pi0.Pi0Config(),
|
| 118 |
+
data=LeRobotUR5DataConfig(
|
| 119 |
+
repo_id="your_username/ur5_dataset",
|
| 120 |
+
# This config lets us reload the UR5 normalization stats from the base model checkpoint.
|
| 121 |
+
# Reloading normalization stats can help transfer pre-trained models to new environments.
|
| 122 |
+
# See the [norm_stats.md](../docs/norm_stats.md) file for more details.
|
| 123 |
+
assets=AssetsConfig(
|
| 124 |
+
assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
|
| 125 |
+
asset_id="ur5e",
|
| 126 |
+
),
|
| 127 |
+
base_config=DataConfig(
|
| 128 |
+
# This flag determines whether we load the prompt (i.e. the task instruction) from the
|
| 129 |
+
# ``task`` field in the LeRobot dataset. The recommended setting is True.
|
| 130 |
+
prompt_from_task=True,
|
| 131 |
+
),
|
| 132 |
+
),
|
| 133 |
+
# Load the pi0 base model checkpoint.
|
| 134 |
+
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"),
|
| 135 |
+
num_train_steps=30_000,
|
| 136 |
+
)
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
|
capvector-pi05/packages/openpi-client/pyproject.toml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "openpi-client"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
requires-python = ">=3.7"
|
| 5 |
+
dependencies = [
|
| 6 |
+
"dm-tree>=0.1.8",
|
| 7 |
+
"msgpack>=1.0.5",
|
| 8 |
+
"numpy>=1.22.4,<2.0.0",
|
| 9 |
+
"pillow>=9.0.0",
|
| 10 |
+
"tree>=0.2.4",
|
| 11 |
+
"websockets>=11.0",
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
[build-system]
|
| 15 |
+
requires = ["hatchling"]
|
| 16 |
+
build-backend = "hatchling.build"
|
| 17 |
+
|
| 18 |
+
[tool.uv]
|
| 19 |
+
dev-dependencies = ["pytest>=8.3.4"]
|
| 20 |
+
|
| 21 |
+
[tool.ruff]
|
| 22 |
+
line-length = 120
|
| 23 |
+
target-version = "py37"
|
capvector-pi05/packages/openpi-client/src/openpi_client/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__version__ = "0.1.0"
|
capvector-pi05/packages/openpi-client/src/openpi_client/action_chunk_broker.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import tree
|
| 5 |
+
from typing_extensions import override
|
| 6 |
+
|
| 7 |
+
from openpi_client import base_policy as _base_policy
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ActionChunkBroker(_base_policy.BasePolicy):
|
| 11 |
+
"""Wraps a policy to return action chunks one-at-a-time.
|
| 12 |
+
|
| 13 |
+
Assumes that the first dimension of all action fields is the chunk size.
|
| 14 |
+
|
| 15 |
+
A new inference call to the inner policy is only made when the current
|
| 16 |
+
list of chunks is exhausted.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int):
|
| 20 |
+
self._policy = policy
|
| 21 |
+
self._action_horizon = action_horizon
|
| 22 |
+
self._cur_step: int = 0
|
| 23 |
+
|
| 24 |
+
self._last_results: Dict[str, np.ndarray] | None = None
|
| 25 |
+
|
| 26 |
+
@override
|
| 27 |
+
def infer(self, obs: Dict) -> Dict: # noqa: UP006
|
| 28 |
+
if self._last_results is None:
|
| 29 |
+
self._last_results = self._policy.infer(obs)
|
| 30 |
+
self._cur_step = 0
|
| 31 |
+
|
| 32 |
+
def slicer(x):
|
| 33 |
+
if isinstance(x, np.ndarray):
|
| 34 |
+
return x[self._cur_step, ...]
|
| 35 |
+
else:
|
| 36 |
+
return x
|
| 37 |
+
|
| 38 |
+
results = tree.map_structure(slicer, self._last_results)
|
| 39 |
+
self._cur_step += 1
|
| 40 |
+
|
| 41 |
+
if self._cur_step >= self._action_horizon:
|
| 42 |
+
self._last_results = None
|
| 43 |
+
|
| 44 |
+
return results
|
| 45 |
+
|
| 46 |
+
@override
|
| 47 |
+
def reset(self) -> None:
|
| 48 |
+
self._policy.reset()
|
| 49 |
+
self._last_results = None
|
| 50 |
+
self._cur_step = 0
|
capvector-pi05/packages/openpi-client/src/openpi_client/base_policy.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
from typing import Dict
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class BasePolicy(abc.ABC):
|
| 6 |
+
@abc.abstractmethod
|
| 7 |
+
def infer(self, obs: Dict) -> Dict:
|
| 8 |
+
"""Infer actions from observations."""
|
| 9 |
+
|
| 10 |
+
def reset(self) -> None:
|
| 11 |
+
"""Reset the policy to its initial state."""
|
| 12 |
+
pass
|
capvector-pi05/packages/openpi-client/src/openpi_client/image_tools.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from PIL import Image
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def convert_to_uint8(img: np.ndarray) -> np.ndarray:
|
| 6 |
+
"""Converts an image to uint8 if it is a float image.
|
| 7 |
+
|
| 8 |
+
This is important for reducing the size of the image when sending it over the network.
|
| 9 |
+
"""
|
| 10 |
+
if np.issubdtype(img.dtype, np.floating):
|
| 11 |
+
img = (255 * img).astype(np.uint8)
|
| 12 |
+
return img
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def resize_with_pad(images: np.ndarray, height: int, width: int, method=Image.BILINEAR, return_mask=False) -> np.ndarray:
|
| 16 |
+
"""Replicates tf.image.resize_with_pad for multiple images using PIL. Resizes a batch of images to a target height.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
images: A batch of images in [..., height, width, channel] format.
|
| 20 |
+
height: The target height of the image.
|
| 21 |
+
width: The target width of the image.
|
| 22 |
+
method: The interpolation method to use. Default is bilinear.
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
The resized images in [..., height, width, channel].
|
| 26 |
+
"""
|
| 27 |
+
# If the images are already the correct size, return them as is.
|
| 28 |
+
if images.shape[-3:-1] == (height, width):
|
| 29 |
+
if return_mask:
|
| 30 |
+
img_padding_mask = np.ones((*images.shape[:-3], height, width), dtype=bool)
|
| 31 |
+
return images, img_padding_mask
|
| 32 |
+
return images
|
| 33 |
+
|
| 34 |
+
original_shape = images.shape
|
| 35 |
+
|
| 36 |
+
images = images.reshape(-1, *original_shape[-3:])
|
| 37 |
+
|
| 38 |
+
resized_results = [
|
| 39 |
+
_resize_with_pad_pil(Image.fromarray(im), height, width, method=method) for im in images
|
| 40 |
+
]
|
| 41 |
+
resized_images, img_padding_mask = zip(*resized_results)
|
| 42 |
+
resized_images = np.stack(resized_images)
|
| 43 |
+
img_padding_mask = np.stack(img_padding_mask)
|
| 44 |
+
|
| 45 |
+
if return_mask:
|
| 46 |
+
return (
|
| 47 |
+
resized_images.reshape(*original_shape[:-3], *resized_images.shape[-3:]),
|
| 48 |
+
img_padding_mask.reshape(*original_shape[:-3], *img_padding_mask.shape[-2:]),
|
| 49 |
+
)
|
| 50 |
+
else:
|
| 51 |
+
return resized_images.reshape(*original_shape[:-3], *resized_images.shape[-3:])
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _resize_with_pad_pil(image: Image.Image, height: int, width: int, method: int) -> Image.Image:
|
| 55 |
+
"""Replicates tf.image.resize_with_pad for one image using PIL. Resizes an image to a target height and
|
| 56 |
+
width without distortion by padding with zeros.
|
| 57 |
+
|
| 58 |
+
Unlike the jax version, note that PIL uses [width, height, channel] ordering instead of [batch, h, w, c].
|
| 59 |
+
"""
|
| 60 |
+
cur_width, cur_height = image.size
|
| 61 |
+
if cur_width == width and cur_height == height:
|
| 62 |
+
return image # No need to resize if the image is already the correct size.
|
| 63 |
+
|
| 64 |
+
ratio = max(cur_width / width, cur_height / height)
|
| 65 |
+
resized_height = int(cur_height / ratio)
|
| 66 |
+
resized_width = int(cur_width / ratio)
|
| 67 |
+
resized_image = image.resize((resized_width, resized_height), resample=method)
|
| 68 |
+
|
| 69 |
+
zero_image = Image.new(resized_image.mode, (width, height), 0)
|
| 70 |
+
pad_height = max(0, int((height - resized_height) / 2))
|
| 71 |
+
pad_width = max(0, int((width - resized_width) / 2))
|
| 72 |
+
zero_image.paste(resized_image, (pad_width, pad_height))
|
| 73 |
+
assert zero_image.size == (width, height)
|
| 74 |
+
|
| 75 |
+
img_padding_mask = np.zeros((height, width), dtype=bool)
|
| 76 |
+
img_padding_mask[pad_height:pad_height+resized_height, pad_width:pad_width+resized_width] = True
|
| 77 |
+
|
| 78 |
+
return zero_image, img_padding_mask
|
capvector-pi05/packages/openpi-client/src/openpi_client/image_tools_test.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
import openpi_client.image_tools as image_tools
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def test_resize_with_pad_shapes():
|
| 7 |
+
# Test case 1: Resize image with larger dimensions
|
| 8 |
+
images = np.zeros((2, 10, 10, 3), dtype=np.uint8) # Input images of shape (batch_size, height, width, channels)
|
| 9 |
+
height = 20
|
| 10 |
+
width = 20
|
| 11 |
+
resized_images = image_tools.resize_with_pad(images, height, width)
|
| 12 |
+
assert resized_images.shape == (2, height, width, 3)
|
| 13 |
+
assert np.all(resized_images == 0)
|
| 14 |
+
|
| 15 |
+
# Test case 2: Resize image with smaller dimensions
|
| 16 |
+
images = np.zeros((3, 30, 30, 3), dtype=np.uint8)
|
| 17 |
+
height = 15
|
| 18 |
+
width = 15
|
| 19 |
+
resized_images = image_tools.resize_with_pad(images, height, width)
|
| 20 |
+
assert resized_images.shape == (3, height, width, 3)
|
| 21 |
+
assert np.all(resized_images == 0)
|
| 22 |
+
|
| 23 |
+
# Test case 3: Resize image with the same dimensions
|
| 24 |
+
images = np.zeros((1, 50, 50, 3), dtype=np.uint8)
|
| 25 |
+
height = 50
|
| 26 |
+
width = 50
|
| 27 |
+
resized_images = image_tools.resize_with_pad(images, height, width)
|
| 28 |
+
assert resized_images.shape == (1, height, width, 3)
|
| 29 |
+
assert np.all(resized_images == 0)
|
| 30 |
+
|
| 31 |
+
# Test case 3: Resize image with odd-numbered padding
|
| 32 |
+
images = np.zeros((1, 256, 320, 3), dtype=np.uint8)
|
| 33 |
+
height = 60
|
| 34 |
+
width = 80
|
| 35 |
+
resized_images = image_tools.resize_with_pad(images, height, width)
|
| 36 |
+
assert resized_images.shape == (1, height, width, 3)
|
| 37 |
+
assert np.all(resized_images == 0)
|
capvector-pi05/packages/openpi-client/src/openpi_client/msgpack_numpy.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Adds NumPy array support to msgpack.
|
| 2 |
+
|
| 3 |
+
msgpack is good for (de)serializing data over a network for multiple reasons:
|
| 4 |
+
- msgpack is secure (as opposed to pickle/dill/etc which allow for arbitrary code execution)
|
| 5 |
+
- msgpack is widely used and has good cross-language support
|
| 6 |
+
- msgpack does not require a schema (as opposed to protobuf/flatbuffers/etc) which is convenient in dynamically typed
|
| 7 |
+
languages like Python and JavaScript
|
| 8 |
+
- msgpack is fast and efficient (as opposed to readable formats like JSON/YAML/etc); I found that msgpack was ~4x faster
|
| 9 |
+
than pickle for serializing large arrays using the below strategy
|
| 10 |
+
|
| 11 |
+
The code below is adapted from https://github.com/lebedov/msgpack-numpy. The reason not to use that library directly is
|
| 12 |
+
that it falls back to pickle for object arrays.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import functools
|
| 16 |
+
|
| 17 |
+
import msgpack
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def pack_array(obj):
|
| 22 |
+
if (isinstance(obj, (np.ndarray, np.generic))) and obj.dtype.kind in ("V", "O", "c"):
|
| 23 |
+
raise ValueError(f"Unsupported dtype: {obj.dtype}")
|
| 24 |
+
|
| 25 |
+
if isinstance(obj, np.ndarray):
|
| 26 |
+
return {
|
| 27 |
+
b"__ndarray__": True,
|
| 28 |
+
b"data": obj.tobytes(),
|
| 29 |
+
b"dtype": obj.dtype.str,
|
| 30 |
+
b"shape": obj.shape,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
if isinstance(obj, np.generic):
|
| 34 |
+
return {
|
| 35 |
+
b"__npgeneric__": True,
|
| 36 |
+
b"data": obj.item(),
|
| 37 |
+
b"dtype": obj.dtype.str,
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
return obj
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def unpack_array(obj):
|
| 44 |
+
if b"__ndarray__" in obj:
|
| 45 |
+
return np.ndarray(buffer=obj[b"data"], dtype=np.dtype(obj[b"dtype"]), shape=obj[b"shape"])
|
| 46 |
+
|
| 47 |
+
if b"__npgeneric__" in obj:
|
| 48 |
+
return np.dtype(obj[b"dtype"]).type(obj[b"data"])
|
| 49 |
+
|
| 50 |
+
return obj
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
Packer = functools.partial(msgpack.Packer, default=pack_array)
|
| 54 |
+
packb = functools.partial(msgpack.packb, default=pack_array)
|
| 55 |
+
|
| 56 |
+
Unpacker = functools.partial(msgpack.Unpacker, object_hook=unpack_array)
|
| 57 |
+
unpackb = functools.partial(msgpack.unpackb, object_hook=unpack_array)
|
capvector-pi05/packages/openpi-client/src/openpi_client/msgpack_numpy_test.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pytest
|
| 3 |
+
import tree
|
| 4 |
+
|
| 5 |
+
from openpi_client import msgpack_numpy
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _check(expected, actual):
|
| 9 |
+
if isinstance(expected, np.ndarray):
|
| 10 |
+
assert expected.shape == actual.shape
|
| 11 |
+
assert expected.dtype == actual.dtype
|
| 12 |
+
assert np.array_equal(expected, actual, equal_nan=expected.dtype.kind == "f")
|
| 13 |
+
else:
|
| 14 |
+
assert expected == actual
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@pytest.mark.parametrize(
|
| 18 |
+
"data",
|
| 19 |
+
[
|
| 20 |
+
1, # int
|
| 21 |
+
1.0, # float
|
| 22 |
+
"hello", # string
|
| 23 |
+
np.bool_(True), # boolean scalar
|
| 24 |
+
np.array([1, 2, 3])[0], # int scalar
|
| 25 |
+
np.str_("asdf"), # string scalar
|
| 26 |
+
[1, 2, 3], # list
|
| 27 |
+
{"key": "value"}, # dict
|
| 28 |
+
{"key": [1, 2, 3]}, # nested dict
|
| 29 |
+
np.array(1.0), # 0D array
|
| 30 |
+
np.array([1, 2, 3], dtype=np.int32), # 1D integer array
|
| 31 |
+
np.array(["asdf", "qwer"]), # string array
|
| 32 |
+
np.array([True, False]), # boolean array
|
| 33 |
+
np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), # 2D float array
|
| 34 |
+
np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.int16), # 3D integer array
|
| 35 |
+
np.array([np.nan, np.inf, -np.inf]), # special float values
|
| 36 |
+
{"arr": np.array([1, 2, 3]), "nested": {"arr": np.array([4, 5, 6])}}, # nested dict with arrays
|
| 37 |
+
[np.array([1, 2]), np.array([3, 4])], # list of arrays
|
| 38 |
+
np.zeros((3, 4, 5), dtype=np.float32), # 3D zeros
|
| 39 |
+
np.ones((2, 3), dtype=np.float64), # 2D ones with double precision
|
| 40 |
+
],
|
| 41 |
+
)
|
| 42 |
+
def test_pack_unpack(data):
|
| 43 |
+
packed = msgpack_numpy.packb(data)
|
| 44 |
+
unpacked = msgpack_numpy.unpackb(packed)
|
| 45 |
+
tree.map_structure(_check, data, unpacked)
|
capvector-pi05/packages/openpi-client/src/openpi_client/runtime/agent.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Agent(abc.ABC):
|
| 5 |
+
"""An Agent is the thing with agency, i.e. the entity that makes decisions.
|
| 6 |
+
|
| 7 |
+
Agents receive observations about the state of the world, and return actions
|
| 8 |
+
to take in response.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
@abc.abstractmethod
|
| 12 |
+
def get_action(self, observation: dict) -> dict:
|
| 13 |
+
"""Query the agent for the next action."""
|
| 14 |
+
|
| 15 |
+
@abc.abstractmethod
|
| 16 |
+
def reset(self) -> None:
|
| 17 |
+
"""Reset the agent to its initial state."""
|
capvector-pi05/packages/openpi-client/src/openpi_client/runtime/agents/policy_agent.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing_extensions import override
|
| 2 |
+
|
| 3 |
+
from openpi_client import base_policy as _base_policy
|
| 4 |
+
from openpi_client.runtime import agent as _agent
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class PolicyAgent(_agent.Agent):
|
| 8 |
+
"""An agent that uses a policy to determine actions."""
|
| 9 |
+
|
| 10 |
+
def __init__(self, policy: _base_policy.BasePolicy) -> None:
|
| 11 |
+
self._policy = policy
|
| 12 |
+
|
| 13 |
+
@override
|
| 14 |
+
def get_action(self, observation: dict) -> dict:
|
| 15 |
+
return self._policy.infer(observation)
|
| 16 |
+
|
| 17 |
+
def reset(self) -> None:
|
| 18 |
+
self._policy.reset()
|
capvector-pi05/packages/openpi-client/src/openpi_client/runtime/environment.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Environment(abc.ABC):
|
| 5 |
+
"""An Environment represents the robot and the environment it inhabits.
|
| 6 |
+
|
| 7 |
+
The primary contract of environments is that they can be queried for observations
|
| 8 |
+
about their state, and have actions applied to them to change that state.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
@abc.abstractmethod
|
| 12 |
+
def reset(self) -> None:
|
| 13 |
+
"""Reset the environment to its initial state.
|
| 14 |
+
|
| 15 |
+
This will be called once before starting each episode.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
@abc.abstractmethod
|
| 19 |
+
def is_episode_complete(self) -> bool:
|
| 20 |
+
"""Allow the environment to signal that the episode is complete.
|
| 21 |
+
|
| 22 |
+
This will be called after each step. It should return `True` if the episode is
|
| 23 |
+
complete (either successfully or unsuccessfully), and `False` otherwise.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
@abc.abstractmethod
|
| 27 |
+
def get_observation(self) -> dict:
|
| 28 |
+
"""Query the environment for the current state."""
|
| 29 |
+
|
| 30 |
+
@abc.abstractmethod
|
| 31 |
+
def apply_action(self, action: dict) -> None:
|
| 32 |
+
"""Take an action in the environment."""
|
capvector-pi05/packages/openpi-client/src/openpi_client/runtime/runtime.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import threading
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
from openpi_client.runtime import agent as _agent
|
| 6 |
+
from openpi_client.runtime import environment as _environment
|
| 7 |
+
from openpi_client.runtime import subscriber as _subscriber
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Runtime:
|
| 11 |
+
"""The core module orchestrating interactions between key components of the system."""
|
| 12 |
+
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
environment: _environment.Environment,
|
| 16 |
+
agent: _agent.Agent,
|
| 17 |
+
subscribers: list[_subscriber.Subscriber],
|
| 18 |
+
max_hz: float = 0,
|
| 19 |
+
num_episodes: int = 1,
|
| 20 |
+
max_episode_steps: int = 0,
|
| 21 |
+
) -> None:
|
| 22 |
+
self._environment = environment
|
| 23 |
+
self._agent = agent
|
| 24 |
+
self._subscribers = subscribers
|
| 25 |
+
self._max_hz = max_hz
|
| 26 |
+
self._num_episodes = num_episodes
|
| 27 |
+
self._max_episode_steps = max_episode_steps
|
| 28 |
+
|
| 29 |
+
self._in_episode = False
|
| 30 |
+
self._episode_steps = 0
|
| 31 |
+
|
| 32 |
+
def run(self) -> None:
|
| 33 |
+
"""Runs the runtime loop continuously until stop() is called or the environment is done."""
|
| 34 |
+
for _ in range(self._num_episodes):
|
| 35 |
+
self._run_episode()
|
| 36 |
+
|
| 37 |
+
# Final reset, this is important for real environments to move the robot to its home position.
|
| 38 |
+
self._environment.reset()
|
| 39 |
+
|
| 40 |
+
def run_in_new_thread(self) -> threading.Thread:
|
| 41 |
+
"""Runs the runtime loop in a new thread."""
|
| 42 |
+
thread = threading.Thread(target=self.run)
|
| 43 |
+
thread.start()
|
| 44 |
+
return thread
|
| 45 |
+
|
| 46 |
+
def mark_episode_complete(self) -> None:
|
| 47 |
+
"""Marks the end of an episode."""
|
| 48 |
+
self._in_episode = False
|
| 49 |
+
|
| 50 |
+
def _run_episode(self) -> None:
|
| 51 |
+
"""Runs a single episode."""
|
| 52 |
+
logging.info("Starting episode...")
|
| 53 |
+
self._environment.reset()
|
| 54 |
+
self._agent.reset()
|
| 55 |
+
for subscriber in self._subscribers:
|
| 56 |
+
subscriber.on_episode_start()
|
| 57 |
+
|
| 58 |
+
self._in_episode = True
|
| 59 |
+
self._episode_steps = 0
|
| 60 |
+
step_time = 1 / self._max_hz if self._max_hz > 0 else 0
|
| 61 |
+
last_step_time = time.time()
|
| 62 |
+
|
| 63 |
+
while self._in_episode:
|
| 64 |
+
self._step()
|
| 65 |
+
self._episode_steps += 1
|
| 66 |
+
|
| 67 |
+
# Sleep to maintain the desired frame rate
|
| 68 |
+
now = time.time()
|
| 69 |
+
dt = now - last_step_time
|
| 70 |
+
if dt < step_time:
|
| 71 |
+
time.sleep(step_time - dt)
|
| 72 |
+
last_step_time = time.time()
|
| 73 |
+
else:
|
| 74 |
+
last_step_time = now
|
| 75 |
+
|
| 76 |
+
logging.info("Episode completed.")
|
| 77 |
+
for subscriber in self._subscribers:
|
| 78 |
+
subscriber.on_episode_end()
|
| 79 |
+
|
| 80 |
+
def _step(self) -> None:
|
| 81 |
+
"""A single step of the runtime loop."""
|
| 82 |
+
observation = self._environment.get_observation()
|
| 83 |
+
action = self._agent.get_action(observation)
|
| 84 |
+
self._environment.apply_action(action)
|
| 85 |
+
|
| 86 |
+
for subscriber in self._subscribers:
|
| 87 |
+
subscriber.on_step(observation, action)
|
| 88 |
+
|
| 89 |
+
if self._environment.is_episode_complete() or (
|
| 90 |
+
self._max_episode_steps > 0 and self._episode_steps >= self._max_episode_steps
|
| 91 |
+
):
|
| 92 |
+
self.mark_episode_complete()
|
capvector-pi05/packages/openpi-client/src/openpi_client/runtime/subscriber.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Subscriber(abc.ABC):
|
| 5 |
+
"""Subscribes to events in the runtime.
|
| 6 |
+
|
| 7 |
+
Subscribers can be used to save data, visualize, etc.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
@abc.abstractmethod
|
| 11 |
+
def on_episode_start(self) -> None:
|
| 12 |
+
"""Called when an episode starts."""
|
| 13 |
+
|
| 14 |
+
@abc.abstractmethod
|
| 15 |
+
def on_step(self, observation: dict, action: dict) -> None:
|
| 16 |
+
"""Append a step to the episode."""
|
| 17 |
+
|
| 18 |
+
@abc.abstractmethod
|
| 19 |
+
def on_episode_end(self) -> None:
|
| 20 |
+
"""Called when an episode ends."""
|
capvector-pi05/packages/openpi-client/src/openpi_client/websocket_client_policy.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import time
|
| 3 |
+
from typing import Dict, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
from typing_extensions import override
|
| 6 |
+
import websockets.sync.client
|
| 7 |
+
|
| 8 |
+
from openpi_client import base_policy as _base_policy
|
| 9 |
+
from openpi_client import msgpack_numpy
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class WebsocketClientPolicy(_base_policy.BasePolicy):
|
| 13 |
+
"""Implements the Policy interface by communicating with a server over websocket.
|
| 14 |
+
|
| 15 |
+
See WebsocketPolicyServer for a corresponding server implementation.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, host: str = "0.0.0.0", port: Optional[int] = None, api_key: Optional[str] = None) -> None:
|
| 19 |
+
self._uri = f"ws://{host}"
|
| 20 |
+
if port is not None:
|
| 21 |
+
self._uri += f":{port}"
|
| 22 |
+
self._packer = msgpack_numpy.Packer()
|
| 23 |
+
self._api_key = api_key
|
| 24 |
+
self._ws, self._server_metadata = self._wait_for_server()
|
| 25 |
+
|
| 26 |
+
def get_server_metadata(self) -> Dict:
|
| 27 |
+
return self._server_metadata
|
| 28 |
+
|
| 29 |
+
def _wait_for_server(self) -> Tuple[websockets.sync.client.ClientConnection, Dict]:
|
| 30 |
+
logging.info(f"Waiting for server at {self._uri}...")
|
| 31 |
+
while True:
|
| 32 |
+
try:
|
| 33 |
+
headers = {"Authorization": f"Api-Key {self._api_key}"} if self._api_key else None
|
| 34 |
+
conn = websockets.sync.client.connect(
|
| 35 |
+
self._uri, compression=None, max_size=None, additional_headers=headers
|
| 36 |
+
)
|
| 37 |
+
metadata = msgpack_numpy.unpackb(conn.recv())
|
| 38 |
+
return conn, metadata
|
| 39 |
+
except ConnectionRefusedError:
|
| 40 |
+
logging.info("Still waiting for server...")
|
| 41 |
+
time.sleep(5)
|
| 42 |
+
|
| 43 |
+
@override
|
| 44 |
+
def infer(self, obs: Dict) -> Dict: # noqa: UP006
|
| 45 |
+
data = self._packer.pack(obs)
|
| 46 |
+
self._ws.send(data)
|
| 47 |
+
response = self._ws.recv()
|
| 48 |
+
if isinstance(response, str):
|
| 49 |
+
# we're expecting bytes; if the server sends a string, it's an error.
|
| 50 |
+
raise RuntimeError(f"Error in inference server:\n{response}")
|
| 51 |
+
return msgpack_numpy.unpackb(response)
|
| 52 |
+
|
| 53 |
+
@override
|
| 54 |
+
def reset(self) -> None:
|
| 55 |
+
pass
|
capvector-pi05/scripts/__init__.py
ADDED
|
File without changes
|
capvector-pi05/scripts/compute_norm_stats.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Compute normalization statistics for a config.
|
| 2 |
+
|
| 3 |
+
This script is used to compute the normalization statistics for a given config. It
|
| 4 |
+
will compute the mean and standard deviation of the data in the dataset and save it
|
| 5 |
+
to the config assets directory.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import tqdm
|
| 10 |
+
import tyro
|
| 11 |
+
|
| 12 |
+
import openpi.models.model as _model
|
| 13 |
+
import openpi.shared.normalize as normalize
|
| 14 |
+
import openpi.training.config as _config
|
| 15 |
+
import openpi.training.data_loader as _data_loader
|
| 16 |
+
import openpi.transforms as transforms
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class RemoveStrings(transforms.DataTransformFn):
|
| 20 |
+
def __call__(self, x: dict) -> dict:
|
| 21 |
+
return {k: v for k, v in x.items() if not np.issubdtype(np.asarray(v).dtype, np.str_)}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def create_torch_dataloader(
|
| 25 |
+
data_config: _config.DataConfig,
|
| 26 |
+
action_horizon: int,
|
| 27 |
+
batch_size: int,
|
| 28 |
+
model_config: _model.BaseModelConfig,
|
| 29 |
+
num_workers: int,
|
| 30 |
+
max_frames: int | None = None,
|
| 31 |
+
) -> tuple[_data_loader.Dataset, int]:
|
| 32 |
+
if data_config.repo_id is None:
|
| 33 |
+
raise ValueError("Data config must have a repo_id")
|
| 34 |
+
dataset = _data_loader.create_torch_dataset(data_config, action_horizon, model_config)
|
| 35 |
+
dataset = _data_loader.TransformedDataset(
|
| 36 |
+
dataset,
|
| 37 |
+
[
|
| 38 |
+
*data_config.repack_transforms.inputs,
|
| 39 |
+
*data_config.data_transforms.inputs,
|
| 40 |
+
# Remove strings since they are not supported by JAX and are not needed to compute norm stats.
|
| 41 |
+
RemoveStrings(),
|
| 42 |
+
],
|
| 43 |
+
)
|
| 44 |
+
if max_frames is not None and max_frames < len(dataset):
|
| 45 |
+
num_batches = max_frames // batch_size
|
| 46 |
+
shuffle = True
|
| 47 |
+
else:
|
| 48 |
+
num_batches = len(dataset) // batch_size
|
| 49 |
+
shuffle = False
|
| 50 |
+
data_loader = _data_loader.TorchDataLoader(
|
| 51 |
+
dataset,
|
| 52 |
+
local_batch_size=batch_size,
|
| 53 |
+
num_workers=num_workers,
|
| 54 |
+
shuffle=shuffle,
|
| 55 |
+
num_batches=num_batches,
|
| 56 |
+
)
|
| 57 |
+
return data_loader, num_batches
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def create_rlds_dataloader(
|
| 61 |
+
data_config: _config.DataConfig,
|
| 62 |
+
action_horizon: int,
|
| 63 |
+
batch_size: int,
|
| 64 |
+
max_frames: int | None = None,
|
| 65 |
+
) -> tuple[_data_loader.Dataset, int]:
|
| 66 |
+
dataset = _data_loader.create_rlds_dataset(data_config, action_horizon, batch_size, shuffle=False)
|
| 67 |
+
dataset = _data_loader.IterableTransformedDataset(
|
| 68 |
+
dataset,
|
| 69 |
+
[
|
| 70 |
+
*data_config.repack_transforms.inputs,
|
| 71 |
+
*data_config.data_transforms.inputs,
|
| 72 |
+
# Remove strings since they are not supported by JAX and are not needed to compute norm stats.
|
| 73 |
+
RemoveStrings(),
|
| 74 |
+
],
|
| 75 |
+
is_batched=True,
|
| 76 |
+
)
|
| 77 |
+
if max_frames is not None and max_frames < len(dataset):
|
| 78 |
+
num_batches = max_frames // batch_size
|
| 79 |
+
else:
|
| 80 |
+
# NOTE: this length is currently hard-coded for DROID.
|
| 81 |
+
num_batches = len(dataset) // batch_size
|
| 82 |
+
data_loader = _data_loader.RLDSDataLoader(
|
| 83 |
+
dataset,
|
| 84 |
+
num_batches=num_batches,
|
| 85 |
+
)
|
| 86 |
+
return data_loader, num_batches
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def main(config_name: str, max_frames: int | None = None):
|
| 90 |
+
config = _config.get_config(config_name)
|
| 91 |
+
data_config = config.data.create(config.assets_dirs, config.model)
|
| 92 |
+
|
| 93 |
+
if data_config.rlds_data_dir is not None:
|
| 94 |
+
data_loader, num_batches = create_rlds_dataloader(
|
| 95 |
+
data_config, config.model.action_horizon, config.batch_size, max_frames
|
| 96 |
+
)
|
| 97 |
+
else:
|
| 98 |
+
data_loader, num_batches = create_torch_dataloader(
|
| 99 |
+
data_config, config.model.action_horizon, config.batch_size, config.model, config.num_workers, max_frames
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
keys = ["state", "actions"]
|
| 103 |
+
stats = {key: normalize.RunningStats() for key in keys}
|
| 104 |
+
|
| 105 |
+
for batch in tqdm.tqdm(data_loader, total=num_batches, desc="Computing stats"):
|
| 106 |
+
for key in keys:
|
| 107 |
+
stats[key].update(np.asarray(batch[key]))
|
| 108 |
+
|
| 109 |
+
norm_stats = {key: stats.get_statistics() for key, stats in stats.items()}
|
| 110 |
+
|
| 111 |
+
output_path = config.assets_dirs / data_config.repo_id
|
| 112 |
+
print(f"Writing stats to: {output_path}")
|
| 113 |
+
normalize.save(output_path, norm_stats)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
if __name__ == "__main__":
|
| 117 |
+
tyro.cli(main)
|
capvector-pi05/scripts/docker/compose.yml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Run with:
|
| 2 |
+
# docker compose -f scripts/docker/compose.yml up --build
|
| 3 |
+
services:
|
| 4 |
+
openpi_server:
|
| 5 |
+
image: openpi_server
|
| 6 |
+
build:
|
| 7 |
+
context: ../..
|
| 8 |
+
dockerfile: scripts/docker/serve_policy.Dockerfile
|
| 9 |
+
init: true
|
| 10 |
+
tty: true
|
| 11 |
+
network_mode: host
|
| 12 |
+
# Populate configured openpi data home to /openpi_assets inside the container.
|
| 13 |
+
# Populate aws credential inside the container.
|
| 14 |
+
volumes:
|
| 15 |
+
- $PWD:/app
|
| 16 |
+
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
|
| 17 |
+
environment:
|
| 18 |
+
- SERVER_ARGS
|
| 19 |
+
- OPENPI_DATA_HOME=/openpi_assets
|
| 20 |
+
- IS_DOCKER=true
|
| 21 |
+
|
| 22 |
+
# Comment out this block if not running on a machine with GPUs.
|
| 23 |
+
deploy:
|
| 24 |
+
resources:
|
| 25 |
+
reservations:
|
| 26 |
+
devices:
|
| 27 |
+
- driver: nvidia
|
| 28 |
+
count: 1
|
| 29 |
+
capabilities: [gpu]
|
capvector-pi05/scripts/docker/install_docker_ubuntu22.sh
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Add Docker's official GPG key:
|
| 4 |
+
sudo apt-get update
|
| 5 |
+
sudo apt-get install -y ca-certificates curl
|
| 6 |
+
sudo install -m 0755 -d /etc/apt/keyrings
|
| 7 |
+
sudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc
|
| 8 |
+
sudo chmod a+r /etc/apt/keyrings/docker.asc
|
| 9 |
+
|
| 10 |
+
# Add the repository to Apt sources:
|
| 11 |
+
echo \
|
| 12 |
+
"deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \
|
| 13 |
+
$(. /etc/os-release && echo "$VERSION_CODENAME") stable" |
|
| 14 |
+
sudo tee /etc/apt/sources.list.d/docker.list >/dev/null
|
| 15 |
+
sudo apt-get update
|
| 16 |
+
|
| 17 |
+
sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin
|
| 18 |
+
|
| 19 |
+
# Add current user to the 'docker' group, which allows them to use docker commands (docker build, docker run, etc).
|
| 20 |
+
# See https://docs.docker.com/engine/install/linux-postinstall/
|
| 21 |
+
username=$(whoami)
|
| 22 |
+
sudo usermod -aG docker $username
|
| 23 |
+
|
| 24 |
+
# Configure docker to start automatically on system boot.
|
| 25 |
+
sudo systemctl enable docker.service
|
| 26 |
+
sudo systemctl enable containerd.service
|
| 27 |
+
|
| 28 |
+
# https://forums.docker.com/t/docker-credential-desktop-exe-executable-file-not-found-in-path-using-wsl2/100225/5
|
| 29 |
+
if [ ~/.docker/config.json ]; then
|
| 30 |
+
sed -i 's/credsStore/credStore/g' ~/.docker/config.json
|
| 31 |
+
fi
|
| 32 |
+
|
| 33 |
+
echo ""
|
| 34 |
+
echo "********************************************************************"
|
| 35 |
+
echo "**** Restart to allow Docker permission changes to take effect. ****"
|
| 36 |
+
echo "********************************************************************"
|
| 37 |
+
echo ""
|
capvector-pi05/scripts/docker/install_nvidia_container_toolkit.sh
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Installs the NVIDIA Container Toolkit, which allows Docker containers to access NVIDIA GPUs.
|
| 4 |
+
# NVIDIA's official documentation: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
|
| 5 |
+
|
| 6 |
+
curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg &&
|
| 7 |
+
curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list |
|
| 8 |
+
sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' |
|
| 9 |
+
sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
|
| 10 |
+
|
| 11 |
+
# NVIDIA's documenation omits 'sudo' in the following command, but it is required.
|
| 12 |
+
sudo sed -i -e '/experimental/ s/^#//g' /etc/apt/sources.list.d/nvidia-container-toolkit.list
|
| 13 |
+
sudo apt-get update
|
| 14 |
+
sudo apt-get install -y nvidia-container-toolkit
|
| 15 |
+
|
| 16 |
+
sudo nvidia-ctk runtime configure --runtime=docker
|
| 17 |
+
sudo systemctl restart docker
|
capvector-pi05/scripts/docker/serve_policy.Dockerfile
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dockerfile for serving a PI policy.
|
| 2 |
+
# Based on UV's instructions: https://docs.astral.sh/uv/guides/integration/docker/#developing-in-a-container
|
| 3 |
+
|
| 4 |
+
# Build the container:
|
| 5 |
+
# docker build . -t openpi_server -f scripts/docker/serve_policy.Dockerfile
|
| 6 |
+
|
| 7 |
+
# Run the container:
|
| 8 |
+
# docker run --rm -it --network=host -v .:/app --gpus=all openpi_server /bin/bash
|
| 9 |
+
|
| 10 |
+
FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0
|
| 11 |
+
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
|
| 12 |
+
|
| 13 |
+
WORKDIR /app
|
| 14 |
+
|
| 15 |
+
# Needed because LeRobot uses git-lfs.
|
| 16 |
+
RUN apt-get update && apt-get install -y git git-lfs linux-headers-generic build-essential clang
|
| 17 |
+
|
| 18 |
+
# Copy from the cache instead of linking since it's a mounted volume
|
| 19 |
+
ENV UV_LINK_MODE=copy
|
| 20 |
+
|
| 21 |
+
# Write the virtual environment outside of the project directory so it doesn't
|
| 22 |
+
# leak out of the container when we mount the application code.
|
| 23 |
+
ENV UV_PROJECT_ENVIRONMENT=/.venv
|
| 24 |
+
|
| 25 |
+
# Install the project's dependencies using the lockfile and settings
|
| 26 |
+
RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
|
| 27 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 28 |
+
--mount=type=bind,source=uv.lock,target=uv.lock \
|
| 29 |
+
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
|
| 30 |
+
--mount=type=bind,source=packages/openpi-client/pyproject.toml,target=packages/openpi-client/pyproject.toml \
|
| 31 |
+
--mount=type=bind,source=packages/openpi-client/src,target=packages/openpi-client/src \
|
| 32 |
+
GIT_LFS_SKIP_SMUDGE=1 uv sync --frozen --no-install-project --no-dev
|
| 33 |
+
|
| 34 |
+
# Copy transformers_replace files while preserving directory structure
|
| 35 |
+
COPY src/openpi/models_pytorch/transformers_replace/ /tmp/transformers_replace/
|
| 36 |
+
RUN /.venv/bin/python -c "import transformers; print(transformers.__file__)" | xargs dirname | xargs -I{} cp -r /tmp/transformers_replace/* {} && rm -rf /tmp/transformers_replace
|
| 37 |
+
|
| 38 |
+
CMD /bin/bash -c "uv run scripts/serve_policy.py $SERVER_ARGS"
|
capvector-pi05/scripts/serve_policy.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import enum
|
| 3 |
+
import logging
|
| 4 |
+
import socket
|
| 5 |
+
|
| 6 |
+
import tyro
|
| 7 |
+
|
| 8 |
+
from openpi.policies import policy as _policy
|
| 9 |
+
from openpi.policies import policy_config as _policy_config
|
| 10 |
+
from openpi.serving import websocket_policy_server
|
| 11 |
+
from openpi.training import config as _config
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class EnvMode(enum.Enum):
|
| 15 |
+
"""Supported environments."""
|
| 16 |
+
|
| 17 |
+
ALOHA = "aloha"
|
| 18 |
+
ALOHA_SIM = "aloha_sim"
|
| 19 |
+
DROID = "droid"
|
| 20 |
+
LIBERO = "libero"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclasses.dataclass
|
| 24 |
+
class Checkpoint:
|
| 25 |
+
"""Load a policy from a trained checkpoint."""
|
| 26 |
+
|
| 27 |
+
# Training config name (e.g., "pi0_aloha_sim").
|
| 28 |
+
config: str
|
| 29 |
+
# Checkpoint directory (e.g., "checkpoints/pi0_aloha_sim/exp/10000").
|
| 30 |
+
dir: str
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclasses.dataclass
|
| 34 |
+
class Default:
|
| 35 |
+
"""Use the default policy for the given environment."""
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclasses.dataclass
|
| 39 |
+
class Args:
|
| 40 |
+
"""Arguments for the serve_policy script."""
|
| 41 |
+
|
| 42 |
+
# Environment to serve the policy for. This is only used when serving default policies.
|
| 43 |
+
env: EnvMode = EnvMode.ALOHA_SIM
|
| 44 |
+
|
| 45 |
+
# If provided, will be used in case the "prompt" key is not present in the data, or if the model doesn't have a default
|
| 46 |
+
# prompt.
|
| 47 |
+
default_prompt: str | None = None
|
| 48 |
+
|
| 49 |
+
# Port to serve the policy on.
|
| 50 |
+
port: int = 8000
|
| 51 |
+
# Record the policy's behavior for debugging.
|
| 52 |
+
record: bool = False
|
| 53 |
+
|
| 54 |
+
# Specifies how to load the policy. If not provided, the default policy for the environment will be used.
|
| 55 |
+
policy: Checkpoint | Default = dataclasses.field(default_factory=Default)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# Default checkpoints that should be used for each environment.
|
| 59 |
+
DEFAULT_CHECKPOINT: dict[EnvMode, Checkpoint] = {
|
| 60 |
+
EnvMode.ALOHA: Checkpoint(
|
| 61 |
+
config="pi05_aloha",
|
| 62 |
+
dir="gs://openpi-assets/checkpoints/pi05_base",
|
| 63 |
+
),
|
| 64 |
+
EnvMode.ALOHA_SIM: Checkpoint(
|
| 65 |
+
config="pi0_aloha_sim",
|
| 66 |
+
dir="gs://openpi-assets/checkpoints/pi0_aloha_sim",
|
| 67 |
+
),
|
| 68 |
+
EnvMode.DROID: Checkpoint(
|
| 69 |
+
config="pi05_droid",
|
| 70 |
+
dir="gs://openpi-assets/checkpoints/pi05_droid",
|
| 71 |
+
),
|
| 72 |
+
EnvMode.LIBERO: Checkpoint(
|
| 73 |
+
config="pi05_libero",
|
| 74 |
+
dir="gs://openpi-assets/checkpoints/pi05_libero",
|
| 75 |
+
),
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def create_default_policy(env: EnvMode, *, default_prompt: str | None = None) -> _policy.Policy:
|
| 80 |
+
"""Create a default policy for the given environment."""
|
| 81 |
+
if checkpoint := DEFAULT_CHECKPOINT.get(env):
|
| 82 |
+
return _policy_config.create_trained_policy(
|
| 83 |
+
_config.get_config(checkpoint.config), checkpoint.dir, default_prompt=default_prompt
|
| 84 |
+
)
|
| 85 |
+
raise ValueError(f"Unsupported environment mode: {env}")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def create_policy(args: Args) -> _policy.Policy:
|
| 89 |
+
"""Create a policy from the given arguments."""
|
| 90 |
+
match args.policy:
|
| 91 |
+
case Checkpoint():
|
| 92 |
+
return _policy_config.create_trained_policy(
|
| 93 |
+
_config.get_config(args.policy.config), args.policy.dir, default_prompt=args.default_prompt
|
| 94 |
+
)
|
| 95 |
+
case Default():
|
| 96 |
+
return create_default_policy(args.env, default_prompt=args.default_prompt)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def main(args: Args) -> None:
|
| 100 |
+
policy = create_policy(args)
|
| 101 |
+
policy_metadata = policy.metadata
|
| 102 |
+
|
| 103 |
+
# Record the policy's behavior.
|
| 104 |
+
if args.record:
|
| 105 |
+
policy = _policy.PolicyRecorder(policy, "policy_records")
|
| 106 |
+
|
| 107 |
+
hostname = socket.gethostname()
|
| 108 |
+
local_ip = socket.gethostbyname(hostname)
|
| 109 |
+
logging.info("Creating server (host: %s, ip: %s)", hostname, local_ip)
|
| 110 |
+
|
| 111 |
+
server = websocket_policy_server.WebsocketPolicyServer(
|
| 112 |
+
policy=policy,
|
| 113 |
+
host="0.0.0.0",
|
| 114 |
+
port=args.port,
|
| 115 |
+
metadata=policy_metadata,
|
| 116 |
+
)
|
| 117 |
+
server.serve_forever()
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
if __name__ == "__main__":
|
| 121 |
+
logging.basicConfig(level=logging.INFO, force=True)
|
| 122 |
+
main(tyro.cli(Args))
|
capvector-pi05/scripts/train.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import functools
|
| 3 |
+
import logging
|
| 4 |
+
import platform
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
import etils.epath as epath
|
| 8 |
+
import flax.nnx as nnx
|
| 9 |
+
from flax.training import common_utils
|
| 10 |
+
import flax.traverse_util as traverse_util
|
| 11 |
+
import jax
|
| 12 |
+
import jax.experimental
|
| 13 |
+
import jax.numpy as jnp
|
| 14 |
+
import numpy as np
|
| 15 |
+
import optax
|
| 16 |
+
import tqdm_loggable.auto as tqdm
|
| 17 |
+
import wandb
|
| 18 |
+
|
| 19 |
+
import openpi.models.model as _model
|
| 20 |
+
import openpi.shared.array_typing as at
|
| 21 |
+
import openpi.shared.nnx_utils as nnx_utils
|
| 22 |
+
import openpi.training.checkpoints as _checkpoints
|
| 23 |
+
import openpi.training.config as _config
|
| 24 |
+
import openpi.training.data_loader as _data_loader
|
| 25 |
+
import openpi.training.optimizer as _optimizer
|
| 26 |
+
import openpi.training.sharding as sharding
|
| 27 |
+
import openpi.training.utils as training_utils
|
| 28 |
+
import openpi.training.weight_loaders as _weight_loaders
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def init_logging():
|
| 32 |
+
"""Custom logging format for better readability."""
|
| 33 |
+
level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
|
| 34 |
+
|
| 35 |
+
class CustomFormatter(logging.Formatter):
|
| 36 |
+
def format(self, record):
|
| 37 |
+
record.levelname = level_mapping.get(record.levelname, record.levelname)
|
| 38 |
+
return super().format(record)
|
| 39 |
+
|
| 40 |
+
formatter = CustomFormatter(
|
| 41 |
+
fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
|
| 42 |
+
datefmt="%H:%M:%S",
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
logger = logging.getLogger()
|
| 46 |
+
logger.setLevel(logging.INFO)
|
| 47 |
+
logger.handlers[0].setFormatter(formatter)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def init_wandb(config: _config.TrainConfig, *, resuming: bool, log_code: bool = False, enabled: bool = True):
|
| 51 |
+
if not enabled:
|
| 52 |
+
wandb.init(mode="disabled")
|
| 53 |
+
return
|
| 54 |
+
|
| 55 |
+
ckpt_dir = config.checkpoint_dir
|
| 56 |
+
if not ckpt_dir.exists():
|
| 57 |
+
raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
|
| 58 |
+
if resuming:
|
| 59 |
+
run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
|
| 60 |
+
wandb.init(id=run_id, resume="must", project=config.project_name)
|
| 61 |
+
else:
|
| 62 |
+
wandb.init(
|
| 63 |
+
name=config.exp_name,
|
| 64 |
+
config=dataclasses.asdict(config),
|
| 65 |
+
project=config.project_name,
|
| 66 |
+
)
|
| 67 |
+
(ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
|
| 68 |
+
|
| 69 |
+
if log_code:
|
| 70 |
+
wandb.run.log_code(epath.Path(__file__).parent.parent)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _load_weights_and_validate(loader: _weight_loaders.WeightLoader, params_shape: at.Params) -> at.Params:
|
| 74 |
+
"""Loads and validates the weights. Returns a loaded subset of the weights."""
|
| 75 |
+
loaded_params = loader.load(params_shape)
|
| 76 |
+
at.check_pytree_equality(expected=params_shape, got=loaded_params, check_shapes=True, check_dtypes=True)
|
| 77 |
+
|
| 78 |
+
# Remove jax.ShapeDtypeStruct from the loaded params. This makes sure that only the loaded params are returned.
|
| 79 |
+
return traverse_util.unflatten_dict(
|
| 80 |
+
{k: v for k, v in traverse_util.flatten_dict(loaded_params).items() if not isinstance(v, jax.ShapeDtypeStruct)}
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@at.typecheck
|
| 85 |
+
def init_train_state(
|
| 86 |
+
config: _config.TrainConfig, init_rng: at.KeyArrayLike, mesh: jax.sharding.Mesh, *, resume: bool
|
| 87 |
+
) -> tuple[training_utils.TrainState, Any]:
|
| 88 |
+
tx = _optimizer.create_optimizer(config.optimizer, config.lr_schedule, weight_decay_mask=None)
|
| 89 |
+
|
| 90 |
+
def init(rng: at.KeyArrayLike, partial_params: at.Params | None = None) -> training_utils.TrainState:
|
| 91 |
+
rng, model_rng = jax.random.split(rng)
|
| 92 |
+
# initialize the model (and its parameters).
|
| 93 |
+
model = config.model.create(model_rng)
|
| 94 |
+
|
| 95 |
+
# Merge the partial params into the model.
|
| 96 |
+
if partial_params is not None:
|
| 97 |
+
graphdef, state = nnx.split(model)
|
| 98 |
+
# This will produce an error if the partial params are not a subset of the state.
|
| 99 |
+
state.replace_by_pure_dict(partial_params)
|
| 100 |
+
model = nnx.merge(graphdef, state)
|
| 101 |
+
|
| 102 |
+
params = nnx.state(model)
|
| 103 |
+
# Convert frozen params to bfloat16.
|
| 104 |
+
params = nnx_utils.state_map(params, config.freeze_filter, lambda p: p.replace(p.value.astype(jnp.bfloat16)))
|
| 105 |
+
|
| 106 |
+
return training_utils.TrainState(
|
| 107 |
+
step=0,
|
| 108 |
+
params=params,
|
| 109 |
+
model_def=nnx.graphdef(model),
|
| 110 |
+
tx=tx,
|
| 111 |
+
opt_state=tx.init(params.filter(config.trainable_filter)),
|
| 112 |
+
ema_decay=config.ema_decay,
|
| 113 |
+
ema_params=None if config.ema_decay is None else params,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
train_state_shape = jax.eval_shape(init, init_rng)
|
| 117 |
+
state_sharding = sharding.fsdp_sharding(train_state_shape, mesh, log=True)
|
| 118 |
+
|
| 119 |
+
if resume:
|
| 120 |
+
return train_state_shape, state_sharding
|
| 121 |
+
|
| 122 |
+
partial_params = _load_weights_and_validate(config.weight_loader, train_state_shape.params.to_pure_dict())
|
| 123 |
+
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
| 124 |
+
|
| 125 |
+
# Initialize the train state and mix in the partial params.
|
| 126 |
+
train_state = jax.jit(
|
| 127 |
+
init,
|
| 128 |
+
donate_argnums=(1,), # donate the partial params buffer.
|
| 129 |
+
in_shardings=replicated_sharding,
|
| 130 |
+
out_shardings=state_sharding,
|
| 131 |
+
)(init_rng, partial_params)
|
| 132 |
+
|
| 133 |
+
return train_state, state_sharding
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@at.typecheck
|
| 137 |
+
def train_step(
|
| 138 |
+
config: _config.TrainConfig,
|
| 139 |
+
rng: at.KeyArrayLike,
|
| 140 |
+
state: training_utils.TrainState,
|
| 141 |
+
batch: tuple[_model.Observation, _model.Actions],
|
| 142 |
+
) -> tuple[training_utils.TrainState, dict[str, at.Array]]:
|
| 143 |
+
model = nnx.merge(state.model_def, state.params)
|
| 144 |
+
model.train()
|
| 145 |
+
|
| 146 |
+
@at.typecheck
|
| 147 |
+
def loss_fn(
|
| 148 |
+
model: _model.BaseModel, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions
|
| 149 |
+
):
|
| 150 |
+
chunked_loss = model.compute_loss(rng, observation, actions, train=True)
|
| 151 |
+
return jnp.mean(chunked_loss)
|
| 152 |
+
|
| 153 |
+
train_rng = jax.random.fold_in(rng, state.step)
|
| 154 |
+
observation, actions = batch
|
| 155 |
+
|
| 156 |
+
# Filter out frozen params.
|
| 157 |
+
diff_state = nnx.DiffState(0, config.trainable_filter)
|
| 158 |
+
loss, grads = nnx.value_and_grad(loss_fn, argnums=diff_state)(model, train_rng, observation, actions)
|
| 159 |
+
|
| 160 |
+
params = state.params.filter(config.trainable_filter)
|
| 161 |
+
updates, new_opt_state = state.tx.update(grads, state.opt_state, params)
|
| 162 |
+
new_params = optax.apply_updates(params, updates)
|
| 163 |
+
|
| 164 |
+
# Update the model in place and return the new full state.
|
| 165 |
+
nnx.update(model, new_params)
|
| 166 |
+
new_params = nnx.state(model)
|
| 167 |
+
|
| 168 |
+
new_state = dataclasses.replace(state, step=state.step + 1, params=new_params, opt_state=new_opt_state)
|
| 169 |
+
if state.ema_decay is not None:
|
| 170 |
+
new_state = dataclasses.replace(
|
| 171 |
+
new_state,
|
| 172 |
+
ema_params=jax.tree.map(
|
| 173 |
+
lambda old, new: state.ema_decay * old + (1 - state.ema_decay) * new, state.ema_params, new_params
|
| 174 |
+
),
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Filter out params that aren't kernels.
|
| 178 |
+
kernel_params = nnx.state(
|
| 179 |
+
model,
|
| 180 |
+
nnx.All(
|
| 181 |
+
nnx.Param,
|
| 182 |
+
nnx.Not(nnx_utils.PathRegex(".*/(bias|scale|pos_embedding|input_embedding)")),
|
| 183 |
+
lambda _, x: x.value.ndim > 1,
|
| 184 |
+
),
|
| 185 |
+
)
|
| 186 |
+
info = {
|
| 187 |
+
"loss": loss,
|
| 188 |
+
"grad_norm": optax.global_norm(grads),
|
| 189 |
+
"param_norm": optax.global_norm(kernel_params),
|
| 190 |
+
}
|
| 191 |
+
return new_state, info
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def main(config: _config.TrainConfig):
|
| 195 |
+
init_logging()
|
| 196 |
+
logging.info(f"Running on: {platform.node()}")
|
| 197 |
+
|
| 198 |
+
if config.batch_size % jax.device_count() != 0:
|
| 199 |
+
raise ValueError(
|
| 200 |
+
f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}."
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax").expanduser()))
|
| 204 |
+
|
| 205 |
+
rng = jax.random.key(config.seed)
|
| 206 |
+
train_rng, init_rng = jax.random.split(rng)
|
| 207 |
+
|
| 208 |
+
mesh = sharding.make_mesh(config.fsdp_devices)
|
| 209 |
+
data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(sharding.DATA_AXIS))
|
| 210 |
+
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
| 211 |
+
|
| 212 |
+
checkpoint_manager, resuming = _checkpoints.initialize_checkpoint_dir(
|
| 213 |
+
config.checkpoint_dir,
|
| 214 |
+
keep_period=config.keep_period,
|
| 215 |
+
overwrite=config.overwrite,
|
| 216 |
+
resume=config.resume,
|
| 217 |
+
)
|
| 218 |
+
init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
|
| 219 |
+
|
| 220 |
+
data_loader = _data_loader.create_data_loader(
|
| 221 |
+
config,
|
| 222 |
+
sharding=data_sharding,
|
| 223 |
+
shuffle=True,
|
| 224 |
+
)
|
| 225 |
+
data_iter = iter(data_loader)
|
| 226 |
+
batch = next(data_iter)
|
| 227 |
+
logging.info(f"Initialized data loader:\n{training_utils.array_tree_to_info(batch)}")
|
| 228 |
+
|
| 229 |
+
# Log images from first batch to sanity check.
|
| 230 |
+
images_to_log = [
|
| 231 |
+
wandb.Image(np.concatenate([np.array(img[i]) for img in batch[0].images.values()], axis=1))
|
| 232 |
+
for i in range(min(5, len(next(iter(batch[0].images.values())))))
|
| 233 |
+
]
|
| 234 |
+
wandb.log({"camera_views": images_to_log}, step=0)
|
| 235 |
+
|
| 236 |
+
train_state, train_state_sharding = init_train_state(config, init_rng, mesh, resume=resuming)
|
| 237 |
+
jax.block_until_ready(train_state)
|
| 238 |
+
logging.info(f"Initialized train state:\n{training_utils.array_tree_to_info(train_state.params)}")
|
| 239 |
+
|
| 240 |
+
if resuming:
|
| 241 |
+
train_state = _checkpoints.restore_state(checkpoint_manager, train_state, data_loader)
|
| 242 |
+
|
| 243 |
+
ptrain_step = jax.jit(
|
| 244 |
+
functools.partial(train_step, config),
|
| 245 |
+
in_shardings=(replicated_sharding, train_state_sharding, data_sharding),
|
| 246 |
+
out_shardings=(train_state_sharding, replicated_sharding),
|
| 247 |
+
donate_argnums=(1,),
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
start_step = int(train_state.step)
|
| 251 |
+
pbar = tqdm.tqdm(
|
| 252 |
+
range(start_step, config.num_train_steps),
|
| 253 |
+
initial=start_step,
|
| 254 |
+
total=config.num_train_steps,
|
| 255 |
+
dynamic_ncols=True,
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
infos = []
|
| 259 |
+
for step in pbar:
|
| 260 |
+
with sharding.set_mesh(mesh):
|
| 261 |
+
train_state, info = ptrain_step(train_rng, train_state, batch)
|
| 262 |
+
infos.append(info)
|
| 263 |
+
if step % config.log_interval == 0:
|
| 264 |
+
stacked_infos = common_utils.stack_forest(infos)
|
| 265 |
+
reduced_info = jax.device_get(jax.tree.map(jnp.mean, stacked_infos))
|
| 266 |
+
info_str = ", ".join(f"{k}={v:.4f}" for k, v in reduced_info.items())
|
| 267 |
+
pbar.write(f"Step {step}: {info_str}")
|
| 268 |
+
wandb.log(reduced_info, step=step)
|
| 269 |
+
infos = []
|
| 270 |
+
batch = next(data_iter)
|
| 271 |
+
|
| 272 |
+
if (step % config.save_interval == 0 and step > start_step) or step == config.num_train_steps - 1:
|
| 273 |
+
_checkpoints.save_state(checkpoint_manager, train_state, data_loader, step)
|
| 274 |
+
|
| 275 |
+
logging.info("Waiting for checkpoint manager to finish")
|
| 276 |
+
checkpoint_manager.wait_until_finished()
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
if __name__ == "__main__":
|
| 280 |
+
main(_config.cli())
|
capvector-pi05/scripts/train_align_pytorch.py
ADDED
|
@@ -0,0 +1,658 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PyTorch training entrypoint for PI0/PI05 with multi-GPU and multi-node (DDP) support.
|
| 3 |
+
This script mirrors the behavior of the JAX trainer (`scripts/train.py`) but runs
|
| 4 |
+
entirely in PyTorch using the `PI0Pytorch` model and your existing config/data
|
| 5 |
+
pipeline from `src/openpi/training/config.py` and `src/openpi/training/data_loader.py`.
|
| 6 |
+
|
| 7 |
+
Usage
|
| 8 |
+
Single GPU:
|
| 9 |
+
python scripts/train_pytorch.py <config_name> --exp_name <run_name> --save_interval <interval>
|
| 10 |
+
Example:
|
| 11 |
+
python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test
|
| 12 |
+
python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test --resume # Resume from latest checkpoint
|
| 13 |
+
Multi-GPU (single node):
|
| 14 |
+
torchrun --standalone --nnodes=1 --nproc_per_node=<num_gpus> scripts/train_pytorch.py <config_name> --exp_name <run_name>
|
| 15 |
+
Example:
|
| 16 |
+
torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test
|
| 17 |
+
torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume
|
| 18 |
+
Multi-Node Training:
|
| 19 |
+
torchrun \
|
| 20 |
+
--nnodes=<num_nodes> --nproc_per_node=<gpus_per_node> --node_rank=<rank_of_node> \
|
| 21 |
+
--master_addr=<master_ip> --master_port=<port> \
|
| 22 |
+
scripts/train_pytorch.py <config_name> --exp_name=<run_name> --save_interval <interval>
|
| 23 |
+
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import dataclasses
|
| 27 |
+
import gc
|
| 28 |
+
import logging
|
| 29 |
+
import os
|
| 30 |
+
import platform
|
| 31 |
+
import shutil
|
| 32 |
+
import time
|
| 33 |
+
|
| 34 |
+
import jax
|
| 35 |
+
import numpy as np
|
| 36 |
+
import safetensors.torch
|
| 37 |
+
import torch
|
| 38 |
+
import torch.distributed as dist
|
| 39 |
+
import torch.nn.parallel
|
| 40 |
+
import tqdm
|
| 41 |
+
import wandb
|
| 42 |
+
|
| 43 |
+
import openpi.models.pi0_config
|
| 44 |
+
from openpi.models_pytorch import pi0_pytorch, pi0_align_pytorch, projectors
|
| 45 |
+
import openpi.shared.normalize as _normalize
|
| 46 |
+
import openpi.training.config as _config
|
| 47 |
+
import openpi.training.data_loader as _data
|
| 48 |
+
|
| 49 |
+
from vggt.models.vggt import VGGT
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def init_logging():
|
| 53 |
+
level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
|
| 54 |
+
|
| 55 |
+
class CustomFormatter(logging.Formatter):
|
| 56 |
+
def format(self, record):
|
| 57 |
+
record.levelname = level_mapping.get(record.levelname, record.levelname)
|
| 58 |
+
return super().format(record)
|
| 59 |
+
|
| 60 |
+
formatter = CustomFormatter(
|
| 61 |
+
fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
|
| 62 |
+
datefmt="%H:%M:%S",
|
| 63 |
+
)
|
| 64 |
+
logger = logging.getLogger()
|
| 65 |
+
logger.setLevel(logging.INFO)
|
| 66 |
+
if not logger.handlers:
|
| 67 |
+
ch = logging.StreamHandler()
|
| 68 |
+
ch.setFormatter(formatter)
|
| 69 |
+
logger.addHandler(ch)
|
| 70 |
+
else:
|
| 71 |
+
logger.handlers[0].setFormatter(formatter)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: bool = True):
|
| 75 |
+
"""Initialize wandb logging."""
|
| 76 |
+
if not enabled:
|
| 77 |
+
wandb.init(mode="disabled")
|
| 78 |
+
return
|
| 79 |
+
|
| 80 |
+
ckpt_dir = config.checkpoint_dir
|
| 81 |
+
if not ckpt_dir.exists():
|
| 82 |
+
raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
|
| 83 |
+
|
| 84 |
+
if resuming:
|
| 85 |
+
run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
|
| 86 |
+
wandb.init(id=run_id, resume="must", project=config.project_name)
|
| 87 |
+
else:
|
| 88 |
+
wandb.init(
|
| 89 |
+
name=config.exp_name,
|
| 90 |
+
config=dataclasses.asdict(config),
|
| 91 |
+
project=config.project_name,
|
| 92 |
+
)
|
| 93 |
+
(ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def setup_ddp():
|
| 97 |
+
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
| 98 |
+
use_ddp = world_size > 1
|
| 99 |
+
if use_ddp and not torch.distributed.is_initialized():
|
| 100 |
+
backend = "nccl" if torch.cuda.is_available() else "gloo"
|
| 101 |
+
torch.distributed.init_process_group(backend=backend, init_method="env://")
|
| 102 |
+
|
| 103 |
+
# Set up debugging environment variables for DDP issues
|
| 104 |
+
if os.environ.get("TORCH_DISTRIBUTED_DEBUG") is None:
|
| 105 |
+
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO"
|
| 106 |
+
|
| 107 |
+
local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0")))
|
| 108 |
+
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
|
| 109 |
+
if torch.cuda.is_available():
|
| 110 |
+
torch.cuda.set_device(device)
|
| 111 |
+
return use_ddp, local_rank, device
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def cleanup_ddp():
|
| 115 |
+
if torch.distributed.is_initialized():
|
| 116 |
+
torch.distributed.barrier()
|
| 117 |
+
torch.distributed.destroy_process_group()
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def set_seed(seed: int, local_rank: int):
|
| 121 |
+
torch.manual_seed(seed + local_rank)
|
| 122 |
+
np.random.seed(seed + local_rank)
|
| 123 |
+
if torch.cuda.is_available():
|
| 124 |
+
torch.cuda.manual_seed_all(seed + local_rank)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def build_datasets(config: _config.TrainConfig):
|
| 128 |
+
# Use the unified data loader with PyTorch framework
|
| 129 |
+
data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=True)
|
| 130 |
+
return data_loader, data_loader.data_config()
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def get_model_state_dict(model):
|
| 134 |
+
"""Get state dict from model, handling DDP wrapper."""
|
| 135 |
+
return (
|
| 136 |
+
model.module.state_dict()
|
| 137 |
+
if isinstance(model, torch.nn.parallel.DistributedDataParallel)
|
| 138 |
+
else model.state_dict()
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def get_model_parameters(model):
|
| 143 |
+
"""Get parameters from model, handling DDP wrapper."""
|
| 144 |
+
return (
|
| 145 |
+
model.module.parameters()
|
| 146 |
+
if isinstance(model, torch.nn.parallel.DistributedDataParallel)
|
| 147 |
+
else model.parameters()
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def save_checkpoint(model, optimizer, global_step, config, is_main, data_config):
|
| 152 |
+
"""Save a checkpoint with model state, optimizer state, and metadata."""
|
| 153 |
+
if not is_main:
|
| 154 |
+
return
|
| 155 |
+
|
| 156 |
+
# Only save if it's time to save or if it's the final step
|
| 157 |
+
if (global_step % config.save_interval == 0 and global_step > 0) or global_step == config.num_train_steps - 1:
|
| 158 |
+
# Create temporary directory for atomic checkpoint saving
|
| 159 |
+
final_ckpt_dir = config.checkpoint_dir / f"{global_step}"
|
| 160 |
+
tmp_ckpt_dir = config.checkpoint_dir / f"tmp_{global_step}"
|
| 161 |
+
|
| 162 |
+
# Remove any existing temp directory and create new one
|
| 163 |
+
if tmp_ckpt_dir.exists():
|
| 164 |
+
shutil.rmtree(tmp_ckpt_dir)
|
| 165 |
+
tmp_ckpt_dir.mkdir(parents=True, exist_ok=True)
|
| 166 |
+
|
| 167 |
+
# Save model state using safetensors (handle shared tensors)
|
| 168 |
+
model_to_save = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
|
| 169 |
+
safetensors.torch.save_model(model_to_save, tmp_ckpt_dir / "model.safetensors")
|
| 170 |
+
|
| 171 |
+
# Save optimizer state using PyTorch format
|
| 172 |
+
torch.save(optimizer.state_dict(), tmp_ckpt_dir / "optimizer.pt")
|
| 173 |
+
|
| 174 |
+
# Save training metadata (avoid saving full config to prevent JAX/Flax compatibility issues)
|
| 175 |
+
metadata = {
|
| 176 |
+
"global_step": global_step,
|
| 177 |
+
"config": dataclasses.asdict(config),
|
| 178 |
+
"timestamp": time.time(),
|
| 179 |
+
}
|
| 180 |
+
torch.save(metadata, tmp_ckpt_dir / "metadata.pt")
|
| 181 |
+
|
| 182 |
+
# save norm stats
|
| 183 |
+
norm_stats = data_config.norm_stats
|
| 184 |
+
if norm_stats is not None and data_config.asset_id is not None:
|
| 185 |
+
_normalize.save(tmp_ckpt_dir / "assets" / data_config.asset_id, norm_stats)
|
| 186 |
+
|
| 187 |
+
# Atomically move temp directory to final location
|
| 188 |
+
if final_ckpt_dir.exists():
|
| 189 |
+
shutil.rmtree(final_ckpt_dir)
|
| 190 |
+
tmp_ckpt_dir.rename(final_ckpt_dir)
|
| 191 |
+
|
| 192 |
+
logging.info(f"Saved checkpoint at step {global_step} -> {final_ckpt_dir}")
|
| 193 |
+
|
| 194 |
+
# Log checkpoint to wandb
|
| 195 |
+
if config.wandb_enabled:
|
| 196 |
+
wandb.log({"checkpoint_step": global_step}, step=global_step)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def load_checkpoint(model, optimizer, checkpoint_dir, device):
|
| 200 |
+
"""Load the latest checkpoint and return the global step."""
|
| 201 |
+
checkpoint_steps = [
|
| 202 |
+
int(d.name)
|
| 203 |
+
for d in checkpoint_dir.iterdir()
|
| 204 |
+
if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_")
|
| 205 |
+
]
|
| 206 |
+
|
| 207 |
+
if not checkpoint_steps:
|
| 208 |
+
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
|
| 209 |
+
|
| 210 |
+
latest_step = max(checkpoint_steps)
|
| 211 |
+
ckpt_dir = checkpoint_dir / f"{latest_step}"
|
| 212 |
+
|
| 213 |
+
# Clear memory before loading checkpoints
|
| 214 |
+
if torch.cuda.is_available():
|
| 215 |
+
torch.cuda.empty_cache()
|
| 216 |
+
gc.collect()
|
| 217 |
+
log_memory_usage(device, latest_step, "before_loading_checkpoint")
|
| 218 |
+
|
| 219 |
+
try:
|
| 220 |
+
# Load model state with error handling
|
| 221 |
+
logging.info("Loading model state...")
|
| 222 |
+
safetensors_path = ckpt_dir / "model.safetensors"
|
| 223 |
+
|
| 224 |
+
if safetensors_path.exists():
|
| 225 |
+
model_to_load = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
|
| 226 |
+
safetensors.torch.load_model(model_to_load, safetensors_path, device=str(device))
|
| 227 |
+
logging.info("Loaded model state from safetensors format")
|
| 228 |
+
else:
|
| 229 |
+
raise FileNotFoundError(f"No model checkpoint found at {ckpt_dir}")
|
| 230 |
+
|
| 231 |
+
torch.cuda.empty_cache()
|
| 232 |
+
gc.collect()
|
| 233 |
+
log_memory_usage(device, latest_step, "after_loading_model")
|
| 234 |
+
|
| 235 |
+
# Load optimizer state with error handling
|
| 236 |
+
logging.info("Loading optimizer state...")
|
| 237 |
+
optimizer_path = ckpt_dir / "optimizer.pt"
|
| 238 |
+
|
| 239 |
+
if optimizer_path.exists():
|
| 240 |
+
optimizer_state_dict = torch.load(optimizer_path, map_location=device, weights_only=False)
|
| 241 |
+
logging.info("Loaded optimizer state from pt format")
|
| 242 |
+
else:
|
| 243 |
+
raise FileNotFoundError(f"No optimizer checkpoint found at {ckpt_dir}")
|
| 244 |
+
|
| 245 |
+
optimizer.load_state_dict(optimizer_state_dict)
|
| 246 |
+
del optimizer_state_dict
|
| 247 |
+
torch.cuda.empty_cache()
|
| 248 |
+
gc.collect()
|
| 249 |
+
log_memory_usage(device, latest_step, "after_loading_optimizer")
|
| 250 |
+
|
| 251 |
+
# Load metadata
|
| 252 |
+
logging.info("Loading metadata...")
|
| 253 |
+
metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False)
|
| 254 |
+
global_step = metadata.get("global_step", latest_step)
|
| 255 |
+
del metadata
|
| 256 |
+
torch.cuda.empty_cache()
|
| 257 |
+
gc.collect()
|
| 258 |
+
log_memory_usage(device, latest_step, "after_loading_metadata")
|
| 259 |
+
|
| 260 |
+
logging.info(f"Successfully loaded all checkpoint components from step {latest_step}")
|
| 261 |
+
return global_step
|
| 262 |
+
|
| 263 |
+
except RuntimeError as e:
|
| 264 |
+
if "out of memory" in str(e):
|
| 265 |
+
# Clear memory and provide detailed error message
|
| 266 |
+
torch.cuda.empty_cache()
|
| 267 |
+
gc.collect()
|
| 268 |
+
logging.error(f"Out of memory error while loading checkpoint: {e!s}")
|
| 269 |
+
log_memory_usage(device, latest_step, "after_oom_error")
|
| 270 |
+
raise RuntimeError(
|
| 271 |
+
"Out of memory while loading checkpoint. Try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
|
| 272 |
+
) from e
|
| 273 |
+
raise
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def get_latest_checkpoint_step(checkpoint_dir):
|
| 277 |
+
"""Get the latest checkpoint step number from a checkpoint directory."""
|
| 278 |
+
checkpoint_steps = [
|
| 279 |
+
int(d.name)
|
| 280 |
+
for d in checkpoint_dir.iterdir()
|
| 281 |
+
if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_")
|
| 282 |
+
]
|
| 283 |
+
return max(checkpoint_steps) if checkpoint_steps else None
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def log_memory_usage(device, step, phase="unknown"):
|
| 287 |
+
"""Log detailed memory usage information."""
|
| 288 |
+
if not torch.cuda.is_available():
|
| 289 |
+
return
|
| 290 |
+
|
| 291 |
+
memory_allocated = torch.cuda.memory_allocated(device) / 1e9
|
| 292 |
+
memory_reserved = torch.cuda.memory_reserved(device) / 1e9
|
| 293 |
+
memory_free = torch.cuda.memory_reserved(device) - torch.cuda.memory_allocated(device)
|
| 294 |
+
memory_free = memory_free / 1e9
|
| 295 |
+
|
| 296 |
+
# Get more detailed memory info
|
| 297 |
+
memory_stats = torch.cuda.memory_stats(device)
|
| 298 |
+
max_memory_allocated = memory_stats.get("allocated_bytes.all.peak", 0) / 1e9
|
| 299 |
+
max_memory_reserved = memory_stats.get("reserved_bytes.all.peak", 0) / 1e9
|
| 300 |
+
|
| 301 |
+
# Get DDP info if available
|
| 302 |
+
ddp_info = ""
|
| 303 |
+
if dist.is_initialized():
|
| 304 |
+
ddp_info = f" | DDP: rank={dist.get_rank()}, world_size={dist.get_world_size()}"
|
| 305 |
+
|
| 306 |
+
logging.info(
|
| 307 |
+
f"Step {step} ({phase}): GPU memory - allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB, free: {memory_free:.2f}GB, peak_allocated: {max_memory_allocated:.2f}GB, peak_reserved: {max_memory_reserved:.2f}GB{ddp_info}"
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def train_loop(config: _config.TrainConfig):
|
| 312 |
+
use_ddp, local_rank, device = setup_ddp()
|
| 313 |
+
is_main = (not use_ddp) or (dist.get_rank() == 0)
|
| 314 |
+
set_seed(config.seed, local_rank)
|
| 315 |
+
|
| 316 |
+
# Initialize checkpoint directory and wandb
|
| 317 |
+
resuming = False
|
| 318 |
+
if config.resume:
|
| 319 |
+
# Find checkpoint directory based on experiment name
|
| 320 |
+
exp_checkpoint_dir = config.checkpoint_dir
|
| 321 |
+
if exp_checkpoint_dir.exists():
|
| 322 |
+
# Use validation to find the latest working checkpoint
|
| 323 |
+
latest_step = get_latest_checkpoint_step(exp_checkpoint_dir)
|
| 324 |
+
if latest_step is not None:
|
| 325 |
+
resuming = True
|
| 326 |
+
logging.info(
|
| 327 |
+
f"Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}"
|
| 328 |
+
)
|
| 329 |
+
else:
|
| 330 |
+
raise FileNotFoundError(f"No valid checkpoints found in {exp_checkpoint_dir} for resume")
|
| 331 |
+
else:
|
| 332 |
+
raise FileNotFoundError(f"Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume")
|
| 333 |
+
elif config.overwrite and config.checkpoint_dir.exists():
|
| 334 |
+
shutil.rmtree(config.checkpoint_dir)
|
| 335 |
+
logging.info(f"Overwriting checkpoint directory: {config.checkpoint_dir}")
|
| 336 |
+
|
| 337 |
+
# Create checkpoint directory with experiment name
|
| 338 |
+
if not resuming:
|
| 339 |
+
# For new runs, create experiment-specific checkpoint directory
|
| 340 |
+
exp_checkpoint_dir = config.checkpoint_dir
|
| 341 |
+
exp_checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 342 |
+
logging.info(f"Created experiment checkpoint directory: {exp_checkpoint_dir}")
|
| 343 |
+
else:
|
| 344 |
+
# For resume, checkpoint_dir is already set to the experiment directory
|
| 345 |
+
logging.info(f"Using existing experiment checkpoint directory: {config.checkpoint_dir}")
|
| 346 |
+
|
| 347 |
+
# Initialize wandb (only on main process)
|
| 348 |
+
if is_main:
|
| 349 |
+
init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
|
| 350 |
+
|
| 351 |
+
# Build data loader using the unified data loader
|
| 352 |
+
# Calculate effective batch size per GPU for DDP
|
| 353 |
+
# For N GPUs, each GPU should get batch_size/N samples, so total across all GPUs is batch_size
|
| 354 |
+
world_size = torch.distributed.get_world_size() if use_ddp else 1
|
| 355 |
+
effective_batch_size = config.batch_size // world_size
|
| 356 |
+
logging.info(
|
| 357 |
+
f"Using batch size per GPU: {effective_batch_size} (total batch size across {world_size} GPUs: {config.batch_size})"
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
# Pass the original batch size to data loader - it will handle DDP splitting internally
|
| 361 |
+
loader, data_config = build_datasets(config)
|
| 362 |
+
|
| 363 |
+
# Log sample images to wandb on first batch
|
| 364 |
+
if is_main and config.wandb_enabled and not resuming:
|
| 365 |
+
# Create a separate data loader for sample batch to avoid consuming the main loader
|
| 366 |
+
sample_data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=False)
|
| 367 |
+
sample_batch = next(iter(sample_data_loader))
|
| 368 |
+
# Convert observation and actions to torch tensors
|
| 369 |
+
observation, actions = sample_batch
|
| 370 |
+
sample_batch = observation.to_dict()
|
| 371 |
+
sample_batch["actions"] = actions
|
| 372 |
+
|
| 373 |
+
# Create sample images for wandb
|
| 374 |
+
images_to_log = []
|
| 375 |
+
# Get batch size from the first image tensor
|
| 376 |
+
batch_size = next(iter(sample_batch["image"].values())).shape[0]
|
| 377 |
+
for i in range(min(5, batch_size)):
|
| 378 |
+
# Concatenate all camera views horizontally for this batch item
|
| 379 |
+
# Convert from NCHW to NHWC format for wandb
|
| 380 |
+
img_concatenated = torch.cat([img[i].permute(1, 2, 0) for img in sample_batch["image"].values()], axis=1)
|
| 381 |
+
img_concatenated = img_concatenated.cpu().numpy()
|
| 382 |
+
images_to_log.append(wandb.Image(img_concatenated))
|
| 383 |
+
|
| 384 |
+
wandb.log({"camera_views": images_to_log}, step=0)
|
| 385 |
+
|
| 386 |
+
# Clear sample batch from memory aggressively
|
| 387 |
+
del sample_batch, observation, actions, images_to_log, img_concatenated
|
| 388 |
+
del sample_data_loader # Also delete the sample data loader
|
| 389 |
+
gc.collect()
|
| 390 |
+
if torch.cuda.is_available():
|
| 391 |
+
torch.cuda.empty_cache()
|
| 392 |
+
logging.info("Cleared sample batch and data loader from memory")
|
| 393 |
+
|
| 394 |
+
# Build model
|
| 395 |
+
if not isinstance(config.model, openpi.models.pi0_config.Pi0Config):
|
| 396 |
+
# Convert dataclass to Pi0Config if needed
|
| 397 |
+
model_cfg = openpi.models.pi0_config.Pi0Config(
|
| 398 |
+
dtype=config.pytorch_training_precision,
|
| 399 |
+
action_dim=config.model.action_dim,
|
| 400 |
+
action_horizon=config.model.action_horizon,
|
| 401 |
+
max_token_len=config.model.max_token_len,
|
| 402 |
+
paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"),
|
| 403 |
+
action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"),
|
| 404 |
+
pi05=getattr(config.model, "pi05", False),
|
| 405 |
+
)
|
| 406 |
+
else:
|
| 407 |
+
model_cfg = config.model
|
| 408 |
+
# Update dtype to match pytorch_training_precision
|
| 409 |
+
object.__setattr__(model_cfg, "dtype", config.pytorch_training_precision)
|
| 410 |
+
|
| 411 |
+
model = openpi.models_pytorch.pi0_align_pytorch.PI0Pytorch(model_cfg, config).to(device)
|
| 412 |
+
vggt_model = VGGT(
|
| 413 |
+
enable_camera=False,
|
| 414 |
+
enable_point=False,
|
| 415 |
+
enable_depth=False,
|
| 416 |
+
enable_track=False,
|
| 417 |
+
feature_only=True,
|
| 418 |
+
).to(device)
|
| 419 |
+
align_projector = projectors.AlignProjector(
|
| 420 |
+
model.LLM_width,
|
| 421 |
+
config.vggt_dim,
|
| 422 |
+
config.use_vlm_norm).to(device)
|
| 423 |
+
|
| 424 |
+
if hasattr(model, "gradient_checkpointing_enable"):
|
| 425 |
+
enable_gradient_checkpointing = True
|
| 426 |
+
model.gradient_checkpointing_enable()
|
| 427 |
+
logging.info("Enabled gradient checkpointing for memory optimization")
|
| 428 |
+
else:
|
| 429 |
+
enable_gradient_checkpointing = False
|
| 430 |
+
logging.info("Gradient checkpointing is not supported for this model")
|
| 431 |
+
|
| 432 |
+
# Log initial memory usage after model creation
|
| 433 |
+
if is_main and torch.cuda.is_available():
|
| 434 |
+
log_memory_usage(device, 0, "after_model_creation")
|
| 435 |
+
|
| 436 |
+
# Enable memory optimizations for large-scale training
|
| 437 |
+
if world_size >= 8:
|
| 438 |
+
torch.backends.cudnn.benchmark = True
|
| 439 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 440 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 441 |
+
# Set memory allocation configuration
|
| 442 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
|
| 443 |
+
logging.info("Enabled memory optimizations for 8+ GPU training")
|
| 444 |
+
|
| 445 |
+
if use_ddp:
|
| 446 |
+
model = torch.nn.parallel.DistributedDataParallel(
|
| 447 |
+
model,
|
| 448 |
+
device_ids=[device.index] if device.type == "cuda" else None,
|
| 449 |
+
find_unused_parameters=True, # Disable for memory efficiency
|
| 450 |
+
gradient_as_bucket_view=True, # Enable for memory efficiency
|
| 451 |
+
static_graph=world_size >= 8, # Enable for 8+ GPUs
|
| 452 |
+
)
|
| 453 |
+
align_projector = torch.nn.parallel.DistributedDataParallel(
|
| 454 |
+
align_projector,
|
| 455 |
+
device_ids=[device.index] if device.type == "cuda" else None,
|
| 456 |
+
find_unused_parameters=True, # Disable for memory efficiency
|
| 457 |
+
gradient_as_bucket_view=True, # Enable for memory efficiency
|
| 458 |
+
static_graph=world_size >= 8, # Enable for 8+ GPUs
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
# Load weights from weight_loader if specified (for fine-tuning)
|
| 462 |
+
if config.pytorch_weight_path is not None:
|
| 463 |
+
logging.info(f"Loading weights from: {config.pytorch_weight_path}")
|
| 464 |
+
model_path = os.path.join(config.pytorch_weight_path, "model.safetensors")
|
| 465 |
+
safetensors.torch.load_model(
|
| 466 |
+
(model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model),
|
| 467 |
+
model_path,
|
| 468 |
+
strict=False,
|
| 469 |
+
)
|
| 470 |
+
logging.info(f"Loaded PyTorch weights from {config.pytorch_weight_path}")
|
| 471 |
+
if config.vggt_weight_path is not None:
|
| 472 |
+
vggt_path = os.path.join(config.vggt_weight_path, "model.pt")
|
| 473 |
+
if not os.path.exists(vggt_path):
|
| 474 |
+
raise FileNotFoundError(f"VGGT weight file not found at {vggt_path}")
|
| 475 |
+
vggt_model.load_state_dict(torch.load(vggt_path), strict=False)
|
| 476 |
+
logging.info(f"Loaded VGGT weights from {config.vggt_weight_path}")
|
| 477 |
+
|
| 478 |
+
# Optimizer + learning rate schedule from config
|
| 479 |
+
warmup_steps = config.lr_schedule.warmup_steps
|
| 480 |
+
peak_lr = config.lr_schedule.peak_lr
|
| 481 |
+
decay_steps = config.lr_schedule.decay_steps
|
| 482 |
+
end_lr = config.lr_schedule.decay_lr
|
| 483 |
+
|
| 484 |
+
# Create optimizer with config parameters
|
| 485 |
+
optim = torch.optim.AdamW(
|
| 486 |
+
list(model.parameters()) + list(align_projector.parameters()),
|
| 487 |
+
lr=peak_lr,
|
| 488 |
+
betas=(config.optimizer.b1, config.optimizer.b2),
|
| 489 |
+
eps=config.optimizer.eps,
|
| 490 |
+
weight_decay=config.optimizer.weight_decay,
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
# Load checkpoint if resuming
|
| 494 |
+
global_step = 0
|
| 495 |
+
if resuming:
|
| 496 |
+
global_step = load_checkpoint(model, optim, config.checkpoint_dir, device)
|
| 497 |
+
logging.info(f"Resumed training from step {global_step}")
|
| 498 |
+
|
| 499 |
+
def lr_schedule(step: int):
|
| 500 |
+
if step < warmup_steps:
|
| 501 |
+
# Match JAX behavior: start from peak_lr / (warmup_steps + 1)
|
| 502 |
+
init_lr = peak_lr / (warmup_steps + 1)
|
| 503 |
+
return init_lr + (peak_lr - init_lr) * step / warmup_steps
|
| 504 |
+
# cosine decay
|
| 505 |
+
progress = min(1.0, (step - warmup_steps) / max(1, decay_steps - warmup_steps))
|
| 506 |
+
cos = 0.5 * (1 + np.cos(np.pi * progress))
|
| 507 |
+
return end_lr + (peak_lr - end_lr) * cos
|
| 508 |
+
|
| 509 |
+
model.train()
|
| 510 |
+
align_projector.train()
|
| 511 |
+
vggt_model.eval()
|
| 512 |
+
start_time = time.time()
|
| 513 |
+
infos = [] # Collect stats over log interval
|
| 514 |
+
if is_main:
|
| 515 |
+
logging.info(
|
| 516 |
+
f"Running on: {platform.node()} | world_size={torch.distributed.get_world_size() if use_ddp else 1}"
|
| 517 |
+
)
|
| 518 |
+
logging.info(
|
| 519 |
+
f"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}"
|
| 520 |
+
)
|
| 521 |
+
logging.info(f"Memory optimizations: gradient_checkpointing={enable_gradient_checkpointing}")
|
| 522 |
+
logging.info(
|
| 523 |
+
f"LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}"
|
| 524 |
+
)
|
| 525 |
+
logging.info(
|
| 526 |
+
f"Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}"
|
| 527 |
+
)
|
| 528 |
+
logging.info("EMA is not supported for PyTorch training")
|
| 529 |
+
logging.info(f"Training precision: {model_cfg.dtype}")
|
| 530 |
+
|
| 531 |
+
# Training loop - iterate until we reach num_train_steps
|
| 532 |
+
pbar = (
|
| 533 |
+
tqdm.tqdm(total=config.num_train_steps, initial=global_step, desc="Training", disable=not is_main)
|
| 534 |
+
if is_main
|
| 535 |
+
else None
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
while global_step < config.num_train_steps:
|
| 539 |
+
# Set epoch for distributed training
|
| 540 |
+
if use_ddp and hasattr(loader, "set_epoch"):
|
| 541 |
+
loader.set_epoch(global_step // len(loader))
|
| 542 |
+
|
| 543 |
+
for observation, actions in loader:
|
| 544 |
+
# Check if we've reached the target number of steps
|
| 545 |
+
if global_step >= config.num_train_steps:
|
| 546 |
+
break
|
| 547 |
+
|
| 548 |
+
# The unified data loader returns (observation, actions) tuple
|
| 549 |
+
observation = jax.tree.map(lambda x: x.to(device), observation) # noqa: PLW2901
|
| 550 |
+
actions = actions.to(torch.float32) # noqa: PLW2901
|
| 551 |
+
actions = actions.to(device) # noqa: PLW2901
|
| 552 |
+
|
| 553 |
+
# Update LR
|
| 554 |
+
for pg in optim.param_groups:
|
| 555 |
+
pg["lr"] = lr_schedule(global_step)
|
| 556 |
+
|
| 557 |
+
# Forward pass
|
| 558 |
+
action_losses, align_loss = model(observation, actions, vggt=vggt_model, align_proj=align_projector)
|
| 559 |
+
loss = action_losses + config.align_loss_coeff * align_loss
|
| 560 |
+
|
| 561 |
+
# Backward pass
|
| 562 |
+
loss.backward()
|
| 563 |
+
|
| 564 |
+
# Log memory usage after backward pass
|
| 565 |
+
if global_step < 5 and is_main and torch.cuda.is_available():
|
| 566 |
+
log_memory_usage(device, global_step, "after_backward")
|
| 567 |
+
|
| 568 |
+
# Gradient clipping
|
| 569 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optimizer.clip_gradient_norm)
|
| 570 |
+
|
| 571 |
+
# Optimizer step
|
| 572 |
+
optim.step()
|
| 573 |
+
optim.zero_grad(set_to_none=True)
|
| 574 |
+
|
| 575 |
+
# Clear gradients more aggressively
|
| 576 |
+
for param in model.parameters():
|
| 577 |
+
if param.grad is not None:
|
| 578 |
+
param.grad.detach_()
|
| 579 |
+
param.grad = None
|
| 580 |
+
|
| 581 |
+
# Collect stats
|
| 582 |
+
if is_main:
|
| 583 |
+
infos.append(
|
| 584 |
+
{
|
| 585 |
+
"action_loss": action_losses.item(),
|
| 586 |
+
"align_loss": align_loss.item(),
|
| 587 |
+
"learning_rate": optim.param_groups[0]["lr"],
|
| 588 |
+
"grad_norm": float(grad_norm) if isinstance(grad_norm, torch.Tensor) else grad_norm,
|
| 589 |
+
}
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
if is_main and (global_step % config.log_interval == 0):
|
| 593 |
+
elapsed = time.time() - start_time
|
| 594 |
+
|
| 595 |
+
# Average stats over log interval
|
| 596 |
+
avg_loss = sum(info["action_loss"] for info in infos) / len(infos)
|
| 597 |
+
avg_align_loss = sum(info["align_loss"] for info in infos) / len(infos)
|
| 598 |
+
avg_lr = sum(info["learning_rate"] for info in infos) / len(infos)
|
| 599 |
+
|
| 600 |
+
avg_grad_norm = None
|
| 601 |
+
if any("grad_norm" in info for info in infos):
|
| 602 |
+
vals = [
|
| 603 |
+
info["grad_norm"] for info in infos if "grad_norm" in info and info["grad_norm"] is not None
|
| 604 |
+
]
|
| 605 |
+
if len(vals) > 0:
|
| 606 |
+
avg_grad_norm = sum(vals) / len(vals)
|
| 607 |
+
logging.info(
|
| 608 |
+
f"step={global_step} action_loss={avg_loss:.4f} align_loss={avg_align_loss:.4f} lr={avg_lr:.2e} grad_norm={avg_grad_norm:.2f} time={elapsed:.1f}s"
|
| 609 |
+
if avg_grad_norm is not None
|
| 610 |
+
else f"step={global_step} action_loss={avg_loss:.4f} align_loss={avg_align_loss:.4f} lr={avg_lr:.2e} time={elapsed:.1f}s"
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
# Log to wandb
|
| 614 |
+
if config.wandb_enabled and len(infos) > 0:
|
| 615 |
+
log_payload = {
|
| 616 |
+
"action_loss": avg_loss,
|
| 617 |
+
"align_loss": avg_align_loss,
|
| 618 |
+
"learning_rate": avg_lr,
|
| 619 |
+
"step": global_step,
|
| 620 |
+
"time_per_step": elapsed / config.log_interval,
|
| 621 |
+
}
|
| 622 |
+
if avg_grad_norm is not None:
|
| 623 |
+
log_payload["grad_norm"] = avg_grad_norm
|
| 624 |
+
wandb.log(log_payload, step=global_step)
|
| 625 |
+
|
| 626 |
+
start_time = time.time()
|
| 627 |
+
infos = [] # Reset stats collection
|
| 628 |
+
|
| 629 |
+
global_step += 1
|
| 630 |
+
# Save checkpoint using the new mechanism
|
| 631 |
+
save_checkpoint(model, optim, global_step, config, is_main, data_config)
|
| 632 |
+
|
| 633 |
+
# Update progress bar
|
| 634 |
+
if pbar is not None:
|
| 635 |
+
pbar.update(1)
|
| 636 |
+
pbar.set_postfix(
|
| 637 |
+
{"loss": f"{loss.item():.4f}", "lr": f"{optim.param_groups[0]['lr']:.2e}", "step": global_step}
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
# Close progress bar
|
| 641 |
+
if pbar is not None:
|
| 642 |
+
pbar.close()
|
| 643 |
+
|
| 644 |
+
# Finish wandb run
|
| 645 |
+
if is_main and config.wandb_enabled:
|
| 646 |
+
wandb.finish()
|
| 647 |
+
|
| 648 |
+
cleanup_ddp()
|
| 649 |
+
|
| 650 |
+
|
| 651 |
+
def main():
|
| 652 |
+
init_logging()
|
| 653 |
+
config = _config.cli()
|
| 654 |
+
train_loop(config)
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
if __name__ == "__main__":
|
| 658 |
+
main()
|
capvector-pi05/scripts/train_pytorch.py
ADDED
|
@@ -0,0 +1,632 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PyTorch training entrypoint for PI0/PI05 with multi-GPU and multi-node (DDP) support.
|
| 3 |
+
This script mirrors the behavior of the JAX trainer (`scripts/train.py`) but runs
|
| 4 |
+
entirely in PyTorch using the `PI0Pytorch` model and your existing config/data
|
| 5 |
+
pipeline from `src/openpi/training/config.py` and `src/openpi/training/data_loader.py`.
|
| 6 |
+
|
| 7 |
+
Usage
|
| 8 |
+
Single GPU:
|
| 9 |
+
python scripts/train_pytorch.py <config_name> --exp_name <run_name> --save_interval <interval>
|
| 10 |
+
Example:
|
| 11 |
+
python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test
|
| 12 |
+
python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test --resume # Resume from latest checkpoint
|
| 13 |
+
Multi-GPU (single node):
|
| 14 |
+
torchrun --standalone --nnodes=1 --nproc_per_node=<num_gpus> scripts/train_pytorch.py <config_name> --exp_name <run_name>
|
| 15 |
+
Example:
|
| 16 |
+
torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test
|
| 17 |
+
torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume
|
| 18 |
+
Multi-Node Training:
|
| 19 |
+
torchrun \
|
| 20 |
+
--nnodes=<num_nodes> --nproc_per_node=<gpus_per_node> --node_rank=<rank_of_node> \
|
| 21 |
+
--master_addr=<master_ip> --master_port=<port> \
|
| 22 |
+
scripts/train_pytorch.py <config_name> --exp_name=<run_name> --save_interval <interval>
|
| 23 |
+
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import dataclasses
|
| 27 |
+
import gc
|
| 28 |
+
import logging
|
| 29 |
+
import os
|
| 30 |
+
import platform
|
| 31 |
+
import shutil
|
| 32 |
+
import time
|
| 33 |
+
|
| 34 |
+
import jax
|
| 35 |
+
import numpy as np
|
| 36 |
+
import safetensors.torch
|
| 37 |
+
import torch
|
| 38 |
+
import torch.distributed as dist
|
| 39 |
+
import torch.nn.parallel
|
| 40 |
+
import tqdm
|
| 41 |
+
import wandb
|
| 42 |
+
|
| 43 |
+
import openpi.models.pi0_config
|
| 44 |
+
import openpi.models_pytorch.pi0_pytorch
|
| 45 |
+
import openpi.shared.normalize as _normalize
|
| 46 |
+
import openpi.training.config as _config
|
| 47 |
+
import openpi.training.data_loader as _data
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def init_logging():
|
| 51 |
+
level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
|
| 52 |
+
|
| 53 |
+
class CustomFormatter(logging.Formatter):
|
| 54 |
+
def format(self, record):
|
| 55 |
+
record.levelname = level_mapping.get(record.levelname, record.levelname)
|
| 56 |
+
return super().format(record)
|
| 57 |
+
|
| 58 |
+
formatter = CustomFormatter(
|
| 59 |
+
fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
|
| 60 |
+
datefmt="%H:%M:%S",
|
| 61 |
+
)
|
| 62 |
+
logger = logging.getLogger()
|
| 63 |
+
logger.setLevel(logging.INFO)
|
| 64 |
+
if not logger.handlers:
|
| 65 |
+
ch = logging.StreamHandler()
|
| 66 |
+
ch.setFormatter(formatter)
|
| 67 |
+
logger.addHandler(ch)
|
| 68 |
+
else:
|
| 69 |
+
logger.handlers[0].setFormatter(formatter)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: bool = True):
|
| 73 |
+
"""Initialize wandb logging."""
|
| 74 |
+
if not enabled:
|
| 75 |
+
wandb.init(mode="disabled")
|
| 76 |
+
return
|
| 77 |
+
|
| 78 |
+
ckpt_dir = config.checkpoint_dir
|
| 79 |
+
if not ckpt_dir.exists():
|
| 80 |
+
raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
|
| 81 |
+
|
| 82 |
+
if resuming:
|
| 83 |
+
run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
|
| 84 |
+
wandb.init(id=run_id, resume="must", project=config.project_name)
|
| 85 |
+
else:
|
| 86 |
+
wandb.init(
|
| 87 |
+
name=config.exp_name,
|
| 88 |
+
config=dataclasses.asdict(config),
|
| 89 |
+
project=config.project_name,
|
| 90 |
+
)
|
| 91 |
+
(ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def setup_ddp():
|
| 95 |
+
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
| 96 |
+
use_ddp = world_size > 1
|
| 97 |
+
if use_ddp and not torch.distributed.is_initialized():
|
| 98 |
+
backend = "nccl" if torch.cuda.is_available() else "gloo"
|
| 99 |
+
torch.distributed.init_process_group(backend=backend, init_method="env://")
|
| 100 |
+
|
| 101 |
+
# Set up debugging environment variables for DDP issues
|
| 102 |
+
if os.environ.get("TORCH_DISTRIBUTED_DEBUG") is None:
|
| 103 |
+
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO"
|
| 104 |
+
|
| 105 |
+
local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0")))
|
| 106 |
+
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
|
| 107 |
+
if torch.cuda.is_available():
|
| 108 |
+
torch.cuda.set_device(device)
|
| 109 |
+
return use_ddp, local_rank, device
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def cleanup_ddp():
|
| 113 |
+
if torch.distributed.is_initialized():
|
| 114 |
+
torch.distributed.barrier()
|
| 115 |
+
torch.distributed.destroy_process_group()
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def set_seed(seed: int, local_rank: int):
|
| 119 |
+
torch.manual_seed(seed + local_rank)
|
| 120 |
+
np.random.seed(seed + local_rank)
|
| 121 |
+
if torch.cuda.is_available():
|
| 122 |
+
torch.cuda.manual_seed_all(seed + local_rank)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def build_datasets(config: _config.TrainConfig):
|
| 126 |
+
# Use the unified data loader with PyTorch framework
|
| 127 |
+
data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=True)
|
| 128 |
+
return data_loader, data_loader.data_config()
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def get_model_state_dict(model):
|
| 132 |
+
"""Get state dict from model, handling DDP wrapper."""
|
| 133 |
+
return (
|
| 134 |
+
model.module.state_dict()
|
| 135 |
+
if isinstance(model, torch.nn.parallel.DistributedDataParallel)
|
| 136 |
+
else model.state_dict()
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def get_model_parameters(model):
|
| 141 |
+
"""Get parameters from model, handling DDP wrapper."""
|
| 142 |
+
return (
|
| 143 |
+
model.module.parameters()
|
| 144 |
+
if isinstance(model, torch.nn.parallel.DistributedDataParallel)
|
| 145 |
+
else model.parameters()
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def save_checkpoint(model, optimizer, global_step, config, is_main, data_config):
|
| 150 |
+
"""Save a checkpoint with model state, optimizer state, and metadata."""
|
| 151 |
+
if not is_main:
|
| 152 |
+
return
|
| 153 |
+
|
| 154 |
+
# Only save if it's time to save or if it's the final step
|
| 155 |
+
if (global_step % config.save_interval == 0 and global_step > 0) or global_step == config.num_train_steps - 1:
|
| 156 |
+
# Create temporary directory for atomic checkpoint saving
|
| 157 |
+
final_ckpt_dir = config.checkpoint_dir / f"{global_step}"
|
| 158 |
+
tmp_ckpt_dir = config.checkpoint_dir / f"tmp_{global_step}"
|
| 159 |
+
|
| 160 |
+
# Remove any existing temp directory and create new one
|
| 161 |
+
if tmp_ckpt_dir.exists():
|
| 162 |
+
shutil.rmtree(tmp_ckpt_dir)
|
| 163 |
+
tmp_ckpt_dir.mkdir(parents=True, exist_ok=True)
|
| 164 |
+
|
| 165 |
+
# Save model state using safetensors (handle shared tensors)
|
| 166 |
+
model_to_save = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
|
| 167 |
+
safetensors.torch.save_model(model_to_save, tmp_ckpt_dir / "model.safetensors")
|
| 168 |
+
|
| 169 |
+
# Save optimizer state using PyTorch format
|
| 170 |
+
torch.save(optimizer.state_dict(), tmp_ckpt_dir / "optimizer.pt")
|
| 171 |
+
|
| 172 |
+
# Save training metadata (avoid saving full config to prevent JAX/Flax compatibility issues)
|
| 173 |
+
metadata = {
|
| 174 |
+
"global_step": global_step,
|
| 175 |
+
"config": dataclasses.asdict(config),
|
| 176 |
+
"timestamp": time.time(),
|
| 177 |
+
}
|
| 178 |
+
torch.save(metadata, tmp_ckpt_dir / "metadata.pt")
|
| 179 |
+
|
| 180 |
+
# save norm stats
|
| 181 |
+
norm_stats = data_config.norm_stats
|
| 182 |
+
if norm_stats is not None and data_config.asset_id is not None:
|
| 183 |
+
_normalize.save(tmp_ckpt_dir / "assets" / data_config.asset_id, norm_stats)
|
| 184 |
+
|
| 185 |
+
# Atomically move temp directory to final location
|
| 186 |
+
if final_ckpt_dir.exists():
|
| 187 |
+
shutil.rmtree(final_ckpt_dir)
|
| 188 |
+
tmp_ckpt_dir.rename(final_ckpt_dir)
|
| 189 |
+
|
| 190 |
+
logging.info(f"Saved checkpoint at step {global_step} -> {final_ckpt_dir}")
|
| 191 |
+
|
| 192 |
+
# Log checkpoint to wandb
|
| 193 |
+
if config.wandb_enabled:
|
| 194 |
+
wandb.log({"checkpoint_step": global_step}, step=global_step)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def load_checkpoint(model, optimizer, checkpoint_dir, device):
|
| 198 |
+
"""Load the latest checkpoint and return the global step."""
|
| 199 |
+
checkpoint_steps = [
|
| 200 |
+
int(d.name)
|
| 201 |
+
for d in checkpoint_dir.iterdir()
|
| 202 |
+
if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_")
|
| 203 |
+
]
|
| 204 |
+
|
| 205 |
+
if not checkpoint_steps:
|
| 206 |
+
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
|
| 207 |
+
|
| 208 |
+
latest_step = max(checkpoint_steps)
|
| 209 |
+
ckpt_dir = checkpoint_dir / f"{latest_step}"
|
| 210 |
+
|
| 211 |
+
# Clear memory before loading checkpoints
|
| 212 |
+
if torch.cuda.is_available():
|
| 213 |
+
torch.cuda.empty_cache()
|
| 214 |
+
gc.collect()
|
| 215 |
+
log_memory_usage(device, latest_step, "before_loading_checkpoint")
|
| 216 |
+
|
| 217 |
+
try:
|
| 218 |
+
# Load model state with error handling
|
| 219 |
+
logging.info("Loading model state...")
|
| 220 |
+
safetensors_path = ckpt_dir / "model.safetensors"
|
| 221 |
+
|
| 222 |
+
if safetensors_path.exists():
|
| 223 |
+
model_to_load = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
|
| 224 |
+
safetensors.torch.load_model(model_to_load, safetensors_path, device=str(device))
|
| 225 |
+
logging.info("Loaded model state from safetensors format")
|
| 226 |
+
else:
|
| 227 |
+
raise FileNotFoundError(f"No model checkpoint found at {ckpt_dir}")
|
| 228 |
+
|
| 229 |
+
torch.cuda.empty_cache()
|
| 230 |
+
gc.collect()
|
| 231 |
+
log_memory_usage(device, latest_step, "after_loading_model")
|
| 232 |
+
|
| 233 |
+
# Load optimizer state with error handling
|
| 234 |
+
logging.info("Loading optimizer state...")
|
| 235 |
+
optimizer_path = ckpt_dir / "optimizer.pt"
|
| 236 |
+
|
| 237 |
+
if optimizer_path.exists():
|
| 238 |
+
optimizer_state_dict = torch.load(optimizer_path, map_location=device, weights_only=False)
|
| 239 |
+
logging.info("Loaded optimizer state from pt format")
|
| 240 |
+
else:
|
| 241 |
+
raise FileNotFoundError(f"No optimizer checkpoint found at {ckpt_dir}")
|
| 242 |
+
|
| 243 |
+
optimizer.load_state_dict(optimizer_state_dict)
|
| 244 |
+
del optimizer_state_dict
|
| 245 |
+
torch.cuda.empty_cache()
|
| 246 |
+
gc.collect()
|
| 247 |
+
log_memory_usage(device, latest_step, "after_loading_optimizer")
|
| 248 |
+
|
| 249 |
+
# Load metadata
|
| 250 |
+
logging.info("Loading metadata...")
|
| 251 |
+
metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False)
|
| 252 |
+
global_step = metadata.get("global_step", latest_step)
|
| 253 |
+
del metadata
|
| 254 |
+
torch.cuda.empty_cache()
|
| 255 |
+
gc.collect()
|
| 256 |
+
log_memory_usage(device, latest_step, "after_loading_metadata")
|
| 257 |
+
|
| 258 |
+
logging.info(f"Successfully loaded all checkpoint components from step {latest_step}")
|
| 259 |
+
return global_step
|
| 260 |
+
|
| 261 |
+
except RuntimeError as e:
|
| 262 |
+
if "out of memory" in str(e):
|
| 263 |
+
# Clear memory and provide detailed error message
|
| 264 |
+
torch.cuda.empty_cache()
|
| 265 |
+
gc.collect()
|
| 266 |
+
logging.error(f"Out of memory error while loading checkpoint: {e!s}")
|
| 267 |
+
log_memory_usage(device, latest_step, "after_oom_error")
|
| 268 |
+
raise RuntimeError(
|
| 269 |
+
"Out of memory while loading checkpoint. Try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
|
| 270 |
+
) from e
|
| 271 |
+
raise
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def get_latest_checkpoint_step(checkpoint_dir):
|
| 275 |
+
"""Get the latest checkpoint step number from a checkpoint directory."""
|
| 276 |
+
checkpoint_steps = [
|
| 277 |
+
int(d.name)
|
| 278 |
+
for d in checkpoint_dir.iterdir()
|
| 279 |
+
if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_")
|
| 280 |
+
]
|
| 281 |
+
return max(checkpoint_steps) if checkpoint_steps else None
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def log_memory_usage(device, step, phase="unknown"):
|
| 285 |
+
"""Log detailed memory usage information."""
|
| 286 |
+
if not torch.cuda.is_available():
|
| 287 |
+
return
|
| 288 |
+
|
| 289 |
+
memory_allocated = torch.cuda.memory_allocated(device) / 1e9
|
| 290 |
+
memory_reserved = torch.cuda.memory_reserved(device) / 1e9
|
| 291 |
+
memory_free = torch.cuda.memory_reserved(device) - torch.cuda.memory_allocated(device)
|
| 292 |
+
memory_free = memory_free / 1e9
|
| 293 |
+
|
| 294 |
+
# Get more detailed memory info
|
| 295 |
+
memory_stats = torch.cuda.memory_stats(device)
|
| 296 |
+
max_memory_allocated = memory_stats.get("allocated_bytes.all.peak", 0) / 1e9
|
| 297 |
+
max_memory_reserved = memory_stats.get("reserved_bytes.all.peak", 0) / 1e9
|
| 298 |
+
|
| 299 |
+
# Get DDP info if available
|
| 300 |
+
ddp_info = ""
|
| 301 |
+
if dist.is_initialized():
|
| 302 |
+
ddp_info = f" | DDP: rank={dist.get_rank()}, world_size={dist.get_world_size()}"
|
| 303 |
+
|
| 304 |
+
logging.info(
|
| 305 |
+
f"Step {step} ({phase}): GPU memory - allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB, free: {memory_free:.2f}GB, peak_allocated: {max_memory_allocated:.2f}GB, peak_reserved: {max_memory_reserved:.2f}GB{ddp_info}"
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def train_loop(config: _config.TrainConfig):
|
| 310 |
+
use_ddp, local_rank, device = setup_ddp()
|
| 311 |
+
is_main = (not use_ddp) or (dist.get_rank() == 0)
|
| 312 |
+
set_seed(config.seed, local_rank)
|
| 313 |
+
|
| 314 |
+
# Initialize checkpoint directory and wandb
|
| 315 |
+
resuming = False
|
| 316 |
+
if config.resume:
|
| 317 |
+
# Find checkpoint directory based on experiment name
|
| 318 |
+
exp_checkpoint_dir = config.checkpoint_dir
|
| 319 |
+
if exp_checkpoint_dir.exists():
|
| 320 |
+
# Use validation to find the latest working checkpoint
|
| 321 |
+
latest_step = get_latest_checkpoint_step(exp_checkpoint_dir)
|
| 322 |
+
if latest_step is not None:
|
| 323 |
+
resuming = True
|
| 324 |
+
logging.info(
|
| 325 |
+
f"Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}"
|
| 326 |
+
)
|
| 327 |
+
else:
|
| 328 |
+
raise FileNotFoundError(f"No valid checkpoints found in {exp_checkpoint_dir} for resume")
|
| 329 |
+
else:
|
| 330 |
+
raise FileNotFoundError(f"Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume")
|
| 331 |
+
elif config.overwrite and config.checkpoint_dir.exists():
|
| 332 |
+
shutil.rmtree(config.checkpoint_dir)
|
| 333 |
+
logging.info(f"Overwriting checkpoint directory: {config.checkpoint_dir}")
|
| 334 |
+
|
| 335 |
+
# Create checkpoint directory with experiment name
|
| 336 |
+
if not resuming:
|
| 337 |
+
# For new runs, create experiment-specific checkpoint directory
|
| 338 |
+
exp_checkpoint_dir = config.checkpoint_dir
|
| 339 |
+
exp_checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 340 |
+
logging.info(f"Created experiment checkpoint directory: {exp_checkpoint_dir}")
|
| 341 |
+
else:
|
| 342 |
+
# For resume, checkpoint_dir is already set to the experiment directory
|
| 343 |
+
logging.info(f"Using existing experiment checkpoint directory: {config.checkpoint_dir}")
|
| 344 |
+
|
| 345 |
+
# Initialize wandb (only on main process)
|
| 346 |
+
if is_main:
|
| 347 |
+
init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
|
| 348 |
+
|
| 349 |
+
# Build data loader using the unified data loader
|
| 350 |
+
# Calculate effective batch size per GPU for DDP
|
| 351 |
+
# For N GPUs, each GPU should get batch_size/N samples, so total across all GPUs is batch_size
|
| 352 |
+
world_size = torch.distributed.get_world_size() if use_ddp else 1
|
| 353 |
+
effective_batch_size = config.batch_size // world_size
|
| 354 |
+
logging.info(
|
| 355 |
+
f"Using batch size per GPU: {effective_batch_size} (total batch size across {world_size} GPUs: {config.batch_size})"
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
# Pass the original batch size to data loader - it will handle DDP splitting internally
|
| 359 |
+
loader, data_config = build_datasets(config)
|
| 360 |
+
|
| 361 |
+
# Log sample images to wandb on first batch
|
| 362 |
+
if is_main and config.wandb_enabled and not resuming:
|
| 363 |
+
# Create a separate data loader for sample batch to avoid consuming the main loader
|
| 364 |
+
sample_data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=False)
|
| 365 |
+
sample_batch = next(iter(sample_data_loader))
|
| 366 |
+
# Convert observation and actions to torch tensors
|
| 367 |
+
observation, actions = sample_batch
|
| 368 |
+
sample_batch = observation.to_dict()
|
| 369 |
+
sample_batch["actions"] = actions
|
| 370 |
+
|
| 371 |
+
# Create sample images for wandb
|
| 372 |
+
images_to_log = []
|
| 373 |
+
# Get batch size from the first image tensor
|
| 374 |
+
batch_size = next(iter(sample_batch["image"].values())).shape[0]
|
| 375 |
+
for i in range(min(5, batch_size)):
|
| 376 |
+
# Concatenate all camera views horizontally for this batch item
|
| 377 |
+
# Convert from NCHW to NHWC format for wandb
|
| 378 |
+
img_concatenated = torch.cat([img[i].permute(1, 2, 0) for img in sample_batch["image"].values()], axis=1)
|
| 379 |
+
img_concatenated = img_concatenated.cpu().numpy()
|
| 380 |
+
images_to_log.append(wandb.Image(img_concatenated))
|
| 381 |
+
|
| 382 |
+
wandb.log({"camera_views": images_to_log}, step=0)
|
| 383 |
+
|
| 384 |
+
# Clear sample batch from memory aggressively
|
| 385 |
+
del sample_batch, observation, actions, images_to_log, img_concatenated
|
| 386 |
+
del sample_data_loader # Also delete the sample data loader
|
| 387 |
+
gc.collect()
|
| 388 |
+
if torch.cuda.is_available():
|
| 389 |
+
torch.cuda.empty_cache()
|
| 390 |
+
logging.info("Cleared sample batch and data loader from memory")
|
| 391 |
+
|
| 392 |
+
# Build model
|
| 393 |
+
if not isinstance(config.model, openpi.models.pi0_config.Pi0Config):
|
| 394 |
+
# Convert dataclass to Pi0Config if needed
|
| 395 |
+
model_cfg = openpi.models.pi0_config.Pi0Config(
|
| 396 |
+
dtype=config.pytorch_training_precision,
|
| 397 |
+
action_dim=config.model.action_dim,
|
| 398 |
+
action_horizon=config.model.action_horizon,
|
| 399 |
+
max_token_len=config.model.max_token_len,
|
| 400 |
+
paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"),
|
| 401 |
+
action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"),
|
| 402 |
+
pi05=getattr(config.model, "pi05", False),
|
| 403 |
+
)
|
| 404 |
+
else:
|
| 405 |
+
model_cfg = config.model
|
| 406 |
+
# Update dtype to match pytorch_training_precision
|
| 407 |
+
object.__setattr__(model_cfg, "dtype", config.pytorch_training_precision)
|
| 408 |
+
|
| 409 |
+
model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_cfg).to(device)
|
| 410 |
+
|
| 411 |
+
if hasattr(model, "gradient_checkpointing_enable"):
|
| 412 |
+
enable_gradient_checkpointing = True
|
| 413 |
+
model.gradient_checkpointing_enable()
|
| 414 |
+
logging.info("Enabled gradient checkpointing for memory optimization")
|
| 415 |
+
else:
|
| 416 |
+
enable_gradient_checkpointing = False
|
| 417 |
+
logging.info("Gradient checkpointing is not supported for this model")
|
| 418 |
+
|
| 419 |
+
# Log initial memory usage after model creation
|
| 420 |
+
if is_main and torch.cuda.is_available():
|
| 421 |
+
log_memory_usage(device, 0, "after_model_creation")
|
| 422 |
+
|
| 423 |
+
# Enable memory optimizations for large-scale training
|
| 424 |
+
if world_size >= 8:
|
| 425 |
+
torch.backends.cudnn.benchmark = True
|
| 426 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 427 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 428 |
+
# Set memory allocation configuration
|
| 429 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
|
| 430 |
+
logging.info("Enabled memory optimizations for 8+ GPU training")
|
| 431 |
+
|
| 432 |
+
if use_ddp:
|
| 433 |
+
model = torch.nn.parallel.DistributedDataParallel(
|
| 434 |
+
model,
|
| 435 |
+
device_ids=[device.index] if device.type == "cuda" else None,
|
| 436 |
+
find_unused_parameters=True, # Disable for memory efficiency
|
| 437 |
+
gradient_as_bucket_view=True, # Enable for memory efficiency
|
| 438 |
+
static_graph=world_size >= 8, # Enable for 8+ GPUs
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
# Load weights from weight_loader if specified (for fine-tuning)
|
| 442 |
+
if config.pytorch_weight_path is not None:
|
| 443 |
+
logging.info(f"Loading weights from: {config.pytorch_weight_path}")
|
| 444 |
+
|
| 445 |
+
model_path = os.path.join(config.pytorch_weight_path, "model.safetensors")
|
| 446 |
+
safetensors.torch.load_model(
|
| 447 |
+
(model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model), model_path
|
| 448 |
+
)
|
| 449 |
+
logging.info(f"Loaded PyTorch weights from {config.pytorch_weight_path}")
|
| 450 |
+
|
| 451 |
+
# Optimizer + learning rate schedule from config
|
| 452 |
+
warmup_steps = config.lr_schedule.warmup_steps
|
| 453 |
+
peak_lr = config.lr_schedule.peak_lr
|
| 454 |
+
decay_steps = config.lr_schedule.decay_steps
|
| 455 |
+
end_lr = config.lr_schedule.decay_lr
|
| 456 |
+
|
| 457 |
+
# Create optimizer with config parameters
|
| 458 |
+
optim = torch.optim.AdamW(
|
| 459 |
+
model.parameters(),
|
| 460 |
+
lr=peak_lr,
|
| 461 |
+
betas=(config.optimizer.b1, config.optimizer.b2),
|
| 462 |
+
eps=config.optimizer.eps,
|
| 463 |
+
weight_decay=config.optimizer.weight_decay,
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
# Load checkpoint if resuming
|
| 467 |
+
global_step = 0
|
| 468 |
+
if resuming:
|
| 469 |
+
global_step = load_checkpoint(model, optim, config.checkpoint_dir, device)
|
| 470 |
+
logging.info(f"Resumed training from step {global_step}")
|
| 471 |
+
|
| 472 |
+
def lr_schedule(step: int):
|
| 473 |
+
if step < warmup_steps:
|
| 474 |
+
# Match JAX behavior: start from peak_lr / (warmup_steps + 1)
|
| 475 |
+
init_lr = peak_lr / (warmup_steps + 1)
|
| 476 |
+
return init_lr + (peak_lr - init_lr) * step / warmup_steps
|
| 477 |
+
# cosine decay
|
| 478 |
+
progress = min(1.0, (step - warmup_steps) / max(1, decay_steps - warmup_steps))
|
| 479 |
+
cos = 0.5 * (1 + np.cos(np.pi * progress))
|
| 480 |
+
return end_lr + (peak_lr - end_lr) * cos
|
| 481 |
+
|
| 482 |
+
model.train()
|
| 483 |
+
start_time = time.time()
|
| 484 |
+
infos = [] # Collect stats over log interval
|
| 485 |
+
if is_main:
|
| 486 |
+
logging.info(
|
| 487 |
+
f"Running on: {platform.node()} | world_size={torch.distributed.get_world_size() if use_ddp else 1}"
|
| 488 |
+
)
|
| 489 |
+
logging.info(
|
| 490 |
+
f"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}"
|
| 491 |
+
)
|
| 492 |
+
logging.info(f"Memory optimizations: gradient_checkpointing={enable_gradient_checkpointing}")
|
| 493 |
+
logging.info(
|
| 494 |
+
f"LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}"
|
| 495 |
+
)
|
| 496 |
+
logging.info(
|
| 497 |
+
f"Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}"
|
| 498 |
+
)
|
| 499 |
+
logging.info("EMA is not supported for PyTorch training")
|
| 500 |
+
logging.info(f"Training precision: {model_cfg.dtype}")
|
| 501 |
+
|
| 502 |
+
# Training loop - iterate until we reach num_train_steps
|
| 503 |
+
pbar = (
|
| 504 |
+
tqdm.tqdm(total=config.num_train_steps, initial=global_step, desc="Training", disable=not is_main)
|
| 505 |
+
if is_main
|
| 506 |
+
else None
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
while global_step < config.num_train_steps:
|
| 510 |
+
# Set epoch for distributed training
|
| 511 |
+
if use_ddp and hasattr(loader, "set_epoch"):
|
| 512 |
+
loader.set_epoch(global_step // len(loader))
|
| 513 |
+
|
| 514 |
+
for observation, actions in loader:
|
| 515 |
+
# Check if we've reached the target number of steps
|
| 516 |
+
if global_step >= config.num_train_steps:
|
| 517 |
+
break
|
| 518 |
+
|
| 519 |
+
# The unified data loader returns (observation, actions) tuple
|
| 520 |
+
observation = jax.tree.map(lambda x: x.to(device), observation) # noqa: PLW2901
|
| 521 |
+
actions = actions.to(torch.float32) # noqa: PLW2901
|
| 522 |
+
actions = actions.to(device) # noqa: PLW2901
|
| 523 |
+
|
| 524 |
+
# Update LR
|
| 525 |
+
for pg in optim.param_groups:
|
| 526 |
+
pg["lr"] = lr_schedule(global_step)
|
| 527 |
+
|
| 528 |
+
# Forward pass
|
| 529 |
+
losses = model(observation, actions)
|
| 530 |
+
# Ensure losses is a tensor and handle different return types
|
| 531 |
+
if isinstance(losses, list | tuple):
|
| 532 |
+
losses = torch.stack(losses)
|
| 533 |
+
elif not isinstance(losses, torch.Tensor):
|
| 534 |
+
losses = torch.tensor(losses, device=device, dtype=torch.float32)
|
| 535 |
+
|
| 536 |
+
loss = losses.mean()
|
| 537 |
+
|
| 538 |
+
# Backward pass
|
| 539 |
+
loss.backward()
|
| 540 |
+
|
| 541 |
+
# Log memory usage after backward pass
|
| 542 |
+
if global_step < 5 and is_main and torch.cuda.is_available():
|
| 543 |
+
log_memory_usage(device, global_step, "after_backward")
|
| 544 |
+
|
| 545 |
+
# Gradient clipping
|
| 546 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optimizer.clip_gradient_norm)
|
| 547 |
+
|
| 548 |
+
# Optimizer step
|
| 549 |
+
optim.step()
|
| 550 |
+
optim.zero_grad(set_to_none=True)
|
| 551 |
+
|
| 552 |
+
# Clear gradients more aggressively
|
| 553 |
+
for param in model.parameters():
|
| 554 |
+
if param.grad is not None:
|
| 555 |
+
param.grad.detach_()
|
| 556 |
+
param.grad = None
|
| 557 |
+
|
| 558 |
+
# Collect stats
|
| 559 |
+
if is_main:
|
| 560 |
+
infos.append(
|
| 561 |
+
{
|
| 562 |
+
"loss": loss.item(),
|
| 563 |
+
"learning_rate": optim.param_groups[0]["lr"],
|
| 564 |
+
"grad_norm": float(grad_norm) if isinstance(grad_norm, torch.Tensor) else grad_norm,
|
| 565 |
+
}
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
if is_main and (global_step % config.log_interval == 0):
|
| 569 |
+
elapsed = time.time() - start_time
|
| 570 |
+
|
| 571 |
+
# Average stats over log interval
|
| 572 |
+
avg_loss = sum(info["loss"] for info in infos) / len(infos)
|
| 573 |
+
avg_lr = sum(info["learning_rate"] for info in infos) / len(infos)
|
| 574 |
+
|
| 575 |
+
avg_grad_norm = None
|
| 576 |
+
if any("grad_norm" in info for info in infos):
|
| 577 |
+
vals = [
|
| 578 |
+
info["grad_norm"] for info in infos if "grad_norm" in info and info["grad_norm"] is not None
|
| 579 |
+
]
|
| 580 |
+
if len(vals) > 0:
|
| 581 |
+
avg_grad_norm = sum(vals) / len(vals)
|
| 582 |
+
logging.info(
|
| 583 |
+
f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} grad_norm={avg_grad_norm:.2f} time={elapsed:.1f}s"
|
| 584 |
+
if avg_grad_norm is not None
|
| 585 |
+
else f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} time={elapsed:.1f}s"
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
# Log to wandb
|
| 589 |
+
if config.wandb_enabled and len(infos) > 0:
|
| 590 |
+
log_payload = {
|
| 591 |
+
"loss": avg_loss,
|
| 592 |
+
"learning_rate": avg_lr,
|
| 593 |
+
"step": global_step,
|
| 594 |
+
"time_per_step": elapsed / config.log_interval,
|
| 595 |
+
}
|
| 596 |
+
if avg_grad_norm is not None:
|
| 597 |
+
log_payload["grad_norm"] = avg_grad_norm
|
| 598 |
+
wandb.log(log_payload, step=global_step)
|
| 599 |
+
|
| 600 |
+
start_time = time.time()
|
| 601 |
+
infos = [] # Reset stats collection
|
| 602 |
+
|
| 603 |
+
global_step += 1
|
| 604 |
+
# Save checkpoint using the new mechanism
|
| 605 |
+
save_checkpoint(model, optim, global_step, config, is_main, data_config)
|
| 606 |
+
|
| 607 |
+
# Update progress bar
|
| 608 |
+
if pbar is not None:
|
| 609 |
+
pbar.update(1)
|
| 610 |
+
pbar.set_postfix(
|
| 611 |
+
{"loss": f"{loss.item():.4f}", "lr": f"{optim.param_groups[0]['lr']:.2e}", "step": global_step}
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
# Close progress bar
|
| 615 |
+
if pbar is not None:
|
| 616 |
+
pbar.close()
|
| 617 |
+
|
| 618 |
+
# Finish wandb run
|
| 619 |
+
if is_main and config.wandb_enabled:
|
| 620 |
+
wandb.finish()
|
| 621 |
+
|
| 622 |
+
cleanup_ddp()
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
def main():
|
| 626 |
+
init_logging()
|
| 627 |
+
config = _config.cli()
|
| 628 |
+
train_loop(config)
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
if __name__ == "__main__":
|
| 632 |
+
main()
|
capvector-pi05/scripts/train_regular_loss_pytorch.py
ADDED
|
@@ -0,0 +1,754 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PyTorch training entrypoint for PI0/PI05 with multi-GPU and multi-node (DDP) support.
|
| 3 |
+
This script mirrors the behavior of the JAX trainer (`scripts/train.py`) but runs
|
| 4 |
+
entirely in PyTorch using the `PI0Pytorch` model and your existing config/data
|
| 5 |
+
pipeline from `src/openpi/training/config.py` and `src/openpi/training/data_loader.py`.
|
| 6 |
+
|
| 7 |
+
Usage
|
| 8 |
+
Single GPU:
|
| 9 |
+
python scripts/train_pytorch.py <config_name> --exp_name <run_name> --save_interval <interval>
|
| 10 |
+
Example:
|
| 11 |
+
python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test
|
| 12 |
+
python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test --resume # Resume from latest checkpoint
|
| 13 |
+
Multi-GPU (single node):
|
| 14 |
+
torchrun --standalone --nnodes=1 --nproc_per_node=<num_gpus> scripts/train_pytorch.py <config_name> --exp_name <run_name>
|
| 15 |
+
Example:
|
| 16 |
+
torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test
|
| 17 |
+
torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume
|
| 18 |
+
Multi-Node Training:
|
| 19 |
+
torchrun \
|
| 20 |
+
--nnodes=<num_nodes> --nproc_per_node=<gpus_per_node> --node_rank=<rank_of_node> \
|
| 21 |
+
--master_addr=<master_ip> --master_port=<port> \
|
| 22 |
+
scripts/train_pytorch.py <config_name> --exp_name=<run_name> --save_interval <interval>
|
| 23 |
+
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import dataclasses
|
| 27 |
+
import gc
|
| 28 |
+
import logging
|
| 29 |
+
import os
|
| 30 |
+
import platform
|
| 31 |
+
from pathlib import Path
|
| 32 |
+
import shutil
|
| 33 |
+
import time
|
| 34 |
+
|
| 35 |
+
import jax
|
| 36 |
+
import numpy as np
|
| 37 |
+
import safetensors.torch
|
| 38 |
+
import torch
|
| 39 |
+
import torch.distributed as dist
|
| 40 |
+
import torch.nn.parallel
|
| 41 |
+
import tqdm
|
| 42 |
+
import wandb
|
| 43 |
+
|
| 44 |
+
import openpi.models.pi0_config
|
| 45 |
+
import openpi.models_pytorch.pi0_pytorch
|
| 46 |
+
import openpi.shared.normalize as _normalize
|
| 47 |
+
import openpi.training.config as _config
|
| 48 |
+
import openpi.training.data_loader as _data
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def init_logging():
|
| 52 |
+
level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
|
| 53 |
+
|
| 54 |
+
class CustomFormatter(logging.Formatter):
|
| 55 |
+
def format(self, record):
|
| 56 |
+
record.levelname = level_mapping.get(record.levelname, record.levelname)
|
| 57 |
+
return super().format(record)
|
| 58 |
+
|
| 59 |
+
formatter = CustomFormatter(
|
| 60 |
+
fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
|
| 61 |
+
datefmt="%H:%M:%S",
|
| 62 |
+
)
|
| 63 |
+
logger = logging.getLogger()
|
| 64 |
+
logger.setLevel(logging.INFO)
|
| 65 |
+
if not logger.handlers:
|
| 66 |
+
ch = logging.StreamHandler()
|
| 67 |
+
ch.setFormatter(formatter)
|
| 68 |
+
logger.addHandler(ch)
|
| 69 |
+
else:
|
| 70 |
+
logger.handlers[0].setFormatter(formatter)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: bool = True):
|
| 74 |
+
"""Initialize wandb logging."""
|
| 75 |
+
if not enabled:
|
| 76 |
+
wandb.init(mode="disabled")
|
| 77 |
+
return
|
| 78 |
+
|
| 79 |
+
ckpt_dir = config.checkpoint_dir
|
| 80 |
+
if not ckpt_dir.exists():
|
| 81 |
+
raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
|
| 82 |
+
|
| 83 |
+
if resuming:
|
| 84 |
+
run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
|
| 85 |
+
wandb.init(id=run_id, resume="must", project=config.project_name)
|
| 86 |
+
else:
|
| 87 |
+
wandb.init(
|
| 88 |
+
name=config.name,
|
| 89 |
+
config=dataclasses.asdict(config),
|
| 90 |
+
project=config.project_name,
|
| 91 |
+
id="-".join([config.name, config.exp_name]),
|
| 92 |
+
)
|
| 93 |
+
(ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def setup_ddp():
|
| 97 |
+
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
| 98 |
+
use_ddp = world_size > 1
|
| 99 |
+
if use_ddp and not torch.distributed.is_initialized():
|
| 100 |
+
backend = "nccl" if torch.cuda.is_available() else "gloo"
|
| 101 |
+
torch.distributed.init_process_group(backend=backend, init_method="env://")
|
| 102 |
+
|
| 103 |
+
# Set up debugging environment variables for DDP issues
|
| 104 |
+
if os.environ.get("TORCH_DISTRIBUTED_DEBUG") is None:
|
| 105 |
+
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO"
|
| 106 |
+
|
| 107 |
+
local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0")))
|
| 108 |
+
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
|
| 109 |
+
if torch.cuda.is_available():
|
| 110 |
+
torch.cuda.set_device(device)
|
| 111 |
+
return use_ddp, local_rank, device
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def cleanup_ddp():
|
| 115 |
+
if torch.distributed.is_initialized():
|
| 116 |
+
torch.distributed.barrier()
|
| 117 |
+
torch.distributed.destroy_process_group()
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def set_seed(seed: int, local_rank: int):
|
| 121 |
+
torch.manual_seed(seed + local_rank)
|
| 122 |
+
np.random.seed(seed + local_rank)
|
| 123 |
+
if torch.cuda.is_available():
|
| 124 |
+
torch.cuda.manual_seed_all(seed + local_rank)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def build_datasets(config: _config.TrainConfig):
|
| 128 |
+
# Use the unified data loader with PyTorch framework
|
| 129 |
+
data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=True)
|
| 130 |
+
return data_loader, data_loader.data_config()
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def get_model_state_dict(model):
|
| 134 |
+
"""Get state dict from model, handling DDP wrapper."""
|
| 135 |
+
return (
|
| 136 |
+
model.module.state_dict()
|
| 137 |
+
if isinstance(model, torch.nn.parallel.DistributedDataParallel)
|
| 138 |
+
else model.state_dict()
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def get_model_parameters(model):
|
| 143 |
+
"""Get parameters from model, handling DDP wrapper."""
|
| 144 |
+
return (
|
| 145 |
+
model.module.parameters()
|
| 146 |
+
if isinstance(model, torch.nn.parallel.DistributedDataParallel)
|
| 147 |
+
else model.parameters()
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def load_regular_vector_dict(path: str | Path) -> dict[str, torch.Tensor]:
|
| 152 |
+
"""Load the regularization vectors, which are used for delta-based regularization."""
|
| 153 |
+
tensor_path = Path(path)
|
| 154 |
+
suffix = tensor_path.suffix.lower()
|
| 155 |
+
|
| 156 |
+
if suffix in {".pt", ".pth"}:
|
| 157 |
+
tensors = torch.load(tensor_path, map_location="cpu", weights_only=False, mmap=True)
|
| 158 |
+
elif suffix == ".safetensors":
|
| 159 |
+
tensors = safetensors.torch.load_file(str(tensor_path), device="cpu")
|
| 160 |
+
else:
|
| 161 |
+
raise ValueError(f"Unsupported tensor file format: {tensor_path}")
|
| 162 |
+
|
| 163 |
+
return tensors["state_dict"]
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def prepare_regularization_context(
|
| 167 |
+
model,
|
| 168 |
+
config: _config.TrainConfig,
|
| 169 |
+
) -> dict | None:
|
| 170 |
+
"""Load regularization tensors and build the runtime context for delta-based regularization."""
|
| 171 |
+
|
| 172 |
+
# Don't use regularization optionally
|
| 173 |
+
if not config.regularization_vector_path or config.regularization_coeff == 0.0:
|
| 174 |
+
return None
|
| 175 |
+
|
| 176 |
+
# Get the regularization vectors as reference directions
|
| 177 |
+
if config.resume:
|
| 178 |
+
raise ValueError(
|
| 179 |
+
"Delta-based regularization with --resume is not supported in this PyTorch trainer. "
|
| 180 |
+
"This run now keeps the anchor only in memory at startup."
|
| 181 |
+
)
|
| 182 |
+
vector_path = Path(config.regularization_vector_path).expanduser()
|
| 183 |
+
if not vector_path.exists():
|
| 184 |
+
raise FileNotFoundError(f"Regularization vector file does not exist: {vector_path}")
|
| 185 |
+
regularization_vectors = load_regular_vector_dict(vector_path)
|
| 186 |
+
|
| 187 |
+
# Get the model's trainable parameters to be regularized and the corresponding freezing anchors at startup
|
| 188 |
+
model_module = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
|
| 189 |
+
|
| 190 |
+
trainable_entries = []
|
| 191 |
+
missing_vectors = 0
|
| 192 |
+
shape_mismatches = 0
|
| 193 |
+
trainable_param_names = set()
|
| 194 |
+
|
| 195 |
+
for name, param in model_module.named_parameters():
|
| 196 |
+
if not param.requires_grad:
|
| 197 |
+
continue
|
| 198 |
+
trainable_param_names.add(name)
|
| 199 |
+
regularization_vector = regularization_vectors.get(name)
|
| 200 |
+
if regularization_vector is None:
|
| 201 |
+
missing_vectors += 1
|
| 202 |
+
continue
|
| 203 |
+
anchor_param = param.detach().clone().contiguous()
|
| 204 |
+
if regularization_vector.shape != param.shape or anchor_param.shape != param.shape:
|
| 205 |
+
shape_mismatches += 1
|
| 206 |
+
continue
|
| 207 |
+
trainable_entries.append(
|
| 208 |
+
{
|
| 209 |
+
"name": name,
|
| 210 |
+
"param": param,
|
| 211 |
+
"anchor": anchor_param,
|
| 212 |
+
"vector": regularization_vector.to(device=param.device, dtype=param.dtype).contiguous(),
|
| 213 |
+
}
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
logging.info(
|
| 217 |
+
"Regularization coverage: matched=%d missing_vectors=%d shape_mismatches=%d",
|
| 218 |
+
len(trainable_entries),
|
| 219 |
+
missing_vectors,
|
| 220 |
+
shape_mismatches,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
return {
|
| 224 |
+
"entries": trainable_entries,
|
| 225 |
+
"weight": config.regularization_coeff,
|
| 226 |
+
"vector_path": str(vector_path),
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def compute_regularization_loss(regularization_context: dict | None, device: torch.device) -> torch.Tensor:
|
| 231 |
+
"""Compute the delta-based regularization loss for the current model parameters."""
|
| 232 |
+
reg_loss = torch.zeros((), device=device, dtype=torch.float32)
|
| 233 |
+
|
| 234 |
+
if not regularization_context:
|
| 235 |
+
return reg_loss
|
| 236 |
+
|
| 237 |
+
for entry in regularization_context["entries"]:
|
| 238 |
+
param = entry["param"]
|
| 239 |
+
anchor = entry["anchor"]
|
| 240 |
+
vector = entry["vector"]
|
| 241 |
+
|
| 242 |
+
delta = (param - anchor).reshape(-1).float()
|
| 243 |
+
direction = vector.reshape(-1).float()
|
| 244 |
+
reg_loss = reg_loss + torch.abs(torch.dot(delta, direction))
|
| 245 |
+
|
| 246 |
+
return reg_loss * regularization_context["weight"]
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def save_checkpoint(model, optimizer, global_step, config, is_main, data_config):
|
| 250 |
+
"""Save a checkpoint with model state, optimizer state, and metadata."""
|
| 251 |
+
if not is_main:
|
| 252 |
+
return
|
| 253 |
+
|
| 254 |
+
# Only save if it's time to save or if it's the final step
|
| 255 |
+
if (global_step % config.save_interval == 0 and global_step > 0) or global_step == config.num_train_steps - 1:
|
| 256 |
+
# Create temporary directory for atomic checkpoint saving
|
| 257 |
+
final_ckpt_dir = config.checkpoint_dir / f"{global_step}"
|
| 258 |
+
tmp_ckpt_dir = config.checkpoint_dir / f"tmp_{global_step}"
|
| 259 |
+
|
| 260 |
+
# Remove any existing temp directory and create new one
|
| 261 |
+
if tmp_ckpt_dir.exists():
|
| 262 |
+
shutil.rmtree(tmp_ckpt_dir)
|
| 263 |
+
tmp_ckpt_dir.mkdir(parents=True, exist_ok=True)
|
| 264 |
+
|
| 265 |
+
# Save model state using safetensors (handle shared tensors)
|
| 266 |
+
model_to_save = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
|
| 267 |
+
safetensors.torch.save_model(model_to_save, tmp_ckpt_dir / "model.safetensors")
|
| 268 |
+
|
| 269 |
+
# Save optimizer state using PyTorch format
|
| 270 |
+
torch.save(optimizer.state_dict(), tmp_ckpt_dir / "optimizer.pt")
|
| 271 |
+
|
| 272 |
+
# Save training metadata (avoid saving full config to prevent JAX/Flax compatibility issues)
|
| 273 |
+
metadata = {
|
| 274 |
+
"global_step": global_step,
|
| 275 |
+
"config": dataclasses.asdict(config),
|
| 276 |
+
"timestamp": time.time(),
|
| 277 |
+
}
|
| 278 |
+
torch.save(metadata, tmp_ckpt_dir / "metadata.pt")
|
| 279 |
+
|
| 280 |
+
# save norm stats
|
| 281 |
+
norm_stats = data_config.norm_stats
|
| 282 |
+
if norm_stats is not None and data_config.asset_id is not None:
|
| 283 |
+
_normalize.save(tmp_ckpt_dir / "assets" / data_config.asset_id, norm_stats)
|
| 284 |
+
|
| 285 |
+
# Atomically move temp directory to final location
|
| 286 |
+
if final_ckpt_dir.exists():
|
| 287 |
+
shutil.rmtree(final_ckpt_dir)
|
| 288 |
+
tmp_ckpt_dir.rename(final_ckpt_dir)
|
| 289 |
+
|
| 290 |
+
logging.info(f"Saved checkpoint at step {global_step} -> {final_ckpt_dir}")
|
| 291 |
+
|
| 292 |
+
# Log checkpoint to wandb
|
| 293 |
+
if config.wandb_enabled:
|
| 294 |
+
wandb.log({"checkpoint_step": global_step}, step=global_step)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def load_checkpoint(model, optimizer, checkpoint_dir, device):
|
| 298 |
+
"""Load the latest checkpoint and return the global step."""
|
| 299 |
+
checkpoint_steps = [
|
| 300 |
+
int(d.name)
|
| 301 |
+
for d in checkpoint_dir.iterdir()
|
| 302 |
+
if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_")
|
| 303 |
+
]
|
| 304 |
+
|
| 305 |
+
if not checkpoint_steps:
|
| 306 |
+
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
|
| 307 |
+
|
| 308 |
+
latest_step = max(checkpoint_steps)
|
| 309 |
+
ckpt_dir = checkpoint_dir / f"{latest_step}"
|
| 310 |
+
|
| 311 |
+
# Clear memory before loading checkpoints
|
| 312 |
+
if torch.cuda.is_available():
|
| 313 |
+
torch.cuda.empty_cache()
|
| 314 |
+
gc.collect()
|
| 315 |
+
log_memory_usage(device, latest_step, "before_loading_checkpoint")
|
| 316 |
+
|
| 317 |
+
try:
|
| 318 |
+
# Load model state with error handling
|
| 319 |
+
logging.info("Loading model state...")
|
| 320 |
+
safetensors_path = ckpt_dir / "model.safetensors"
|
| 321 |
+
|
| 322 |
+
if safetensors_path.exists():
|
| 323 |
+
model_to_load = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
|
| 324 |
+
safetensors.torch.load_model(model_to_load, safetensors_path, device=str(device))
|
| 325 |
+
logging.info("Loaded model state from safetensors format")
|
| 326 |
+
else:
|
| 327 |
+
raise FileNotFoundError(f"No model checkpoint found at {ckpt_dir}")
|
| 328 |
+
|
| 329 |
+
torch.cuda.empty_cache()
|
| 330 |
+
gc.collect()
|
| 331 |
+
log_memory_usage(device, latest_step, "after_loading_model")
|
| 332 |
+
|
| 333 |
+
# Load optimizer state with error handling
|
| 334 |
+
logging.info("Loading optimizer state...")
|
| 335 |
+
optimizer_path = ckpt_dir / "optimizer.pt"
|
| 336 |
+
|
| 337 |
+
if optimizer_path.exists():
|
| 338 |
+
optimizer_state_dict = torch.load(optimizer_path, map_location=device, weights_only=False)
|
| 339 |
+
logging.info("Loaded optimizer state from pt format")
|
| 340 |
+
else:
|
| 341 |
+
raise FileNotFoundError(f"No optimizer checkpoint found at {ckpt_dir}")
|
| 342 |
+
|
| 343 |
+
optimizer.load_state_dict(optimizer_state_dict)
|
| 344 |
+
del optimizer_state_dict
|
| 345 |
+
torch.cuda.empty_cache()
|
| 346 |
+
gc.collect()
|
| 347 |
+
log_memory_usage(device, latest_step, "after_loading_optimizer")
|
| 348 |
+
|
| 349 |
+
# Load metadata
|
| 350 |
+
logging.info("Loading metadata...")
|
| 351 |
+
metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False)
|
| 352 |
+
global_step = metadata.get("global_step", latest_step)
|
| 353 |
+
del metadata
|
| 354 |
+
torch.cuda.empty_cache()
|
| 355 |
+
gc.collect()
|
| 356 |
+
log_memory_usage(device, latest_step, "after_loading_metadata")
|
| 357 |
+
|
| 358 |
+
logging.info(f"Successfully loaded all checkpoint components from step {latest_step}")
|
| 359 |
+
return global_step
|
| 360 |
+
|
| 361 |
+
except RuntimeError as e:
|
| 362 |
+
if "out of memory" in str(e):
|
| 363 |
+
# Clear memory and provide detailed error message
|
| 364 |
+
torch.cuda.empty_cache()
|
| 365 |
+
gc.collect()
|
| 366 |
+
logging.error(f"Out of memory error while loading checkpoint: {e!s}")
|
| 367 |
+
log_memory_usage(device, latest_step, "after_oom_error")
|
| 368 |
+
raise RuntimeError(
|
| 369 |
+
"Out of memory while loading checkpoint. Try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
|
| 370 |
+
) from e
|
| 371 |
+
raise
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def get_latest_checkpoint_step(checkpoint_dir):
|
| 375 |
+
"""Get the latest checkpoint step number from a checkpoint directory."""
|
| 376 |
+
checkpoint_steps = [
|
| 377 |
+
int(d.name)
|
| 378 |
+
for d in checkpoint_dir.iterdir()
|
| 379 |
+
if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_")
|
| 380 |
+
]
|
| 381 |
+
return max(checkpoint_steps) if checkpoint_steps else None
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def log_memory_usage(device, step, phase="unknown"):
|
| 385 |
+
"""Log detailed memory usage information."""
|
| 386 |
+
if not torch.cuda.is_available():
|
| 387 |
+
return
|
| 388 |
+
|
| 389 |
+
memory_allocated = torch.cuda.memory_allocated(device) / 1e9
|
| 390 |
+
memory_reserved = torch.cuda.memory_reserved(device) / 1e9
|
| 391 |
+
memory_free = torch.cuda.memory_reserved(device) - torch.cuda.memory_allocated(device)
|
| 392 |
+
memory_free = memory_free / 1e9
|
| 393 |
+
|
| 394 |
+
# Get more detailed memory info
|
| 395 |
+
memory_stats = torch.cuda.memory_stats(device)
|
| 396 |
+
max_memory_allocated = memory_stats.get("allocated_bytes.all.peak", 0) / 1e9
|
| 397 |
+
max_memory_reserved = memory_stats.get("reserved_bytes.all.peak", 0) / 1e9
|
| 398 |
+
|
| 399 |
+
# Get DDP info if available
|
| 400 |
+
ddp_info = ""
|
| 401 |
+
if dist.is_initialized():
|
| 402 |
+
ddp_info = f" | DDP: rank={dist.get_rank()}, world_size={dist.get_world_size()}"
|
| 403 |
+
|
| 404 |
+
logging.info(
|
| 405 |
+
f"Step {step} ({phase}): GPU memory - allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB, free: {memory_free:.2f}GB, peak_allocated: {max_memory_allocated:.2f}GB, peak_reserved: {max_memory_reserved:.2f}GB{ddp_info}"
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def train_loop(config: _config.TrainConfig):
|
| 410 |
+
use_ddp, local_rank, device = setup_ddp()
|
| 411 |
+
is_main = (not use_ddp) or (dist.get_rank() == 0)
|
| 412 |
+
set_seed(config.seed, local_rank)
|
| 413 |
+
|
| 414 |
+
# Initialize checkpoint directory and wandb
|
| 415 |
+
resuming = False
|
| 416 |
+
if config.resume:
|
| 417 |
+
# Find checkpoint directory based on experiment name
|
| 418 |
+
exp_checkpoint_dir = config.checkpoint_dir
|
| 419 |
+
if exp_checkpoint_dir.exists():
|
| 420 |
+
# Use validation to find the latest working checkpoint
|
| 421 |
+
latest_step = get_latest_checkpoint_step(exp_checkpoint_dir)
|
| 422 |
+
if latest_step is not None:
|
| 423 |
+
resuming = True
|
| 424 |
+
logging.info(
|
| 425 |
+
f"Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}"
|
| 426 |
+
)
|
| 427 |
+
else:
|
| 428 |
+
raise FileNotFoundError(f"No valid checkpoints found in {exp_checkpoint_dir} for resume")
|
| 429 |
+
else:
|
| 430 |
+
raise FileNotFoundError(f"Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume")
|
| 431 |
+
elif config.overwrite and config.checkpoint_dir.exists():
|
| 432 |
+
shutil.rmtree(config.checkpoint_dir)
|
| 433 |
+
logging.info(f"Overwriting checkpoint directory: {config.checkpoint_dir}")
|
| 434 |
+
|
| 435 |
+
# Create checkpoint directory with experiment name
|
| 436 |
+
if not resuming:
|
| 437 |
+
# For new runs, create experiment-specific checkpoint directory
|
| 438 |
+
exp_checkpoint_dir = config.checkpoint_dir
|
| 439 |
+
exp_checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 440 |
+
logging.info(f"Created experiment checkpoint directory: {exp_checkpoint_dir}")
|
| 441 |
+
else:
|
| 442 |
+
# For resume, checkpoint_dir is already set to the experiment directory
|
| 443 |
+
logging.info(f"Using existing experiment checkpoint directory: {config.checkpoint_dir}")
|
| 444 |
+
|
| 445 |
+
# Initialize wandb (only on main process)
|
| 446 |
+
if is_main:
|
| 447 |
+
init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
|
| 448 |
+
|
| 449 |
+
# Build data loader using the unified data loader
|
| 450 |
+
# Calculate effective batch size per GPU for DDP
|
| 451 |
+
# For N GPUs, each GPU should get batch_size/N samples, so total across all GPUs is batch_size
|
| 452 |
+
world_size = torch.distributed.get_world_size() if use_ddp else 1
|
| 453 |
+
effective_batch_size = config.batch_size // world_size
|
| 454 |
+
logging.info(
|
| 455 |
+
f"Using batch size per GPU: {effective_batch_size} (total batch size across {world_size} GPUs: {config.batch_size})"
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
# Pass the original batch size to data loader - it will handle DDP splitting internally
|
| 459 |
+
loader, data_config = build_datasets(config)
|
| 460 |
+
|
| 461 |
+
# Log sample images to wandb on first batch
|
| 462 |
+
if is_main and config.wandb_enabled and not resuming:
|
| 463 |
+
# Create a separate data loader for sample batch to avoid consuming the main loader
|
| 464 |
+
sample_data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=False)
|
| 465 |
+
sample_batch = next(iter(sample_data_loader))
|
| 466 |
+
# Convert observation and actions to torch tensors
|
| 467 |
+
observation, actions = sample_batch
|
| 468 |
+
sample_batch = observation.to_dict()
|
| 469 |
+
sample_batch["actions"] = actions
|
| 470 |
+
|
| 471 |
+
# Create sample images for wandb
|
| 472 |
+
images_to_log = []
|
| 473 |
+
# Get batch size from the first image tensor
|
| 474 |
+
batch_size = next(iter(sample_batch["image"].values())).shape[0]
|
| 475 |
+
for i in range(min(5, batch_size)):
|
| 476 |
+
# Concatenate all camera views horizontally for this batch item
|
| 477 |
+
# Convert from NCHW to NHWC format for wandb
|
| 478 |
+
img_concatenated = torch.cat([img[i].permute(1, 2, 0) for img in sample_batch["image"].values()], axis=1)
|
| 479 |
+
img_concatenated = img_concatenated.cpu().numpy()
|
| 480 |
+
images_to_log.append(wandb.Image(img_concatenated))
|
| 481 |
+
|
| 482 |
+
wandb.log({"camera_views": images_to_log}, step=0)
|
| 483 |
+
|
| 484 |
+
# Clear sample batch from memory aggressively
|
| 485 |
+
del sample_batch, observation, actions, images_to_log, img_concatenated
|
| 486 |
+
del sample_data_loader # Also delete the sample data loader
|
| 487 |
+
gc.collect()
|
| 488 |
+
if torch.cuda.is_available():
|
| 489 |
+
torch.cuda.empty_cache()
|
| 490 |
+
logging.info("Cleared sample batch and data loader from memory")
|
| 491 |
+
|
| 492 |
+
# Build model
|
| 493 |
+
if not isinstance(config.model, openpi.models.pi0_config.Pi0Config):
|
| 494 |
+
# Convert dataclass to Pi0Config if needed
|
| 495 |
+
model_cfg = openpi.models.pi0_config.Pi0Config(
|
| 496 |
+
dtype=config.pytorch_training_precision,
|
| 497 |
+
action_dim=config.model.action_dim,
|
| 498 |
+
action_horizon=config.model.action_horizon,
|
| 499 |
+
max_token_len=config.model.max_token_len,
|
| 500 |
+
paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"),
|
| 501 |
+
action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"),
|
| 502 |
+
pi05=getattr(config.model, "pi05", False),
|
| 503 |
+
)
|
| 504 |
+
else:
|
| 505 |
+
model_cfg = config.model
|
| 506 |
+
# Update dtype to match pytorch_training_precision
|
| 507 |
+
object.__setattr__(model_cfg, "dtype", config.pytorch_training_precision)
|
| 508 |
+
|
| 509 |
+
model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_cfg).to(device)
|
| 510 |
+
|
| 511 |
+
if hasattr(model, "gradient_checkpointing_enable"):
|
| 512 |
+
enable_gradient_checkpointing = True
|
| 513 |
+
model.gradient_checkpointing_enable()
|
| 514 |
+
logging.info("Enabled gradient checkpointing for memory optimization")
|
| 515 |
+
else:
|
| 516 |
+
enable_gradient_checkpointing = False
|
| 517 |
+
logging.info("Gradient checkpointing is not supported for this model")
|
| 518 |
+
|
| 519 |
+
# Log initial memory usage after model creation
|
| 520 |
+
if is_main and torch.cuda.is_available():
|
| 521 |
+
log_memory_usage(device, 0, "after_model_creation")
|
| 522 |
+
|
| 523 |
+
# Enable memory optimizations for large-scale training
|
| 524 |
+
if world_size >= 8:
|
| 525 |
+
torch.backends.cudnn.benchmark = True
|
| 526 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 527 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 528 |
+
# Set memory allocation configuration
|
| 529 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
|
| 530 |
+
logging.info("Enabled memory optimizations for 8+ GPU training")
|
| 531 |
+
|
| 532 |
+
if use_ddp:
|
| 533 |
+
model = torch.nn.parallel.DistributedDataParallel(
|
| 534 |
+
model,
|
| 535 |
+
device_ids=[device.index] if device.type == "cuda" else None,
|
| 536 |
+
find_unused_parameters=True, # Disable for memory efficiency
|
| 537 |
+
gradient_as_bucket_view=True, # Enable for memory efficiency
|
| 538 |
+
static_graph=world_size >= 8, # Enable for 8+ GPUs
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
# Load weights from weight_loader if specified (for fine-tuning)
|
| 542 |
+
if config.pytorch_weight_path is not None:
|
| 543 |
+
logging.info(f"Loading weights from: {config.pytorch_weight_path}")
|
| 544 |
+
|
| 545 |
+
model_path = os.path.join(config.pytorch_weight_path, "model.safetensors")
|
| 546 |
+
safetensors.torch.load_model(
|
| 547 |
+
(model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model), model_path
|
| 548 |
+
)
|
| 549 |
+
logging.info(f"Loaded PyTorch weights from {config.pytorch_weight_path}")
|
| 550 |
+
|
| 551 |
+
regularization_context = prepare_regularization_context(model, config)
|
| 552 |
+
|
| 553 |
+
# Optimizer + learning rate schedule from config
|
| 554 |
+
warmup_steps = config.lr_schedule.warmup_steps
|
| 555 |
+
peak_lr = config.lr_schedule.peak_lr
|
| 556 |
+
decay_steps = config.lr_schedule.decay_steps
|
| 557 |
+
end_lr = config.lr_schedule.decay_lr
|
| 558 |
+
|
| 559 |
+
# Create optimizer with config parameters
|
| 560 |
+
optim = torch.optim.AdamW(
|
| 561 |
+
model.parameters(),
|
| 562 |
+
lr=peak_lr,
|
| 563 |
+
betas=(config.optimizer.b1, config.optimizer.b2),
|
| 564 |
+
eps=config.optimizer.eps,
|
| 565 |
+
weight_decay=config.optimizer.weight_decay,
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
# Load checkpoint if resuming
|
| 569 |
+
global_step = 0
|
| 570 |
+
if resuming:
|
| 571 |
+
global_step = load_checkpoint(model, optim, config.checkpoint_dir, device)
|
| 572 |
+
logging.info(f"Resumed training from step {global_step}")
|
| 573 |
+
|
| 574 |
+
def lr_schedule(step: int):
|
| 575 |
+
if step < warmup_steps:
|
| 576 |
+
# Match JAX behavior: start from peak_lr / (warmup_steps + 1)
|
| 577 |
+
init_lr = peak_lr / (warmup_steps + 1)
|
| 578 |
+
return init_lr + (peak_lr - init_lr) * step / warmup_steps
|
| 579 |
+
# cosine decay
|
| 580 |
+
progress = min(1.0, (step - warmup_steps) / max(1, decay_steps - warmup_steps))
|
| 581 |
+
cos = 0.5 * (1 + np.cos(np.pi * progress))
|
| 582 |
+
return end_lr + (peak_lr - end_lr) * cos
|
| 583 |
+
|
| 584 |
+
model.train()
|
| 585 |
+
start_time = time.time()
|
| 586 |
+
infos = [] # Collect stats over log interval
|
| 587 |
+
if is_main:
|
| 588 |
+
logging.info(
|
| 589 |
+
f"Running on: {platform.node()} | world_size={torch.distributed.get_world_size() if use_ddp else 1}"
|
| 590 |
+
)
|
| 591 |
+
logging.info(
|
| 592 |
+
f"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}"
|
| 593 |
+
)
|
| 594 |
+
logging.info(f"Memory optimizations: gradient_checkpointing={enable_gradient_checkpointing}")
|
| 595 |
+
logging.info(
|
| 596 |
+
f"LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}"
|
| 597 |
+
)
|
| 598 |
+
logging.info(
|
| 599 |
+
f"Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}"
|
| 600 |
+
)
|
| 601 |
+
logging.info("EMA is not supported for PyTorch training")
|
| 602 |
+
logging.info(f"Training precision: {model_cfg.dtype}")
|
| 603 |
+
if regularization_context:
|
| 604 |
+
logging.info(
|
| 605 |
+
"Delta-based regularization: enabled | weight=%.2e | vector=%s",
|
| 606 |
+
config.regularization_coeff,
|
| 607 |
+
regularization_context["vector_path"],
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
# Training loop - iterate until we reach num_train_steps
|
| 611 |
+
pbar = (
|
| 612 |
+
tqdm.tqdm(total=config.num_train_steps, initial=global_step, desc="Training", disable=not is_main)
|
| 613 |
+
if is_main
|
| 614 |
+
else None
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
while global_step < config.num_train_steps:
|
| 618 |
+
# Set epoch for distributed training
|
| 619 |
+
if use_ddp and hasattr(loader, "set_epoch"):
|
| 620 |
+
loader.set_epoch(global_step // len(loader))
|
| 621 |
+
|
| 622 |
+
for observation, actions in loader:
|
| 623 |
+
# Check if we've reached the target number of steps
|
| 624 |
+
if global_step >= config.num_train_steps:
|
| 625 |
+
break
|
| 626 |
+
|
| 627 |
+
# The unified data loader returns (observation, actions) tuple
|
| 628 |
+
observation = jax.tree.map(lambda x: x.to(device), observation) # noqa: PLW2901
|
| 629 |
+
actions = actions.to(torch.float32) # noqa: PLW2901
|
| 630 |
+
actions = actions.to(device) # noqa: PLW2901
|
| 631 |
+
|
| 632 |
+
# Update LR
|
| 633 |
+
for pg in optim.param_groups:
|
| 634 |
+
pg["lr"] = lr_schedule(global_step)
|
| 635 |
+
|
| 636 |
+
# Forward pass
|
| 637 |
+
losses = model(observation, actions)
|
| 638 |
+
# Ensure losses is a tensor and handle different return types
|
| 639 |
+
if isinstance(losses, list | tuple):
|
| 640 |
+
losses = torch.stack(losses)
|
| 641 |
+
elif not isinstance(losses, torch.Tensor):
|
| 642 |
+
losses = torch.tensor(losses, device=device, dtype=torch.float32)
|
| 643 |
+
|
| 644 |
+
action_loss = losses.mean()
|
| 645 |
+
regularization_loss = compute_regularization_loss(regularization_context, device)
|
| 646 |
+
total_loss = action_loss + regularization_loss
|
| 647 |
+
|
| 648 |
+
# Backward pass
|
| 649 |
+
total_loss.backward()
|
| 650 |
+
|
| 651 |
+
# Log memory usage after backward pass
|
| 652 |
+
if global_step < 5 and is_main and torch.cuda.is_available():
|
| 653 |
+
log_memory_usage(device, global_step, "after_backward")
|
| 654 |
+
|
| 655 |
+
# Gradient clipping
|
| 656 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optimizer.clip_gradient_norm)
|
| 657 |
+
|
| 658 |
+
# Optimizer step
|
| 659 |
+
optim.step()
|
| 660 |
+
optim.zero_grad(set_to_none=True)
|
| 661 |
+
|
| 662 |
+
# Clear gradients more aggressively
|
| 663 |
+
for param in model.parameters():
|
| 664 |
+
if param.grad is not None:
|
| 665 |
+
param.grad.detach_()
|
| 666 |
+
param.grad = None
|
| 667 |
+
|
| 668 |
+
# Collect stats
|
| 669 |
+
if is_main:
|
| 670 |
+
infos.append(
|
| 671 |
+
{
|
| 672 |
+
"action_loss": action_loss.item(),
|
| 673 |
+
"regularization_loss": regularization_loss.item(),
|
| 674 |
+
"total_loss": total_loss.item(),
|
| 675 |
+
"learning_rate": optim.param_groups[0]["lr"],
|
| 676 |
+
"grad_norm": float(grad_norm) if isinstance(grad_norm, torch.Tensor) else grad_norm,
|
| 677 |
+
}
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
if is_main and (global_step % config.log_interval == 0):
|
| 681 |
+
elapsed = time.time() - start_time
|
| 682 |
+
|
| 683 |
+
# Average stats over log interval
|
| 684 |
+
avg_action_loss = sum(info["action_loss"] for info in infos) / len(infos)
|
| 685 |
+
avg_regularization_loss = sum(info["regularization_loss"] for info in infos) / len(infos)
|
| 686 |
+
avg_total_loss = sum(info["total_loss"] for info in infos) / len(infos)
|
| 687 |
+
avg_lr = sum(info["learning_rate"] for info in infos) / len(infos)
|
| 688 |
+
|
| 689 |
+
avg_grad_norm = None
|
| 690 |
+
if any("grad_norm" in info for info in infos):
|
| 691 |
+
vals = [
|
| 692 |
+
info["grad_norm"] for info in infos if "grad_norm" in info and info["grad_norm"] is not None
|
| 693 |
+
]
|
| 694 |
+
if len(vals) > 0:
|
| 695 |
+
avg_grad_norm = sum(vals) / len(vals)
|
| 696 |
+
logging.info(
|
| 697 |
+
f"step={global_step} action_loss={avg_action_loss:.4f} regularization_loss={avg_regularization_loss:.4f} total_loss={avg_total_loss:.4f} lr={avg_lr:.2e} grad_norm={avg_grad_norm:.2f} time={elapsed:.1f}s"
|
| 698 |
+
if avg_grad_norm is not None
|
| 699 |
+
else f"step={global_step} action_loss={avg_action_loss:.4f} regularization_loss={avg_regularization_loss:.4f} total_loss={avg_total_loss:.4f} lr={avg_lr:.2e} time={elapsed:.1f}s"
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
# Log to wandb
|
| 703 |
+
if config.wandb_enabled and len(infos) > 0:
|
| 704 |
+
log_payload = {
|
| 705 |
+
"action_loss": avg_action_loss,
|
| 706 |
+
"regularization_loss": avg_regularization_loss,
|
| 707 |
+
"total_loss": avg_total_loss,
|
| 708 |
+
"learning_rate": avg_lr,
|
| 709 |
+
"step": global_step,
|
| 710 |
+
"time_per_step": elapsed / config.log_interval,
|
| 711 |
+
}
|
| 712 |
+
if avg_grad_norm is not None:
|
| 713 |
+
log_payload["grad_norm"] = avg_grad_norm
|
| 714 |
+
wandb.log(log_payload, step=global_step)
|
| 715 |
+
|
| 716 |
+
start_time = time.time()
|
| 717 |
+
infos = [] # Reset stats collection
|
| 718 |
+
|
| 719 |
+
global_step += 1
|
| 720 |
+
# Save checkpoint using the new mechanism
|
| 721 |
+
save_checkpoint(model, optim, global_step, config, is_main, data_config)
|
| 722 |
+
|
| 723 |
+
# Update progress bar
|
| 724 |
+
if pbar is not None:
|
| 725 |
+
pbar.update(1)
|
| 726 |
+
pbar.set_postfix(
|
| 727 |
+
{
|
| 728 |
+
"action_loss": f"{action_loss.item():.4f}",
|
| 729 |
+
"reg_loss": f"{regularization_loss.item():.4f}",
|
| 730 |
+
"total_loss": f"{total_loss.item():.4f}",
|
| 731 |
+
"lr": f"{optim.param_groups[0]['lr']:.2e}",
|
| 732 |
+
"step": global_step,
|
| 733 |
+
}
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
# Close progress bar
|
| 737 |
+
if pbar is not None:
|
| 738 |
+
pbar.close()
|
| 739 |
+
|
| 740 |
+
# Finish wandb run
|
| 741 |
+
if is_main and config.wandb_enabled:
|
| 742 |
+
wandb.finish()
|
| 743 |
+
|
| 744 |
+
cleanup_ddp()
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
def main():
|
| 748 |
+
init_logging()
|
| 749 |
+
config = _config.cli()
|
| 750 |
+
train_loop(config)
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
if __name__ == "__main__":
|
| 754 |
+
main()
|
capvector-pi05/scripts/train_test.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import os
|
| 3 |
+
import pathlib
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
os.environ["JAX_PLATFORMS"] = "cpu"
|
| 8 |
+
|
| 9 |
+
from openpi.training import config as _config
|
| 10 |
+
|
| 11 |
+
from . import train
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@pytest.mark.parametrize("config_name", ["debug"])
|
| 15 |
+
def test_train(tmp_path: pathlib.Path, config_name: str):
|
| 16 |
+
config = dataclasses.replace(
|
| 17 |
+
_config._CONFIGS_DICT[config_name], # noqa: SLF001
|
| 18 |
+
batch_size=2,
|
| 19 |
+
checkpoint_base_dir=str(tmp_path / "checkpoint"),
|
| 20 |
+
exp_name="test",
|
| 21 |
+
overwrite=False,
|
| 22 |
+
resume=False,
|
| 23 |
+
num_train_steps=2,
|
| 24 |
+
log_interval=1,
|
| 25 |
+
)
|
| 26 |
+
train.main(config)
|
| 27 |
+
|
| 28 |
+
# test resuming
|
| 29 |
+
config = dataclasses.replace(config, resume=True, num_train_steps=4)
|
| 30 |
+
train.main(config)
|
capvector-pi05/src/openpi/__init__.py
ADDED
|
File without changes
|
capvector-pi05/src/openpi/conftest.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import pynvml
|
| 4 |
+
import pytest
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def set_jax_cpu_backend_if_no_gpu() -> None:
|
| 8 |
+
try:
|
| 9 |
+
pynvml.nvmlInit()
|
| 10 |
+
pynvml.nvmlShutdown()
|
| 11 |
+
except pynvml.NVMLError:
|
| 12 |
+
# No GPU found.
|
| 13 |
+
os.environ["JAX_PLATFORMS"] = "cpu"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def pytest_configure(config: pytest.Config) -> None:
|
| 17 |
+
set_jax_cpu_backend_if_no_gpu()
|
capvector-pi05/src/openpi/models/__init__.py
ADDED
|
File without changes
|
capvector-pi05/src/openpi/models/gemma.py
ADDED
|
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Gemma adaptation for Pi, taken from big_vision.
|
| 16 |
+
|
| 17 |
+
We follow this einsum axis naming convention:
|
| 18 |
+
B: batch
|
| 19 |
+
T: query length
|
| 20 |
+
S: k/v length
|
| 21 |
+
N: num query heads
|
| 22 |
+
K: num k/v heads
|
| 23 |
+
G: num query heads per k/v head
|
| 24 |
+
H: head dim
|
| 25 |
+
D: d_model ("features")
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
from collections.abc import Sequence
|
| 29 |
+
import dataclasses
|
| 30 |
+
from typing import Literal, TypeAlias
|
| 31 |
+
|
| 32 |
+
import einops
|
| 33 |
+
import flax.linen as nn
|
| 34 |
+
import jax
|
| 35 |
+
import jax.numpy as jnp
|
| 36 |
+
|
| 37 |
+
import openpi.models.lora as lora
|
| 38 |
+
import openpi.shared.array_typing as at
|
| 39 |
+
import openpi.training.sharding as sharding
|
| 40 |
+
|
| 41 |
+
PALIGEMMA_VOCAB_SIZE = 257_152
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@dataclasses.dataclass
|
| 45 |
+
class Config:
|
| 46 |
+
width: int
|
| 47 |
+
depth: int
|
| 48 |
+
mlp_dim: int
|
| 49 |
+
num_heads: int
|
| 50 |
+
num_kv_heads: int
|
| 51 |
+
head_dim: int
|
| 52 |
+
lora_configs: dict[str, lora.LoRAConfig] = dataclasses.field(default_factory=dict)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
Variant = Literal["dummy", "gemma_300m", "gemma_300m_lora", "gemma_2b", "gemma_2b_lora"]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def get_config(variant: Variant) -> Config:
|
| 59 |
+
"""Returns config for specified gemma variant."""
|
| 60 |
+
if variant == "dummy":
|
| 61 |
+
return Config(
|
| 62 |
+
width=64,
|
| 63 |
+
depth=4,
|
| 64 |
+
mlp_dim=128,
|
| 65 |
+
num_heads=8,
|
| 66 |
+
num_kv_heads=1,
|
| 67 |
+
head_dim=16,
|
| 68 |
+
)
|
| 69 |
+
if variant == "gemma_300m":
|
| 70 |
+
# 311M params
|
| 71 |
+
return Config(
|
| 72 |
+
width=1024,
|
| 73 |
+
depth=18,
|
| 74 |
+
mlp_dim=4096,
|
| 75 |
+
num_heads=8,
|
| 76 |
+
num_kv_heads=1,
|
| 77 |
+
head_dim=256,
|
| 78 |
+
)
|
| 79 |
+
if variant == "gemma_2b":
|
| 80 |
+
return Config(
|
| 81 |
+
width=2048,
|
| 82 |
+
depth=18,
|
| 83 |
+
mlp_dim=16_384,
|
| 84 |
+
num_heads=8,
|
| 85 |
+
num_kv_heads=1,
|
| 86 |
+
head_dim=256,
|
| 87 |
+
)
|
| 88 |
+
if variant == "gemma_2b_lora":
|
| 89 |
+
return Config(
|
| 90 |
+
width=2048,
|
| 91 |
+
depth=18,
|
| 92 |
+
mlp_dim=16_384,
|
| 93 |
+
num_heads=8,
|
| 94 |
+
num_kv_heads=1,
|
| 95 |
+
head_dim=256,
|
| 96 |
+
lora_configs={"attn": lora.LoRAConfig(rank=16, alpha=16.0), "ffn": lora.LoRAConfig(rank=16, alpha=16.0)},
|
| 97 |
+
)
|
| 98 |
+
if variant == "gemma_300m_lora":
|
| 99 |
+
# 311M params
|
| 100 |
+
return Config(
|
| 101 |
+
width=1024,
|
| 102 |
+
depth=18,
|
| 103 |
+
mlp_dim=4096,
|
| 104 |
+
num_heads=8,
|
| 105 |
+
num_kv_heads=1,
|
| 106 |
+
head_dim=256,
|
| 107 |
+
lora_configs={"attn": lora.LoRAConfig(rank=32, alpha=32.0), "ffn": lora.LoRAConfig(rank=32, alpha=32.0)},
|
| 108 |
+
)
|
| 109 |
+
raise ValueError(f"Unknown variant: {variant}")
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@at.typecheck
|
| 113 |
+
class RMSNorm(nn.Module):
|
| 114 |
+
@nn.compact
|
| 115 |
+
def __call__(self, x, cond):
|
| 116 |
+
dtype = x.dtype # original dtype, could be half-precision
|
| 117 |
+
var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True) # compute variance in float32
|
| 118 |
+
normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) # compute normalization in float32
|
| 119 |
+
if cond is None:
|
| 120 |
+
# regular RMSNorm
|
| 121 |
+
scale = self.param("scale", nn.initializers.zeros_init(), (x.shape[-1]))
|
| 122 |
+
normed_inputs = normed_inputs * (
|
| 123 |
+
1 + scale
|
| 124 |
+
) # scale by learned parameter in float32 (matches Flax implementation)
|
| 125 |
+
return normed_inputs.astype(dtype), None # return in original dtype
|
| 126 |
+
|
| 127 |
+
# adaptive RMSNorm
|
| 128 |
+
modulation = nn.Dense(x.shape[-1] * 3, kernel_init=nn.initializers.zeros, dtype=dtype)(cond)
|
| 129 |
+
scale, shift, gate = jnp.split(modulation[:, None, :], 3, axis=-1)
|
| 130 |
+
normed_inputs = normed_inputs * (1 + scale) + shift # scale and shift in float32
|
| 131 |
+
return normed_inputs.astype(dtype), gate
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
@at.typecheck
|
| 135 |
+
class Embedder(nn.Module):
|
| 136 |
+
"""Embedder module."""
|
| 137 |
+
|
| 138 |
+
vocab_size: int
|
| 139 |
+
embed_dim: int
|
| 140 |
+
|
| 141 |
+
def setup(self):
|
| 142 |
+
self.input_embedding_table = self.param(
|
| 143 |
+
"input_embedding",
|
| 144 |
+
nn.initializers.normal(),
|
| 145 |
+
(self.vocab_size, self.embed_dim),
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
def encode(self, x):
|
| 149 |
+
x = self.input_embedding_table[(x,)]
|
| 150 |
+
x *= jnp.sqrt(self.embed_dim).astype(x.dtype)
|
| 151 |
+
return x
|
| 152 |
+
|
| 153 |
+
def decode(self, x):
|
| 154 |
+
return jnp.dot(x, self.input_embedding_table.T)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
@at.typecheck
|
| 158 |
+
class Attention(nn.Module):
|
| 159 |
+
"""Attention module."""
|
| 160 |
+
|
| 161 |
+
configs: Sequence[Config]
|
| 162 |
+
|
| 163 |
+
@nn.compact
|
| 164 |
+
def __call__(self, xs, positions, attn_mask, kv_cache):
|
| 165 |
+
# all experts must share the same head dim, num heads, and num kv heads for self-attention to work
|
| 166 |
+
assert all(config.head_dim == self.configs[0].head_dim for config in self.configs)
|
| 167 |
+
assert all(config.num_heads == self.configs[0].num_heads for config in self.configs)
|
| 168 |
+
assert all(config.num_kv_heads == self.configs[0].num_kv_heads for config in self.configs)
|
| 169 |
+
|
| 170 |
+
dtype = next(x.dtype for x in xs if x is not None) # original dtype, could be half-precision
|
| 171 |
+
|
| 172 |
+
qkvs = []
|
| 173 |
+
for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
|
| 174 |
+
if x is None:
|
| 175 |
+
continue
|
| 176 |
+
if config.num_kv_heads == config.num_heads:
|
| 177 |
+
qkv_einsum = lora.Einsum(
|
| 178 |
+
shape=(3, config.num_heads, config.width, config.head_dim),
|
| 179 |
+
name=_name("qkv_einsum", i),
|
| 180 |
+
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
|
| 181 |
+
lora_config=config.lora_configs.get("attn"),
|
| 182 |
+
)
|
| 183 |
+
qkvs.append(qkv_einsum("BSD,3KDH->3BSKH", x))
|
| 184 |
+
else:
|
| 185 |
+
q_einsum = lora.Einsum(
|
| 186 |
+
shape=(config.num_heads, config.width, config.head_dim),
|
| 187 |
+
name=_name("q_einsum", i),
|
| 188 |
+
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
|
| 189 |
+
lora_config=config.lora_configs.get("attn"),
|
| 190 |
+
)
|
| 191 |
+
q = q_einsum("BTD,NDH->BTNH", x)
|
| 192 |
+
kv_einsum = lora.Einsum(
|
| 193 |
+
shape=(2, config.num_kv_heads, config.width, config.head_dim),
|
| 194 |
+
name=_name("kv_einsum", i),
|
| 195 |
+
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
|
| 196 |
+
lora_config=config.lora_configs.get("attn"),
|
| 197 |
+
)
|
| 198 |
+
k, v = kv_einsum("BSD,2KDH->2BSKH", x)
|
| 199 |
+
qkvs.append((q, k, v))
|
| 200 |
+
|
| 201 |
+
q, k, v = (jnp.concatenate(y, axis=1) for y in zip(*qkvs, strict=True))
|
| 202 |
+
|
| 203 |
+
q = _apply_rope(q, positions=positions)
|
| 204 |
+
q *= self.configs[0].head_dim ** -0.5
|
| 205 |
+
|
| 206 |
+
k = _apply_rope(k, positions=positions)
|
| 207 |
+
|
| 208 |
+
# should still be half-precision here (if input was half-precision)
|
| 209 |
+
assert q.dtype == k.dtype == v.dtype == dtype
|
| 210 |
+
|
| 211 |
+
if kv_cache is not None:
|
| 212 |
+
cache_k, cache_v = kv_cache
|
| 213 |
+
k = jnp.concatenate([cache_k, k], axis=1)
|
| 214 |
+
v = jnp.concatenate([cache_v, v], axis=1)
|
| 215 |
+
|
| 216 |
+
q = einops.rearrange(q, "B T (K G) H -> B T K G H", K=self.configs[0].num_kv_heads)
|
| 217 |
+
logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k, preferred_element_type=jnp.float32)
|
| 218 |
+
|
| 219 |
+
if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]):
|
| 220 |
+
raise ValueError(
|
| 221 |
+
f"Attention mask with shape {attn_mask.shape} but shapes for q and k are: {q.shape} and {k.shape}"
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
# big_neg = jnp.finfo(logits.dtype).min
|
| 225 |
+
big_neg = -2.3819763e38 # See gemma/modules.py
|
| 226 |
+
masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg)
|
| 227 |
+
|
| 228 |
+
probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype)
|
| 229 |
+
|
| 230 |
+
encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v)
|
| 231 |
+
encoded = einops.rearrange(encoded, "B T K G H -> B T (K G) H")
|
| 232 |
+
|
| 233 |
+
out = []
|
| 234 |
+
start = 0
|
| 235 |
+
for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
|
| 236 |
+
if x is not None:
|
| 237 |
+
end = start + x.shape[1]
|
| 238 |
+
out_einsum = lora.Einsum(
|
| 239 |
+
shape=(config.num_heads, config.head_dim, config.width),
|
| 240 |
+
name=_name("attn_vec_einsum", i),
|
| 241 |
+
init_fn=nn.initializers.lecun_normal(in_axis=(-3, -2), out_axis=-1),
|
| 242 |
+
lora_config=config.lora_configs.get("attn"),
|
| 243 |
+
)
|
| 244 |
+
out.append(out_einsum("BTNH,NHD->BTD", encoded[:, start:end]))
|
| 245 |
+
start = end
|
| 246 |
+
else:
|
| 247 |
+
out.append(None)
|
| 248 |
+
|
| 249 |
+
return out, (k, v)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
@at.typecheck
|
| 253 |
+
class FeedForward(nn.Module):
|
| 254 |
+
"""Feed forward module."""
|
| 255 |
+
|
| 256 |
+
features: int
|
| 257 |
+
hidden_dim: int
|
| 258 |
+
|
| 259 |
+
@nn.compact
|
| 260 |
+
def __call__(self, x):
|
| 261 |
+
dtype = x.dtype # original dtype, could be half-precision
|
| 262 |
+
w_gating = self.param(
|
| 263 |
+
"gating_einsum",
|
| 264 |
+
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
|
| 265 |
+
(2, self.features, self.hidden_dim),
|
| 266 |
+
).astype(dtype)
|
| 267 |
+
ff_gate = jnp.dot(x, w_gating[0])
|
| 268 |
+
gate_value = nn.gelu(ff_gate)
|
| 269 |
+
|
| 270 |
+
ff1 = jnp.dot(x, w_gating[1])
|
| 271 |
+
activations = gate_value * ff1
|
| 272 |
+
|
| 273 |
+
w_linear = self.param(
|
| 274 |
+
"linear",
|
| 275 |
+
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1),
|
| 276 |
+
(self.hidden_dim, self.features),
|
| 277 |
+
).astype(dtype)
|
| 278 |
+
outputs = jnp.dot(activations, w_linear)
|
| 279 |
+
assert outputs.dtype == dtype
|
| 280 |
+
return outputs
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
@at.typecheck
|
| 284 |
+
class Block(nn.Module):
|
| 285 |
+
"""Transformer block."""
|
| 286 |
+
|
| 287 |
+
configs: tuple[Config, ...]
|
| 288 |
+
|
| 289 |
+
dropout: float = 0.0
|
| 290 |
+
dropout_bdims: tuple[int, ...] = ()
|
| 291 |
+
|
| 292 |
+
@nn.compact
|
| 293 |
+
def __call__(self, xs, kv_cache, positions, attn_mask, adarms_cond, deterministic=True): # noqa: FBT002
|
| 294 |
+
xs = sharding.activation_sharding_constraint(xs)
|
| 295 |
+
drop = nn.Dropout(self.dropout, self.dropout_bdims) if self.dropout else lambda x, _: x
|
| 296 |
+
|
| 297 |
+
attn = Attention(configs=self.configs, name="attn")
|
| 298 |
+
|
| 299 |
+
pre_attn = []
|
| 300 |
+
gates = []
|
| 301 |
+
for i, x in enumerate(xs):
|
| 302 |
+
if x is not None:
|
| 303 |
+
x, gate = RMSNorm(name=_name("pre_attention_norm", i))(x, adarms_cond[i]) # noqa: PLW2901
|
| 304 |
+
pre_attn.append(x)
|
| 305 |
+
gates.append(gate if x is not None else None)
|
| 306 |
+
|
| 307 |
+
pre_attn = sharding.activation_sharding_constraint(pre_attn)
|
| 308 |
+
post_attn, kv_cache = attn(pre_attn, positions, attn_mask, kv_cache)
|
| 309 |
+
post_attn = jax.tree.map(lambda x: drop(x, deterministic), post_attn)
|
| 310 |
+
post_attn = sharding.activation_sharding_constraint(post_attn)
|
| 311 |
+
xs = [_gated_residual(x, y, gate) for x, y, gate in zip(xs, post_attn, gates, strict=True)]
|
| 312 |
+
xs = sharding.activation_sharding_constraint(xs)
|
| 313 |
+
|
| 314 |
+
out = []
|
| 315 |
+
gates = []
|
| 316 |
+
for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
|
| 317 |
+
if x is not None:
|
| 318 |
+
x, gate = RMSNorm(name=_name("pre_ffw_norm", i))(x, adarms_cond[i]) # noqa: PLW2901
|
| 319 |
+
x = lora.FeedForward( # noqa: PLW2901
|
| 320 |
+
features=config.width,
|
| 321 |
+
hidden_dim=config.mlp_dim,
|
| 322 |
+
name=_name("mlp", i),
|
| 323 |
+
lora_config=config.lora_configs.get("ffn"),
|
| 324 |
+
)(x)
|
| 325 |
+
out.append(x)
|
| 326 |
+
gates.append(gate if x is not None else None)
|
| 327 |
+
|
| 328 |
+
out = sharding.activation_sharding_constraint(out)
|
| 329 |
+
out = jax.tree.map(lambda x: drop(x, deterministic), out)
|
| 330 |
+
xs = [_gated_residual(x, y, gate) for x, y, gate in zip(xs, out, gates, strict=True)]
|
| 331 |
+
xs = sharding.activation_sharding_constraint(xs)
|
| 332 |
+
|
| 333 |
+
return xs, kv_cache
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
KVCache: TypeAlias = tuple[at.Float[at.Array, "l b _t _k _h"], at.Float[at.Array, "l b _t _v _h"]]
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
@at.typecheck
|
| 340 |
+
class Module(nn.Module):
|
| 341 |
+
"""Transformer model, supporting a mixture of different weights for different tokens."""
|
| 342 |
+
|
| 343 |
+
configs: Sequence[Config] # list of configs, one for each expert
|
| 344 |
+
embed_dtype: str
|
| 345 |
+
|
| 346 |
+
dropout: float = 0.0
|
| 347 |
+
dropout_bdims: tuple[int, ...] = () # Every float is dropped independently.
|
| 348 |
+
adarms: bool = False
|
| 349 |
+
|
| 350 |
+
def setup(self):
|
| 351 |
+
# all experts must have the same depth
|
| 352 |
+
assert all(config.depth == self.configs[0].depth for config in self.configs)
|
| 353 |
+
|
| 354 |
+
self.embedder = Embedder(
|
| 355 |
+
vocab_size=PALIGEMMA_VOCAB_SIZE,
|
| 356 |
+
embed_dim=self.configs[0].width, # embedder for first expert only
|
| 357 |
+
name="embedder",
|
| 358 |
+
)
|
| 359 |
+
block_cls = nn.remat(
|
| 360 |
+
Block,
|
| 361 |
+
prevent_cse=False,
|
| 362 |
+
static_argnums=(5,), # 0=self, 6=deterministic
|
| 363 |
+
policy=jax.checkpoint_policies.nothing_saveable,
|
| 364 |
+
)
|
| 365 |
+
self.layers = nn.scan(
|
| 366 |
+
block_cls,
|
| 367 |
+
variable_axes={"params": 0},
|
| 368 |
+
split_rngs={"params": True, "dropout": True},
|
| 369 |
+
in_axes=(
|
| 370 |
+
0,
|
| 371 |
+
nn.broadcast,
|
| 372 |
+
nn.broadcast,
|
| 373 |
+
nn.broadcast,
|
| 374 |
+
nn.broadcast,
|
| 375 |
+
), # 0=kv_cache, 1=positions, 2=mask, 3=adarms_cond, 4=deterministic
|
| 376 |
+
length=self.configs[0].depth,
|
| 377 |
+
)(
|
| 378 |
+
configs=self.configs,
|
| 379 |
+
dropout=self.dropout,
|
| 380 |
+
dropout_bdims=self.dropout_bdims,
|
| 381 |
+
)
|
| 382 |
+
self.final_norms = [RMSNorm(name=_name("final_norm", i)) for i in range(len(self.configs))]
|
| 383 |
+
|
| 384 |
+
@at.typecheck
|
| 385 |
+
def embed(self, tokens: at.Int[at.Array, "b t"]) -> at.Float[at.Array, "b t d"]:
|
| 386 |
+
return self.embedder.encode(tokens).astype(self.embed_dtype)
|
| 387 |
+
|
| 388 |
+
@at.typecheck
|
| 389 |
+
def __call__(
|
| 390 |
+
self,
|
| 391 |
+
# list of token arrays, one for each expert, or None if that expert should not be run
|
| 392 |
+
embedded: Sequence[at.Float[at.Array, "b _t _d"] | None],
|
| 393 |
+
positions: at.Int[at.Array, "b t"],
|
| 394 |
+
mask: at.Bool[at.Array, "b t s"],
|
| 395 |
+
adarms_cond: Sequence[at.Float[at.Array, "b _d"] | None] | None = None,
|
| 396 |
+
*,
|
| 397 |
+
kv_cache: KVCache | None = None,
|
| 398 |
+
deterministic: bool = True,
|
| 399 |
+
) -> tuple[Sequence[at.Float[at.Array, "b _t _d"] | None], KVCache]:
|
| 400 |
+
embedded = jax.tree.map(lambda e: e.astype(self.embed_dtype), embedded)
|
| 401 |
+
mask = jnp.asarray(mask)[:, None, :, :]
|
| 402 |
+
if adarms_cond is None:
|
| 403 |
+
adarms_cond = [None] * len(self.configs)
|
| 404 |
+
|
| 405 |
+
embedded, kv_cache = self.layers(embedded, kv_cache, positions, mask, adarms_cond, deterministic)
|
| 406 |
+
|
| 407 |
+
assert all(e.dtype == jnp.dtype(self.embed_dtype) for e in embedded if e is not None)
|
| 408 |
+
|
| 409 |
+
return [
|
| 410 |
+
f(e, a)[0] if e is not None else e for f, e, a in zip(self.final_norms, embedded, adarms_cond, strict=True)
|
| 411 |
+
], kv_cache
|
| 412 |
+
|
| 413 |
+
def init(self, use_adarms: Sequence[bool]):
|
| 414 |
+
"""Convenience method for initializing all parameters, necessary due to the quirks of linen."""
|
| 415 |
+
self.embed(jnp.zeros((1, 1), dtype=jnp.int32))
|
| 416 |
+
self(
|
| 417 |
+
[jnp.zeros((1, 1, c.width)) for c in self.configs],
|
| 418 |
+
jnp.zeros((1, len(self.configs)), dtype=jnp.int32),
|
| 419 |
+
jnp.zeros((1, len(self.configs), len(self.configs)), dtype=bool),
|
| 420 |
+
adarms_cond=[jnp.zeros((1, c.width)) if u else None for u, c in zip(use_adarms, self.configs, strict=True)],
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def _apply_rope(x, *, positions, max_wavelength=10_000):
|
| 425 |
+
"""Applies RoPE positions [B, L] to x [B, L, H, D]."""
|
| 426 |
+
freq_exponents = (2.0 / x.shape[-1]) * jnp.arange(x.shape[-1] // 2, dtype=jnp.float32)
|
| 427 |
+
timescale = max_wavelength**freq_exponents
|
| 428 |
+
radians = positions[..., None] / timescale[None, None, :]
|
| 429 |
+
radians = radians[..., None, :]
|
| 430 |
+
assert radians.dtype == jnp.float32
|
| 431 |
+
# radians.shape = [...,L,1,d=D/2]
|
| 432 |
+
sin, cos = jnp.sin(radians), jnp.cos(radians)
|
| 433 |
+
x1, x2 = jnp.split(x, 2, axis=-1)
|
| 434 |
+
res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)
|
| 435 |
+
assert res.dtype == jnp.float32
|
| 436 |
+
# The original bigvision impl allows RoPE to upcast to float32. It is then immediately downcast again to the cache
|
| 437 |
+
# dtype when in inference mode (but not in training mode). I don't think any of this was intentional. Based on the
|
| 438 |
+
# original DeepMind impl, as well as the widely-used transformers impl, it is ok to always downcast back to bfloat16
|
| 439 |
+
# here.
|
| 440 |
+
return res.astype(x.dtype)
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def _name(name, i):
|
| 444 |
+
# we name layers like this because we want the first expert's weights to have no suffix (e.g., "attn"), so that they
|
| 445 |
+
# can be loaded seamlessly from the existing PaliGemma checkpoint. subsequent experts will have a suffix (e.g.,
|
| 446 |
+
# "attn_1") and their weights will be initialized from scratch. in practice, we only use two experts -- PaliGemma,
|
| 447 |
+
# and the action expert.
|
| 448 |
+
if i == 0:
|
| 449 |
+
return name
|
| 450 |
+
return f"{name}_{i}"
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def _gated_residual(x, y, gate):
|
| 454 |
+
assert (x is None) == (y is None)
|
| 455 |
+
if x is None:
|
| 456 |
+
return None
|
| 457 |
+
if gate is None:
|
| 458 |
+
return x + y
|
| 459 |
+
return x + y * gate
|
capvector-pi05/src/openpi/models/gemma_fast.py
ADDED
|
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Gemma model implementation from big_vision/models/ppp/gemma.py (with small modifications for NNX compatibility)
|
| 17 |
+
Used for FAST autoregressive policies.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import dataclasses
|
| 21 |
+
from typing import Literal, TypeAlias
|
| 22 |
+
|
| 23 |
+
import einops
|
| 24 |
+
import flax.linen as nn
|
| 25 |
+
import jax
|
| 26 |
+
import jax.numpy as jnp
|
| 27 |
+
import ml_collections
|
| 28 |
+
|
| 29 |
+
import openpi.models.lora as lora
|
| 30 |
+
import openpi.shared.array_typing as at
|
| 31 |
+
|
| 32 |
+
Variant = Literal["gemma_2b", "gemma_2b_lora"]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_config(variant):
|
| 36 |
+
"""Returns config for specified gemma variant."""
|
| 37 |
+
if variant == "gemma_2b":
|
| 38 |
+
return ml_collections.ConfigDict(
|
| 39 |
+
{
|
| 40 |
+
"variant": variant,
|
| 41 |
+
"width": 2048,
|
| 42 |
+
"depth": 18,
|
| 43 |
+
"mlp_dim": 16_384,
|
| 44 |
+
"num_heads": 8,
|
| 45 |
+
"num_kv_heads": 1,
|
| 46 |
+
"head_dim": 256,
|
| 47 |
+
"norm_eps": 1e-6,
|
| 48 |
+
"vocab_size": 257_152,
|
| 49 |
+
"scan": True,
|
| 50 |
+
"remat_policy": "nothing_saveable",
|
| 51 |
+
}
|
| 52 |
+
)
|
| 53 |
+
if variant == "gemma_2b_lora":
|
| 54 |
+
return ml_collections.ConfigDict(
|
| 55 |
+
{
|
| 56 |
+
"variant": variant,
|
| 57 |
+
"width": 2048,
|
| 58 |
+
"depth": 18,
|
| 59 |
+
"mlp_dim": 16_384,
|
| 60 |
+
"num_heads": 8,
|
| 61 |
+
"num_kv_heads": 1,
|
| 62 |
+
"head_dim": 256,
|
| 63 |
+
"norm_eps": 1e-6,
|
| 64 |
+
"vocab_size": 257_152,
|
| 65 |
+
"scan": True,
|
| 66 |
+
"remat_policy": "nothing_saveable",
|
| 67 |
+
"lora_configs": {
|
| 68 |
+
"attn": lora.LoRAConfig(rank=16, alpha=16.0),
|
| 69 |
+
"ffn": lora.LoRAConfig(rank=16, alpha=16.0),
|
| 70 |
+
},
|
| 71 |
+
}
|
| 72 |
+
)
|
| 73 |
+
raise ValueError(f"Unknown variant: {variant}")
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@at.typecheck
|
| 77 |
+
class Einsum(nn.Module):
|
| 78 |
+
shape: tuple[int, ...]
|
| 79 |
+
|
| 80 |
+
@nn.compact
|
| 81 |
+
def __call__(self, eqn, x):
|
| 82 |
+
dtype = x.dtype # original dtype, could be half-precision
|
| 83 |
+
w = self.param("w", nn.initializers.zeros_init(), self.shape).astype(dtype)
|
| 84 |
+
return jnp.einsum(eqn, x, w)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@at.typecheck
|
| 88 |
+
class RMSNorm(nn.Module):
|
| 89 |
+
@nn.compact
|
| 90 |
+
def __call__(self, x):
|
| 91 |
+
dtype = x.dtype # original dtype, could be half-precision
|
| 92 |
+
scale = self.param("scale", nn.initializers.zeros_init(), (x.shape[-1]))
|
| 93 |
+
var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True) # compute variance in float32
|
| 94 |
+
normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) # compute normalization in float32
|
| 95 |
+
normed_inputs = normed_inputs * (
|
| 96 |
+
1 + scale
|
| 97 |
+
) # scale by learned parameter in float32 (matches Flax implementation)
|
| 98 |
+
return normed_inputs.astype(dtype) # return in original dtype
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@at.typecheck
|
| 102 |
+
class Embedder(nn.Module):
|
| 103 |
+
"""Embedder module."""
|
| 104 |
+
|
| 105 |
+
vocab_size: int
|
| 106 |
+
embed_dim: int
|
| 107 |
+
|
| 108 |
+
def setup(self):
|
| 109 |
+
self.input_embedding_table = self.param(
|
| 110 |
+
"input_embedding",
|
| 111 |
+
nn.initializers.zeros_init(),
|
| 112 |
+
(self.vocab_size, self.embed_dim),
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def encode(self, x):
|
| 116 |
+
x = self.input_embedding_table[(x,)]
|
| 117 |
+
x *= jnp.sqrt(self.embed_dim).astype(x.dtype)
|
| 118 |
+
return x
|
| 119 |
+
|
| 120 |
+
def decode(self, x):
|
| 121 |
+
return jnp.dot(x, self.input_embedding_table.T)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@at.typecheck
|
| 125 |
+
class Attention(nn.Module):
|
| 126 |
+
"""Attention module."""
|
| 127 |
+
|
| 128 |
+
num_heads: int
|
| 129 |
+
num_kv_heads: int
|
| 130 |
+
features: int
|
| 131 |
+
head_dim: int
|
| 132 |
+
|
| 133 |
+
cache_dtype: str | None = None
|
| 134 |
+
|
| 135 |
+
lora_config: lora.LoRAConfig | None = None
|
| 136 |
+
|
| 137 |
+
def setup(self):
|
| 138 |
+
if self.num_kv_heads == self.num_heads:
|
| 139 |
+
self.qkv_einsum = lora.Einsum(
|
| 140 |
+
shape=(3, self.num_heads, self.features, self.head_dim),
|
| 141 |
+
name="qkv_einsum",
|
| 142 |
+
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
|
| 143 |
+
lora_config=self.lora_config,
|
| 144 |
+
)
|
| 145 |
+
else:
|
| 146 |
+
self.q_einsum = lora.Einsum(
|
| 147 |
+
shape=(self.num_heads, self.features, self.head_dim),
|
| 148 |
+
name="q_einsum",
|
| 149 |
+
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
|
| 150 |
+
lora_config=self.lora_config,
|
| 151 |
+
)
|
| 152 |
+
self.kv_einsum = lora.Einsum(
|
| 153 |
+
shape=(2, self.num_kv_heads, self.features, self.head_dim),
|
| 154 |
+
name="kv_einsum",
|
| 155 |
+
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
|
| 156 |
+
lora_config=self.lora_config,
|
| 157 |
+
)
|
| 158 |
+
self.attn_vec_einsum = lora.Einsum(
|
| 159 |
+
shape=(self.num_heads, self.head_dim, self.features),
|
| 160 |
+
name="attn_vec_einsum",
|
| 161 |
+
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
|
| 162 |
+
lora_config=self.lora_config,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
def _init_cache(self, k, v, cache_size):
|
| 166 |
+
"""Initialize KV cache"""
|
| 167 |
+
prefill_len = k.shape[1]
|
| 168 |
+
pad_width = ((0, 0), (0, cache_size - prefill_len), (0, 0), (0, 0))
|
| 169 |
+
cache_dtype = self.cache_dtype or k.dtype
|
| 170 |
+
k_cache = jnp.pad(k.astype(cache_dtype), pad_width)
|
| 171 |
+
v_cache = jnp.pad(v.astype(cache_dtype), pad_width)
|
| 172 |
+
idx = jnp.zeros((k.shape[0],), dtype=jnp.int32) + prefill_len
|
| 173 |
+
return idx, k_cache, v_cache
|
| 174 |
+
|
| 175 |
+
def _update_cache(self, k, v, idx, k_cache, v_cache):
|
| 176 |
+
"""Update KV cache with new values"""
|
| 177 |
+
assert k.shape[1] == 1, "Only support kv-cache updates of length 1"
|
| 178 |
+
indices = (0, idx[0], 0, 0)
|
| 179 |
+
cache_dtype = self.cache_dtype or k.dtype
|
| 180 |
+
k_new = jax.lax.dynamic_update_slice(k_cache, k.astype(cache_dtype), indices)
|
| 181 |
+
v_new = jax.lax.dynamic_update_slice(v_cache, v.astype(cache_dtype), indices)
|
| 182 |
+
idx_new = idx + 1
|
| 183 |
+
return idx_new, k_new, v_new
|
| 184 |
+
|
| 185 |
+
@nn.compact
|
| 186 |
+
def __call__(self, x, positions, attn_mask, kv_cache, decode, deterministic=True): # noqa: FBT002
|
| 187 |
+
dtype = x.dtype # original dtype, could be half-precision
|
| 188 |
+
if self.num_kv_heads == self.num_heads:
|
| 189 |
+
q, k, v = self.qkv_einsum("BSD,3KDH->3BSKH", x)
|
| 190 |
+
else:
|
| 191 |
+
q = self.q_einsum("BTD,NDH->BTNH", x)
|
| 192 |
+
k, v = self.kv_einsum("BSD,2KDH->2BSKH", x)
|
| 193 |
+
|
| 194 |
+
q = _apply_rope(q, positions=positions) # promotes to float32
|
| 195 |
+
q *= self.head_dim**-0.5
|
| 196 |
+
|
| 197 |
+
k = _apply_rope(k, positions=positions) # promotes to float32
|
| 198 |
+
|
| 199 |
+
if kv_cache is None:
|
| 200 |
+
idx, k_cache, v_cache = self._init_cache(k, v, attn_mask.shape[-1])
|
| 201 |
+
else:
|
| 202 |
+
idx, k_cache, v_cache = kv_cache
|
| 203 |
+
idx, k_cache, v_cache = self._update_cache(k, v, idx, k_cache, v_cache)
|
| 204 |
+
|
| 205 |
+
k, v = k_cache, v_cache
|
| 206 |
+
kv_cache = (idx, k_cache, v_cache)
|
| 207 |
+
|
| 208 |
+
q = einops.rearrange(q, "B T (K G) H -> B T K G H", K=self.num_kv_heads)
|
| 209 |
+
logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k, preferred_element_type=jnp.float32)
|
| 210 |
+
|
| 211 |
+
if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]):
|
| 212 |
+
raise ValueError(
|
| 213 |
+
f"Attention mask with shape {attn_mask.shape} but shapes for q and k are: {q.shape} and {k.shape}"
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
# big_neg = jnp.finfo(logits.dtype).min
|
| 217 |
+
big_neg = -2.3819763e38 # See gemma/modules.py
|
| 218 |
+
masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg)
|
| 219 |
+
|
| 220 |
+
probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype)
|
| 221 |
+
|
| 222 |
+
encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v)
|
| 223 |
+
encoded = einops.rearrange(encoded, "B T K G H -> B T (K G) H")
|
| 224 |
+
return self.attn_vec_einsum("BTNH,NHD->BTD", encoded), kv_cache
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
@at.typecheck
|
| 228 |
+
class Block(nn.Module):
|
| 229 |
+
"""Transformer block."""
|
| 230 |
+
|
| 231 |
+
num_heads: int
|
| 232 |
+
num_kv_heads: int
|
| 233 |
+
embed_dim: int
|
| 234 |
+
head_dim: int
|
| 235 |
+
hidden_dim: int
|
| 236 |
+
|
| 237 |
+
dropout: float = 0.0
|
| 238 |
+
dropout_bdims: tuple[int, ...] = ()
|
| 239 |
+
cache_dtype: str | None = None
|
| 240 |
+
lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict)
|
| 241 |
+
|
| 242 |
+
def setup(self):
|
| 243 |
+
self.pre_attention_norm = RMSNorm()
|
| 244 |
+
self.attn = Attention(
|
| 245 |
+
num_heads=self.num_heads,
|
| 246 |
+
num_kv_heads=self.num_kv_heads,
|
| 247 |
+
features=self.embed_dim,
|
| 248 |
+
head_dim=self.head_dim,
|
| 249 |
+
cache_dtype=self.cache_dtype,
|
| 250 |
+
lora_config=self.lora_configs.get("attn"),
|
| 251 |
+
)
|
| 252 |
+
self.pre_ffw_norm = RMSNorm()
|
| 253 |
+
self.mlp = lora.FeedForward(
|
| 254 |
+
features=self.embed_dim, hidden_dim=self.hidden_dim, name="mlp", lora_config=self.lora_configs.get("ffn")
|
| 255 |
+
)
|
| 256 |
+
if self.dropout:
|
| 257 |
+
self.drop = nn.Dropout(self.dropout, self.dropout_bdims)
|
| 258 |
+
else:
|
| 259 |
+
self.drop = lambda x, _: x
|
| 260 |
+
|
| 261 |
+
def __call__(self, x, kv_cache, positions, attn_mask, decode, deterministic=True): # noqa: FBT002
|
| 262 |
+
x = nn.with_logical_constraint(x, ("act_batch", "act_len", "act_emb"))
|
| 263 |
+
inputs_normalized = self.pre_attention_norm(x)
|
| 264 |
+
attn_output, kv_cache = self.attn(inputs_normalized, positions, attn_mask, kv_cache, decode, deterministic)
|
| 265 |
+
attn_output = self.drop(attn_output, deterministic)
|
| 266 |
+
attn_output += x
|
| 267 |
+
residual = attn_output
|
| 268 |
+
attn_output = self.pre_ffw_norm(attn_output)
|
| 269 |
+
outputs = self.mlp(attn_output)
|
| 270 |
+
outputs = self.drop(outputs, deterministic)
|
| 271 |
+
outputs = residual + outputs
|
| 272 |
+
return outputs, kv_cache
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
KVCache: TypeAlias = tuple[at.Int[at.Array, " b"], at.Float[at.Array, "b _t _k _h"], at.Float[at.Array, "b _t _v _h"]]
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
@at.typecheck
|
| 279 |
+
class Module(nn.Module):
|
| 280 |
+
"""gemma model."""
|
| 281 |
+
|
| 282 |
+
variant: str
|
| 283 |
+
|
| 284 |
+
width: int
|
| 285 |
+
depth: int
|
| 286 |
+
mlp_dim: int
|
| 287 |
+
num_heads: int
|
| 288 |
+
num_kv_heads: int
|
| 289 |
+
head_dim: int
|
| 290 |
+
norm_eps: float
|
| 291 |
+
vocab_size: int
|
| 292 |
+
embed_dtype: str
|
| 293 |
+
|
| 294 |
+
dropout: float = 0.0
|
| 295 |
+
dropout_bdims: tuple[int, ...] = () # Every float is dropped independently.
|
| 296 |
+
cache_dtype: str | None = None
|
| 297 |
+
|
| 298 |
+
scan: bool = False
|
| 299 |
+
remat_policy: str = "none"
|
| 300 |
+
lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict)
|
| 301 |
+
|
| 302 |
+
@nn.compact
|
| 303 |
+
def __call__(
|
| 304 |
+
self,
|
| 305 |
+
tokens=None,
|
| 306 |
+
embedded_prefix=None,
|
| 307 |
+
embed_only=False, # noqa: FBT002
|
| 308 |
+
pre_logits=None,
|
| 309 |
+
positions=None,
|
| 310 |
+
mask=None,
|
| 311 |
+
decode=False, # noqa: FBT002
|
| 312 |
+
kv_cache=None,
|
| 313 |
+
deterministic=True, # noqa: FBT002
|
| 314 |
+
return_prelogits=False, # noqa: FBT002
|
| 315 |
+
):
|
| 316 |
+
"""Embed only, or complete forward pass.
|
| 317 |
+
|
| 318 |
+
Args:
|
| 319 |
+
tokens: Embedded, then and appended to `embedded_prefix`. Can be None.
|
| 320 |
+
embedded_prefix: Optional prefix that is already embedded.
|
| 321 |
+
embed_only: Whether to compute embeddings only.
|
| 322 |
+
pre_logits: If present computes logits from pre_logits and returns.
|
| 323 |
+
positions: Optional `[B, T]` allows to specify the absolute position of
|
| 324 |
+
the tokens.
|
| 325 |
+
mask: Optional attention mask `[B, T, S]`.
|
| 326 |
+
decode: Whether to use kv-cache. Caller must pass masks and positions.
|
| 327 |
+
deterministic: Forwarded to all dropout layers.
|
| 328 |
+
return_prelogits: Whether to return the pre-logits.
|
| 329 |
+
|
| 330 |
+
Returns:
|
| 331 |
+
If `embed_only=False`, then `(logits, out)` will be returned.
|
| 332 |
+
If `embed_only=True`, then the embeddings will be returned.
|
| 333 |
+
If `return_prelogits=True`, then the pre-logits will be returned.
|
| 334 |
+
"""
|
| 335 |
+
out = {}
|
| 336 |
+
|
| 337 |
+
embedder = Embedder(vocab_size=self.vocab_size, embed_dim=self.width, name="embedder")
|
| 338 |
+
|
| 339 |
+
if pre_logits is not None:
|
| 340 |
+
x = out["pre_logits"] = pre_logits
|
| 341 |
+
logits = out["logits"] = embedder.decode(x)
|
| 342 |
+
return logits, out
|
| 343 |
+
|
| 344 |
+
x = []
|
| 345 |
+
if embedded_prefix is not None:
|
| 346 |
+
x.append(embedded_prefix)
|
| 347 |
+
if tokens is not None:
|
| 348 |
+
x.append(embedder.encode(tokens))
|
| 349 |
+
|
| 350 |
+
x = jnp.concatenate(x, axis=-2)
|
| 351 |
+
x = x.astype(self.embed_dtype)
|
| 352 |
+
batch_size, seq_len, width = x.shape
|
| 353 |
+
|
| 354 |
+
if embed_only:
|
| 355 |
+
return x
|
| 356 |
+
|
| 357 |
+
if decode:
|
| 358 |
+
assert positions is not None and mask is not None, ( # noqa: PT018
|
| 359 |
+
"Must explicitly pass positions and mask for decoding."
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
if positions is None:
|
| 363 |
+
positions = jnp.arange(seq_len).astype(jnp.int32)[None, :]
|
| 364 |
+
assert positions.shape[1] == x.shape[1], (positions.shape, x.shape)
|
| 365 |
+
|
| 366 |
+
if mask is None:
|
| 367 |
+
mask = nn.attention.make_causal_mask(jnp.ones([batch_size, seq_len]))
|
| 368 |
+
if mask.ndim == 3:
|
| 369 |
+
mask = mask[:, None, :, :]
|
| 370 |
+
cache_size = max(seq_len, mask.shape[-1])
|
| 371 |
+
assert mask.shape == (batch_size, 1, seq_len, cache_size), mask.shape
|
| 372 |
+
|
| 373 |
+
if self.remat_policy == "none":
|
| 374 |
+
block_cls = Block
|
| 375 |
+
else:
|
| 376 |
+
block_cls = nn.remat(
|
| 377 |
+
Block,
|
| 378 |
+
prevent_cse=not self.scan,
|
| 379 |
+
static_argnums=(5, 6), # 0=self, 5=decode, 6=deterministic
|
| 380 |
+
policy=getattr(jax.checkpoint_policies, self.remat_policy),
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
block_kw = {
|
| 384 |
+
"num_heads": self.num_heads,
|
| 385 |
+
"head_dim": self.head_dim,
|
| 386 |
+
"num_kv_heads": self.num_kv_heads,
|
| 387 |
+
"embed_dim": width,
|
| 388 |
+
"hidden_dim": self.mlp_dim,
|
| 389 |
+
"dropout": self.dropout,
|
| 390 |
+
"dropout_bdims": self.dropout_bdims,
|
| 391 |
+
"cache_dtype": self.cache_dtype,
|
| 392 |
+
"lora_configs": self.lora_configs,
|
| 393 |
+
}
|
| 394 |
+
layers = self.scope.push("layers")
|
| 395 |
+
blocks = [
|
| 396 |
+
nn.scan(
|
| 397 |
+
block_cls,
|
| 398 |
+
variable_axes={"params": 0},
|
| 399 |
+
split_rngs={"params": True, "dropout": True},
|
| 400 |
+
in_axes=(0, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast), # 0=kv_cache, 1=positions, 2=mask
|
| 401 |
+
length=self.depth,
|
| 402 |
+
)(parent=layers, **block_kw)
|
| 403 |
+
]
|
| 404 |
+
for block in blocks:
|
| 405 |
+
x, kv_cache = block(x, kv_cache, positions, mask, decode, deterministic)
|
| 406 |
+
|
| 407 |
+
assert x.dtype == jnp.dtype(self.embed_dtype) # Sanity check.
|
| 408 |
+
out["encoded"] = x
|
| 409 |
+
|
| 410 |
+
x = RMSNorm(name="final_norm")(x)
|
| 411 |
+
out["pre_logits"] = x
|
| 412 |
+
if return_prelogits:
|
| 413 |
+
return x, kv_cache, out
|
| 414 |
+
|
| 415 |
+
x = embedder.decode(x)
|
| 416 |
+
out["logits"] = x
|
| 417 |
+
|
| 418 |
+
return x, kv_cache, out
|
| 419 |
+
|
| 420 |
+
def init(self):
|
| 421 |
+
"""Convenience method for initializing all parameters, necessary due to the quirks of linen."""
|
| 422 |
+
self(jnp.zeros((1, 1), dtype=jnp.int32))
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def _apply_rope(x, *, positions, max_wavelength=10_000):
|
| 426 |
+
"""Applies RoPE positions [B, L] to x [B, L, H, D]."""
|
| 427 |
+
freq_exponents = (2.0 / x.shape[-1]) * jnp.arange(x.shape[-1] // 2, dtype=jnp.float32)
|
| 428 |
+
timescale = max_wavelength**freq_exponents
|
| 429 |
+
radians = positions[..., None] / timescale[None, None, :]
|
| 430 |
+
radians = radians[..., None, :]
|
| 431 |
+
assert radians.dtype == jnp.float32
|
| 432 |
+
# radians.shape = [...,L,1,d=D/2]
|
| 433 |
+
sin, cos = jnp.sin(radians), jnp.cos(radians)
|
| 434 |
+
x1, x2 = jnp.split(x, 2, axis=-1)
|
| 435 |
+
res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)
|
| 436 |
+
assert res.dtype == jnp.float32
|
| 437 |
+
return res
|
capvector-pi05/src/openpi/models/lora.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
import flax.linen as nn
|
| 5 |
+
import flax.struct as struct
|
| 6 |
+
import jax.numpy as jnp
|
| 7 |
+
|
| 8 |
+
import openpi.shared.array_typing as at
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@struct.dataclass
|
| 12 |
+
class LoRAConfig:
|
| 13 |
+
"""Configuration for LoRA."""
|
| 14 |
+
|
| 15 |
+
# LoRA rank.
|
| 16 |
+
rank: int
|
| 17 |
+
# LoRA scaling factor.
|
| 18 |
+
alpha: float = 1.0
|
| 19 |
+
# Initialization function for LoRA parameters.
|
| 20 |
+
init_fn: nn.initializers.Initializer = nn.initializers.normal(stddev=0.01)
|
| 21 |
+
# Enable rank-stabilized LoRA: https://arxiv.org/pdf/2312.03732
|
| 22 |
+
rslora: bool = False
|
| 23 |
+
# Axes in the weight to apply LoRA to. Should typically be the last two axes.
|
| 24 |
+
axes: tuple[int, int] = (-2, -1)
|
| 25 |
+
# Axis label which is used by LoRA in einsum equations. Must not be present in the original equation.
|
| 26 |
+
label: str = "L"
|
| 27 |
+
|
| 28 |
+
@property
|
| 29 |
+
def scaling_value(self) -> float:
|
| 30 |
+
return self.alpha / math.sqrt(self.rank) if self.rslora else self.alpha / self.rank
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Einsum(nn.Module):
|
| 34 |
+
"""Einsum with LoRA support. Can be used as a drop-in replacement for the Gemma Einsum."""
|
| 35 |
+
|
| 36 |
+
# Shape of the weight.
|
| 37 |
+
shape: tuple[int, ...]
|
| 38 |
+
# Initialization function for the weight.
|
| 39 |
+
init_fn: nn.initializers.Initializer = nn.initializers.zeros
|
| 40 |
+
# If not None, apply LoRA to the weight.
|
| 41 |
+
lora_config: LoRAConfig | None = None
|
| 42 |
+
|
| 43 |
+
def setup(self):
|
| 44 |
+
self.w = self.param("w", self.init_fn, self.shape)
|
| 45 |
+
|
| 46 |
+
if config := self.lora_config:
|
| 47 |
+
# Setup LoRA parameters.
|
| 48 |
+
shape_a, shape_b = list(self.shape), list(self.shape)
|
| 49 |
+
shape_a[config.axes[1]] = config.rank
|
| 50 |
+
shape_b[config.axes[0]] = config.rank
|
| 51 |
+
self.w_a = self.param("lora_a", config.init_fn, shape_a)
|
| 52 |
+
self.w_b = self.param("lora_b", config.init_fn, shape_b)
|
| 53 |
+
|
| 54 |
+
@nn.compact
|
| 55 |
+
def __call__(self, eqn: str, x):
|
| 56 |
+
dtype = x.dtype # original dtype, could be half-precision
|
| 57 |
+
result = jnp.einsum(eqn, x, self.w.astype(dtype))
|
| 58 |
+
|
| 59 |
+
if config := self.lora_config:
|
| 60 |
+
eqn_a, eqn_b = self._make_lora_eqns(eqn)
|
| 61 |
+
lora = jnp.einsum(eqn_a, x, self.w_a.astype(dtype))
|
| 62 |
+
lora = jnp.einsum(eqn_b, lora, self.w_b.astype(dtype))
|
| 63 |
+
result = result + lora * config.scaling_value
|
| 64 |
+
|
| 65 |
+
return result
|
| 66 |
+
|
| 67 |
+
def _make_lora_eqns(self, eqn: str) -> tuple[str, str]:
|
| 68 |
+
if "L" in eqn:
|
| 69 |
+
raise ValueError(f"L already in eqn: {eqn}")
|
| 70 |
+
if not (m := re.match("(.*),(.*)->(.*)", eqn)):
|
| 71 |
+
raise ValueError(f"Unsupported einsum eqn: {eqn}")
|
| 72 |
+
lhs, rhs, out = m.groups()
|
| 73 |
+
|
| 74 |
+
assert self.lora_config is not None
|
| 75 |
+
a_label, b_label = (rhs[x] for x in self.lora_config.axes)
|
| 76 |
+
label = self.lora_config.label
|
| 77 |
+
|
| 78 |
+
a_rhs = rhs.replace(b_label, label)
|
| 79 |
+
a_out = out.replace(b_label, label)
|
| 80 |
+
eqn_a = f"{lhs},{a_rhs}->{a_out}"
|
| 81 |
+
|
| 82 |
+
b_rhs = rhs.replace(a_label, label)
|
| 83 |
+
eqn_b = f"{a_out},{b_rhs}->{out}"
|
| 84 |
+
|
| 85 |
+
return eqn_a, eqn_b
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class FeedForward(nn.Module):
|
| 89 |
+
"""Feed forward module."""
|
| 90 |
+
|
| 91 |
+
features: int
|
| 92 |
+
hidden_dim: int
|
| 93 |
+
# If not None, apply LoRA to the weight.
|
| 94 |
+
lora_config: LoRAConfig | None = None
|
| 95 |
+
|
| 96 |
+
def setup(self):
|
| 97 |
+
self.w_gating = self.param(
|
| 98 |
+
"gating_einsum",
|
| 99 |
+
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
|
| 100 |
+
(2, self.features, self.hidden_dim),
|
| 101 |
+
)
|
| 102 |
+
self.w_linear = self.param(
|
| 103 |
+
"linear",
|
| 104 |
+
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1),
|
| 105 |
+
(self.hidden_dim, self.features),
|
| 106 |
+
)
|
| 107 |
+
self.w_gating_lora = None
|
| 108 |
+
self.w_linear_lora = None
|
| 109 |
+
if self.lora_config:
|
| 110 |
+
# Setup LoRA parameters.
|
| 111 |
+
# TODO: follow up with a simplified init_fn api.
|
| 112 |
+
self.w_gating_lora = (
|
| 113 |
+
self.param("gating_einsum_lora_a", self.lora_config.init_fn, (2, self.features, self.lora_config.rank)),
|
| 114 |
+
self.param(
|
| 115 |
+
"gating_einsum_lora_b", self.lora_config.init_fn, (2, self.lora_config.rank, self.hidden_dim)
|
| 116 |
+
),
|
| 117 |
+
)
|
| 118 |
+
self.w_linear_lora = (
|
| 119 |
+
self.param("linear_lora_a", self.lora_config.init_fn, (self.hidden_dim, self.lora_config.rank)),
|
| 120 |
+
self.param("linear_lora_b", self.lora_config.init_fn, (self.lora_config.rank, self.features)),
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
@nn.compact
|
| 124 |
+
def __call__(self, x):
|
| 125 |
+
dtype = x.dtype # original dtype, could be half-precision
|
| 126 |
+
ff_gate = self._dot(
|
| 127 |
+
x,
|
| 128 |
+
self.w_gating[0],
|
| 129 |
+
None if self.w_gating_lora is None else (self.w_gating_lora[0][0], self.w_gating_lora[1][0]),
|
| 130 |
+
)
|
| 131 |
+
gate_value = nn.gelu(ff_gate)
|
| 132 |
+
|
| 133 |
+
ff1 = self._dot(
|
| 134 |
+
x,
|
| 135 |
+
self.w_gating[1],
|
| 136 |
+
None if self.w_gating_lora is None else (self.w_gating_lora[0][1], self.w_gating_lora[1][1]),
|
| 137 |
+
)
|
| 138 |
+
activations = gate_value * ff1
|
| 139 |
+
|
| 140 |
+
outputs = self._dot(activations, self.w_linear, self.w_linear_lora)
|
| 141 |
+
assert outputs.dtype == dtype
|
| 142 |
+
return outputs
|
| 143 |
+
|
| 144 |
+
def _dot(self, x: at.Array, w: at.Array, lora_weights: tuple[at.Array, at.Array] | None) -> at.Array:
|
| 145 |
+
base = jnp.dot(x, w.astype(x.dtype))
|
| 146 |
+
if lora_weights is None:
|
| 147 |
+
return base
|
| 148 |
+
return base + jnp.dot(jnp.dot(x, lora_weights[0].astype(x.dtype)), lora_weights[1].astype(x.dtype))
|
capvector-pi05/src/openpi/models/lora_test.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import flax.linen as nn
|
| 2 |
+
import jax
|
| 3 |
+
import jax.numpy as jnp
|
| 4 |
+
|
| 5 |
+
import openpi.models.lora as lora
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def test_lora_einsum_params_shape():
|
| 9 |
+
shape = (3, 8, 32, 4) # (3KDH)
|
| 10 |
+
einsum = lora.Einsum(shape)
|
| 11 |
+
lora0 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2))
|
| 12 |
+
lora1 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2, axes=(1, 2)))
|
| 13 |
+
|
| 14 |
+
key = jax.random.key(0)
|
| 15 |
+
x = jax.random.normal(key, (8, 64, 32)) # (BSD)
|
| 16 |
+
eqn = "BSD,3KDH->3BSKH"
|
| 17 |
+
|
| 18 |
+
# Ensure that lora parameters are not initialized when LoRA is not used.
|
| 19 |
+
params = einsum.init(key, eqn, x)
|
| 20 |
+
assert "lora_a" not in params["params"]
|
| 21 |
+
assert "lora_b" not in params["params"]
|
| 22 |
+
|
| 23 |
+
# Check that default axes work.
|
| 24 |
+
params_lora0 = lora0.init(key, eqn, x)
|
| 25 |
+
assert params_lora0["params"]["lora_a"].shape == (3, 8, 32, 2)
|
| 26 |
+
assert params_lora0["params"]["lora_b"].shape == (3, 8, 2, 4)
|
| 27 |
+
|
| 28 |
+
# Check that user provided axes work.
|
| 29 |
+
params_lora1 = lora1.init(key, eqn, x)
|
| 30 |
+
assert params_lora1["params"]["lora_a"].shape == (3, 8, 2, 4)
|
| 31 |
+
assert params_lora1["params"]["lora_b"].shape == (3, 2, 32, 4)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def test_lora_einsum_same_output():
|
| 35 |
+
shape = (3, 8, 32, 4) # (3KDH)
|
| 36 |
+
einsum = lora.Einsum(shape)
|
| 37 |
+
einsum_lora = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros))
|
| 38 |
+
|
| 39 |
+
key = jax.random.key(0)
|
| 40 |
+
x = jax.random.normal(key, (8, 64, 32)) # (BSD)
|
| 41 |
+
eqn = "BSD,3KDH->3BSKH"
|
| 42 |
+
|
| 43 |
+
params = einsum.init(key, eqn, x)
|
| 44 |
+
output = einsum.apply(params, eqn, x)
|
| 45 |
+
|
| 46 |
+
params_lora = einsum_lora.init(key, eqn, x)
|
| 47 |
+
output_lora = einsum_lora.apply(params_lora, eqn, x)
|
| 48 |
+
|
| 49 |
+
# Results are the same since the LoRA parameters are initialized to zeros.
|
| 50 |
+
assert jnp.allclose(output, output_lora)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def test_lora_ffn_params_shape():
|
| 54 |
+
ffn = lora.FeedForward(features=8, hidden_dim=32)
|
| 55 |
+
ffn_lora = lora.FeedForward(
|
| 56 |
+
features=8,
|
| 57 |
+
hidden_dim=32,
|
| 58 |
+
lora_config=lora.LoRAConfig(rank=2),
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
key = jax.random.key(0)
|
| 62 |
+
x = jax.random.normal(key, (2, 8))
|
| 63 |
+
|
| 64 |
+
params = ffn.init(key, x)
|
| 65 |
+
assert params["params"]["gating_einsum"].shape == (2, 8, 32)
|
| 66 |
+
assert params["params"]["linear"].shape == (32, 8)
|
| 67 |
+
|
| 68 |
+
params_lora = ffn_lora.init(key, x)
|
| 69 |
+
assert params_lora["params"]["gating_einsum"].shape == (2, 8, 32)
|
| 70 |
+
assert params_lora["params"]["linear"].shape == (32, 8)
|
| 71 |
+
assert params_lora["params"]["gating_einsum_lora_a"].shape == (2, 8, 2)
|
| 72 |
+
assert params_lora["params"]["gating_einsum_lora_b"].shape == (2, 2, 32)
|
| 73 |
+
assert params_lora["params"]["linear_lora_a"].shape == (32, 2)
|
| 74 |
+
assert params_lora["params"]["linear_lora_b"].shape == (2, 8)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def test_lora_ffn_same_output():
|
| 78 |
+
ffn = lora.FeedForward(features=8, hidden_dim=32)
|
| 79 |
+
ffn_lora = lora.FeedForward(
|
| 80 |
+
features=8,
|
| 81 |
+
hidden_dim=32,
|
| 82 |
+
lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros),
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
key = jax.random.key(0)
|
| 86 |
+
x = jax.random.normal(key, (2, 8))
|
| 87 |
+
|
| 88 |
+
params = ffn.init(key, x)
|
| 89 |
+
output = ffn.apply(params, x)
|
| 90 |
+
|
| 91 |
+
params_lora = ffn_lora.init(key, x)
|
| 92 |
+
output_lora = ffn_lora.apply(params_lora, x)
|
| 93 |
+
|
| 94 |
+
assert jnp.allclose(output, output_lora)
|
capvector-pi05/src/openpi/models/model.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
from collections.abc import Sequence
|
| 3 |
+
import dataclasses
|
| 4 |
+
import enum
|
| 5 |
+
import logging
|
| 6 |
+
import pathlib
|
| 7 |
+
from typing import Generic, TypeVar
|
| 8 |
+
|
| 9 |
+
import augmax
|
| 10 |
+
from flax import nnx
|
| 11 |
+
from flax import struct
|
| 12 |
+
from flax import traverse_util
|
| 13 |
+
import jax
|
| 14 |
+
import jax.numpy as jnp
|
| 15 |
+
import numpy as np
|
| 16 |
+
import orbax.checkpoint as ocp
|
| 17 |
+
import safetensors
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from openpi.models_pytorch import pi0_pytorch
|
| 21 |
+
from openpi.shared import image_tools
|
| 22 |
+
import openpi.shared.array_typing as at
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger("openpi")
|
| 25 |
+
|
| 26 |
+
# Type variable for array types (JAX arrays, PyTorch tensors, or numpy arrays)
|
| 27 |
+
ArrayT = TypeVar("ArrayT", bound=jax.Array | torch.Tensor | np.ndarray)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class ModelType(enum.Enum):
|
| 31 |
+
"""Supported model types."""
|
| 32 |
+
|
| 33 |
+
PI0 = "pi0"
|
| 34 |
+
PI0_FAST = "pi0_fast"
|
| 35 |
+
PI05 = "pi05"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# The model always expects these images
|
| 39 |
+
IMAGE_KEYS = (
|
| 40 |
+
"base_0_rgb",
|
| 41 |
+
"left_wrist_0_rgb",
|
| 42 |
+
"right_wrist_0_rgb",
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# This may need change if we release a small model.
|
| 47 |
+
IMAGE_RESOLUTION = (224, 224)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# Data format
|
| 51 |
+
#
|
| 52 |
+
# Data transforms produce the model input as a nested dictionary which is later converted
|
| 53 |
+
# into `Obesrvation` and `Actions` objects. See below.
|
| 54 |
+
#
|
| 55 |
+
# In the dictory form, this data should look like:
|
| 56 |
+
# {
|
| 57 |
+
# # Observation data.
|
| 58 |
+
# "image": {
|
| 59 |
+
# "base_0_rgb": (float32|uint8)[*b, h, w, 3], # RGB image in [-1, 1] or [0, 255]
|
| 60 |
+
# ... # Additional camera views
|
| 61 |
+
# },
|
| 62 |
+
# "image_mask": {
|
| 63 |
+
# "base_0_rgb": bool[*b], # True if image is valid
|
| 64 |
+
# ... # Masks for additional views
|
| 65 |
+
# },
|
| 66 |
+
# "state": float32[*b, s], # Low-dimensional robot state
|
| 67 |
+
# "tokenized_prompt": int32[*b, l], # Optional, tokenized language prompt
|
| 68 |
+
# "tokenized_prompt_mask": bool[*b, l], # Optional, mask for tokenized prompt
|
| 69 |
+
# "token_ar_mask": int32[*b, l], # Optional, autoregressive mask for FAST model
|
| 70 |
+
# "token_loss_mask": bool[*b, l], # Optional, loss mask for FAST model
|
| 71 |
+
#
|
| 72 |
+
# # Actions data.
|
| 73 |
+
# "actions": float32[*b ah ad]
|
| 74 |
+
# }
|
| 75 |
+
# where:
|
| 76 |
+
# *b = batch dimensions
|
| 77 |
+
# h,w = image height/width
|
| 78 |
+
# s = state dimension
|
| 79 |
+
# l = sequence length
|
| 80 |
+
#
|
| 81 |
+
@at.typecheck
|
| 82 |
+
@struct.dataclass
|
| 83 |
+
class Observation(Generic[ArrayT]):
|
| 84 |
+
"""Holds observations, i.e., inputs to the model.
|
| 85 |
+
|
| 86 |
+
See `Observation.from_dict` to see the expected dictionary form. This is the format
|
| 87 |
+
that should be produced by the data transforms.
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
# Images, in [-1, 1] float32.
|
| 91 |
+
images: dict[str, at.Float[ArrayT, "*b h w c"]]
|
| 92 |
+
# the padding area for non-rectangular input images is False
|
| 93 |
+
image_padding_mask: dict[str, at.Bool[ArrayT, "*b w c"]]
|
| 94 |
+
# Image masks, with same keys as images.
|
| 95 |
+
image_masks: dict[str, at.Bool[ArrayT, "*b"]]
|
| 96 |
+
# Low-dimensional robot state.
|
| 97 |
+
state: at.Float[ArrayT, "*b s"]
|
| 98 |
+
|
| 99 |
+
# Tokenized prompt.
|
| 100 |
+
tokenized_prompt: at.Int[ArrayT, "*b l"] | None = None
|
| 101 |
+
# Tokenized prompt mask.
|
| 102 |
+
tokenized_prompt_mask: at.Bool[ArrayT, "*b l"] | None = None
|
| 103 |
+
|
| 104 |
+
# pi0-fast model specific fields.
|
| 105 |
+
|
| 106 |
+
# Token auto-regressive mask (for FAST autoregressive model).
|
| 107 |
+
token_ar_mask: at.Int[ArrayT, "*b l"] | None = None
|
| 108 |
+
# Token loss mask (for FAST autoregressive model).
|
| 109 |
+
token_loss_mask: at.Bool[ArrayT, "*b l"] | None = None
|
| 110 |
+
|
| 111 |
+
@classmethod
|
| 112 |
+
def from_dict(cls, data: at.PyTree[ArrayT]) -> "Observation[ArrayT]":
|
| 113 |
+
"""This method defines the mapping between unstructured data (i.e., nested dict) to the structured Observation format."""
|
| 114 |
+
# Ensure that tokenized_prompt and tokenized_prompt_mask are provided together.
|
| 115 |
+
if ("tokenized_prompt" in data) != ("tokenized_prompt_mask" in data):
|
| 116 |
+
raise ValueError("tokenized_prompt and tokenized_prompt_mask must be provided together.")
|
| 117 |
+
# If images are uint8, convert them to [-1, 1] float32.
|
| 118 |
+
for key in data["image"]:
|
| 119 |
+
if data["image"][key].dtype == np.uint8:
|
| 120 |
+
data["image"][key] = data["image"][key].astype(np.float32) / 255.0 * 2.0 - 1.0
|
| 121 |
+
elif hasattr(data["image"][key], "dtype") and data["image"][key].dtype == torch.uint8:
|
| 122 |
+
data["image"][key] = data["image"][key].to(torch.float32).permute(0, 3, 1, 2) / 255.0 * 2.0 - 1.0
|
| 123 |
+
return cls(
|
| 124 |
+
images=data["image"],
|
| 125 |
+
image_padding_mask=data.get("image_padding_mask", {}),
|
| 126 |
+
image_masks=data["image_mask"],
|
| 127 |
+
state=data["state"],
|
| 128 |
+
tokenized_prompt=data.get("tokenized_prompt"),
|
| 129 |
+
tokenized_prompt_mask=data.get("tokenized_prompt_mask"),
|
| 130 |
+
token_ar_mask=data.get("token_ar_mask"),
|
| 131 |
+
token_loss_mask=data.get("token_loss_mask"),
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
def to_dict(self) -> at.PyTree[ArrayT]:
|
| 135 |
+
"""Convert the Observation to a nested dict."""
|
| 136 |
+
result = dataclasses.asdict(self)
|
| 137 |
+
result["image"] = result.pop("images")
|
| 138 |
+
result["image_mask"] = result.pop("image_masks")
|
| 139 |
+
return result
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# Defines the format of the actions. This field is included as "actions" inside the dictionary
|
| 143 |
+
# produced by the data transforms.
|
| 144 |
+
Actions = at.Float[ArrayT, "*b ah ad"]
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def preprocess_observation(
|
| 148 |
+
rng: at.KeyArrayLike | None,
|
| 149 |
+
observation: Observation,
|
| 150 |
+
*,
|
| 151 |
+
train: bool = False,
|
| 152 |
+
image_keys: Sequence[str] = IMAGE_KEYS,
|
| 153 |
+
image_resolution: tuple[int, int] = IMAGE_RESOLUTION,
|
| 154 |
+
) -> Observation:
|
| 155 |
+
"""Preprocess the observations by performing image augmentations (if train=True), resizing (if necessary), and
|
| 156 |
+
filling in a default image mask (if necessary).
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
if not set(image_keys).issubset(observation.images):
|
| 160 |
+
raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}")
|
| 161 |
+
|
| 162 |
+
batch_shape = observation.state.shape[:-1]
|
| 163 |
+
|
| 164 |
+
out_images = {}
|
| 165 |
+
for key in image_keys:
|
| 166 |
+
image = observation.images[key]
|
| 167 |
+
if image.shape[1:3] != image_resolution:
|
| 168 |
+
logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}")
|
| 169 |
+
image = image_tools.resize_with_pad(image, *image_resolution)
|
| 170 |
+
|
| 171 |
+
if train:
|
| 172 |
+
# Convert from [-1, 1] to [0, 1] for augmax.
|
| 173 |
+
image = image / 2.0 + 0.5
|
| 174 |
+
|
| 175 |
+
transforms = []
|
| 176 |
+
if "wrist" not in key:
|
| 177 |
+
height, width = image.shape[1:3]
|
| 178 |
+
transforms += [
|
| 179 |
+
augmax.RandomCrop(int(width * 0.95), int(height * 0.95)),
|
| 180 |
+
augmax.Resize(width, height),
|
| 181 |
+
augmax.Rotate((-5, 5)),
|
| 182 |
+
]
|
| 183 |
+
transforms += [
|
| 184 |
+
augmax.ColorJitter(brightness=0.3, contrast=0.4, saturation=0.5),
|
| 185 |
+
]
|
| 186 |
+
sub_rngs = jax.random.split(rng, image.shape[0])
|
| 187 |
+
image = jax.vmap(augmax.Chain(*transforms))(sub_rngs, image)
|
| 188 |
+
|
| 189 |
+
# Back to [-1, 1].
|
| 190 |
+
image = image * 2.0 - 1.0
|
| 191 |
+
|
| 192 |
+
out_images[key] = image
|
| 193 |
+
|
| 194 |
+
# obtain mask
|
| 195 |
+
out_masks = {}
|
| 196 |
+
for key in out_images:
|
| 197 |
+
if key not in observation.image_masks:
|
| 198 |
+
# do not mask by default
|
| 199 |
+
out_masks[key] = jnp.ones(batch_shape, dtype=jnp.bool)
|
| 200 |
+
else:
|
| 201 |
+
out_masks[key] = jnp.asarray(observation.image_masks[key])
|
| 202 |
+
|
| 203 |
+
return Observation(
|
| 204 |
+
images=out_images,
|
| 205 |
+
image_masks=out_masks,
|
| 206 |
+
state=observation.state,
|
| 207 |
+
tokenized_prompt=observation.tokenized_prompt,
|
| 208 |
+
tokenized_prompt_mask=observation.tokenized_prompt_mask,
|
| 209 |
+
token_ar_mask=observation.token_ar_mask,
|
| 210 |
+
token_loss_mask=observation.token_loss_mask,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
@dataclasses.dataclass(frozen=True)
|
| 215 |
+
class BaseModelConfig(abc.ABC):
|
| 216 |
+
"""Configuration shared by all models. Specific models should inherit from this class, and implement the `create`
|
| 217 |
+
method to create the corresponding model.
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
# Action space dimension.
|
| 221 |
+
action_dim: int
|
| 222 |
+
# Action sequence length.
|
| 223 |
+
action_horizon: int
|
| 224 |
+
# Tokenized prompt maximum length.
|
| 225 |
+
max_token_len: int
|
| 226 |
+
|
| 227 |
+
@property
|
| 228 |
+
@abc.abstractmethod
|
| 229 |
+
def model_type(self) -> ModelType:
|
| 230 |
+
"""The model type."""
|
| 231 |
+
|
| 232 |
+
@abc.abstractmethod
|
| 233 |
+
def create(self, rng: at.KeyArrayLike) -> "BaseModel":
|
| 234 |
+
"""Create a new model, initializing parameters."""
|
| 235 |
+
|
| 236 |
+
def load(self, params: at.Params, *, remove_extra_params: bool = True) -> "BaseModel":
|
| 237 |
+
"""Create a model with the given parameters."""
|
| 238 |
+
model = nnx.eval_shape(self.create, jax.random.key(0))
|
| 239 |
+
graphdef, state = nnx.split(model)
|
| 240 |
+
if remove_extra_params:
|
| 241 |
+
params = ocp.transform_utils.intersect_trees(state.to_pure_dict(), params)
|
| 242 |
+
at.check_pytree_equality(expected=state.to_pure_dict(), got=params, check_shapes=True, check_dtypes=False)
|
| 243 |
+
state.replace_by_pure_dict(params)
|
| 244 |
+
return nnx.merge(graphdef, state)
|
| 245 |
+
|
| 246 |
+
def load_pytorch(self, train_config, weight_path: str):
|
| 247 |
+
logger.info(f"train_config: {train_config}")
|
| 248 |
+
model = pi0_pytorch.PI0Pytorch(config=train_config.model)
|
| 249 |
+
safetensors.torch.load_model(model, weight_path)
|
| 250 |
+
return model
|
| 251 |
+
|
| 252 |
+
@abc.abstractmethod
|
| 253 |
+
def inputs_spec(self, *, batch_size: int = 1) -> tuple[Observation, Actions]:
|
| 254 |
+
"""Returns the input specification for the model. Values are jax.ShapeDtypeStruct."""
|
| 255 |
+
|
| 256 |
+
def fake_obs(self, batch_size: int = 1) -> Observation:
|
| 257 |
+
observation_spec, _ = self.inputs_spec(batch_size=batch_size)
|
| 258 |
+
return jax.tree.map(lambda x: jnp.ones(x.shape, x.dtype), observation_spec)
|
| 259 |
+
|
| 260 |
+
def fake_act(self, batch_size: int = 1) -> Actions:
|
| 261 |
+
_, action_spec = self.inputs_spec(batch_size=batch_size)
|
| 262 |
+
return jax.tree.map(lambda x: jnp.ones(x.shape, x.dtype), action_spec)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
@dataclasses.dataclass
|
| 266 |
+
class BaseModel(nnx.Module, abc.ABC):
|
| 267 |
+
"""Base class for all model implementations. Specific models should inherit from this class. They should call
|
| 268 |
+
super().__init__() to initialize the shared attributes (action_dim, action_horizon, and max_token_len).
|
| 269 |
+
"""
|
| 270 |
+
|
| 271 |
+
action_dim: int
|
| 272 |
+
action_horizon: int
|
| 273 |
+
max_token_len: int
|
| 274 |
+
|
| 275 |
+
@abc.abstractmethod
|
| 276 |
+
def compute_loss(
|
| 277 |
+
self,
|
| 278 |
+
rng: at.KeyArrayLike,
|
| 279 |
+
observation: Observation,
|
| 280 |
+
actions: Actions,
|
| 281 |
+
*,
|
| 282 |
+
train: bool = False,
|
| 283 |
+
) -> at.Float[at.Array, "*b ah"]: ...
|
| 284 |
+
|
| 285 |
+
@abc.abstractmethod
|
| 286 |
+
def sample_actions(self, rng: at.KeyArrayLike, observation: Observation, **kwargs) -> Actions: ...
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def restore_params(
|
| 290 |
+
params_path: pathlib.Path | str,
|
| 291 |
+
*,
|
| 292 |
+
restore_type: type[np.ndarray] | type[jax.Array] = jax.Array,
|
| 293 |
+
dtype: jnp.dtype | None = None,
|
| 294 |
+
sharding: jax.sharding.Sharding | None = None,
|
| 295 |
+
) -> at.Params:
|
| 296 |
+
"""Restores unstructured params PyTree from a checkpoint.
|
| 297 |
+
|
| 298 |
+
This works with checkpoints saved with `save_state` during openpi training (see `training/checkpoints.py`) as
|
| 299 |
+
well as pre-trained checkpoints released for openpi.
|
| 300 |
+
|
| 301 |
+
Args:
|
| 302 |
+
params_path: The local path to the checkpoint directory.
|
| 303 |
+
restore_type: The type to restore the params as. Can be set to `np.ndarray` to load the params as a numpy array.
|
| 304 |
+
dtype: The dtype to restore all params as. If not provided, will use the original dtype from the checkpoint.
|
| 305 |
+
sharding: The sharding to use for the params. If not provided, the params will be replicated across all devices.
|
| 306 |
+
|
| 307 |
+
Returns:
|
| 308 |
+
The restored params.
|
| 309 |
+
"""
|
| 310 |
+
params_path = pathlib.Path(params_path).resolve() if not str(params_path).startswith("gs://") else params_path
|
| 311 |
+
|
| 312 |
+
if restore_type is jax.Array and sharding is None:
|
| 313 |
+
mesh = jax.sharding.Mesh(jax.devices(), ("x",))
|
| 314 |
+
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
| 315 |
+
|
| 316 |
+
with ocp.PyTreeCheckpointer() as ckptr:
|
| 317 |
+
metadata = ckptr.metadata(params_path)
|
| 318 |
+
item = {"params": metadata["params"]}
|
| 319 |
+
|
| 320 |
+
params = ckptr.restore(
|
| 321 |
+
params_path,
|
| 322 |
+
ocp.args.PyTreeRestore(
|
| 323 |
+
item=item,
|
| 324 |
+
restore_args=jax.tree.map(
|
| 325 |
+
lambda _: ocp.ArrayRestoreArgs(sharding=sharding, restore_type=restore_type, dtype=dtype), item
|
| 326 |
+
),
|
| 327 |
+
),
|
| 328 |
+
)["params"]
|
| 329 |
+
|
| 330 |
+
# If the params were saved with `save_state` during openpi training, every key path will end with "value", which is
|
| 331 |
+
# added by `nnx.State`. We remove the "value" suffix here and always return what NNX calls a "pure dict".
|
| 332 |
+
flat_params = traverse_util.flatten_dict(params)
|
| 333 |
+
if all(kp[-1] == "value" for kp in flat_params):
|
| 334 |
+
flat_params = {kp[:-1]: v for kp, v in flat_params.items()}
|
| 335 |
+
return traverse_util.unflatten_dict(flat_params)
|
capvector-pi05/src/openpi/models/model_test.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from flax import nnx
|
| 2 |
+
import jax
|
| 3 |
+
import pytest
|
| 4 |
+
|
| 5 |
+
from openpi.models import model as _model
|
| 6 |
+
from openpi.models import pi0_config
|
| 7 |
+
from openpi.models import pi0_fast
|
| 8 |
+
from openpi.shared import download
|
| 9 |
+
from openpi.shared import nnx_utils
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def test_pi0_model():
|
| 13 |
+
key = jax.random.key(0)
|
| 14 |
+
config = pi0_config.Pi0Config()
|
| 15 |
+
model = config.create(key)
|
| 16 |
+
|
| 17 |
+
batch_size = 2
|
| 18 |
+
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
|
| 19 |
+
|
| 20 |
+
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
|
| 21 |
+
assert loss.shape == (batch_size, config.action_horizon)
|
| 22 |
+
|
| 23 |
+
actions = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10)
|
| 24 |
+
assert actions.shape == (batch_size, model.action_horizon, model.action_dim)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def test_pi0_lora_model():
|
| 28 |
+
key = jax.random.key(0)
|
| 29 |
+
config = pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora")
|
| 30 |
+
model = config.create(key)
|
| 31 |
+
|
| 32 |
+
batch_size = 2
|
| 33 |
+
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
|
| 34 |
+
|
| 35 |
+
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
|
| 36 |
+
assert loss.shape == (batch_size, config.action_horizon)
|
| 37 |
+
|
| 38 |
+
actions = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10)
|
| 39 |
+
assert actions.shape == (batch_size, model.action_horizon, model.action_dim)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def test_pi0_fast_model():
|
| 43 |
+
key = jax.random.key(0)
|
| 44 |
+
config = pi0_fast.Pi0FASTConfig()
|
| 45 |
+
model = config.create(key)
|
| 46 |
+
|
| 47 |
+
batch_size = 2
|
| 48 |
+
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
|
| 49 |
+
|
| 50 |
+
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
|
| 51 |
+
assert loss.shape == (batch_size,)
|
| 52 |
+
|
| 53 |
+
actions = nnx_utils.module_jit(model.sample_actions)(key, obs)
|
| 54 |
+
assert actions.shape == (batch_size, 256)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def test_pi0_fast_lora_model():
|
| 58 |
+
key = jax.random.key(0)
|
| 59 |
+
config = pi0_fast.Pi0FASTConfig(paligemma_variant="gemma_2b_lora")
|
| 60 |
+
model = config.create(key)
|
| 61 |
+
|
| 62 |
+
batch_size = 2
|
| 63 |
+
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
|
| 64 |
+
|
| 65 |
+
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
|
| 66 |
+
assert loss.shape == (batch_size,)
|
| 67 |
+
|
| 68 |
+
actions = nnx_utils.module_jit(model.sample_actions)(key, obs)
|
| 69 |
+
assert actions.shape == (batch_size, 256)
|
| 70 |
+
|
| 71 |
+
lora_filter = nnx_utils.PathRegex(".*lora.*")
|
| 72 |
+
model_state = nnx.state(model)
|
| 73 |
+
|
| 74 |
+
lora_state_elems = list(model_state.filter(lora_filter))
|
| 75 |
+
assert len(lora_state_elems) > 0
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@pytest.mark.manual
|
| 79 |
+
def test_model_restore():
|
| 80 |
+
key = jax.random.key(0)
|
| 81 |
+
config = pi0_config.Pi0Config()
|
| 82 |
+
|
| 83 |
+
batch_size = 2
|
| 84 |
+
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
|
| 85 |
+
|
| 86 |
+
model = config.load(
|
| 87 |
+
_model.restore_params(download.maybe_download("gs://openpi-assets/checkpoints/pi0_base/params"))
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
loss = model.compute_loss(key, obs, act)
|
| 91 |
+
assert loss.shape == (batch_size, config.action_horizon)
|
| 92 |
+
|
| 93 |
+
actions = model.sample_actions(key, obs, num_steps=10)
|
| 94 |
+
assert actions.shape == (batch_size, model.action_horizon, model.action_dim)
|
capvector-pi05/src/openpi/models/pi0.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
import einops
|
| 4 |
+
import flax.nnx as nnx
|
| 5 |
+
import flax.nnx.bridge as nnx_bridge
|
| 6 |
+
import jax
|
| 7 |
+
import jax.numpy as jnp
|
| 8 |
+
from typing_extensions import override
|
| 9 |
+
|
| 10 |
+
from openpi.models import model as _model
|
| 11 |
+
from openpi.models import pi0_config
|
| 12 |
+
import openpi.models.gemma as _gemma
|
| 13 |
+
import openpi.models.siglip as _siglip
|
| 14 |
+
from openpi.shared import array_typing as at
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger("openpi")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def make_attn_mask(input_mask, mask_ar):
|
| 20 |
+
"""Adapted from big_vision.
|
| 21 |
+
|
| 22 |
+
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
|
| 23 |
+
smaller or equal to theirs. This way `mask_ar` bool[?B, N] can be used to
|
| 24 |
+
setup several types of attention, for example:
|
| 25 |
+
|
| 26 |
+
[[1 1 1 1 1 1]]: pure causal attention.
|
| 27 |
+
|
| 28 |
+
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
| 29 |
+
themselves and the last 3 tokens have a causal attention. The first
|
| 30 |
+
entry could also be a 1 without changing behaviour.
|
| 31 |
+
|
| 32 |
+
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
| 33 |
+
block can attend all previous blocks and all tokens on the same block.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
input_mask: bool[B, N] true if its part of the input, false if padding.
|
| 37 |
+
mask_ar: bool[?B, N] mask that's true where previous tokens cannot depend on
|
| 38 |
+
it and false where it shares the same attention mask as the previous token.
|
| 39 |
+
"""
|
| 40 |
+
mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape)
|
| 41 |
+
cumsum = jnp.cumsum(mask_ar, axis=1)
|
| 42 |
+
attn_mask = cumsum[:, None, :] <= cumsum[:, :, None]
|
| 43 |
+
valid_mask = input_mask[:, None, :] * input_mask[:, :, None]
|
| 44 |
+
return jnp.logical_and(attn_mask, valid_mask)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@at.typecheck
|
| 48 |
+
def posemb_sincos(
|
| 49 |
+
pos: at.Real[at.Array, " b"], embedding_dim: int, min_period: float, max_period: float
|
| 50 |
+
) -> at.Float[at.Array, "b {embedding_dim}"]:
|
| 51 |
+
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
| 52 |
+
if embedding_dim % 2 != 0:
|
| 53 |
+
raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by 2")
|
| 54 |
+
|
| 55 |
+
fraction = jnp.linspace(0.0, 1.0, embedding_dim // 2)
|
| 56 |
+
period = min_period * (max_period / min_period) ** fraction
|
| 57 |
+
sinusoid_input = jnp.einsum(
|
| 58 |
+
"i,j->ij",
|
| 59 |
+
pos,
|
| 60 |
+
1.0 / period * 2 * jnp.pi,
|
| 61 |
+
precision=jax.lax.Precision.HIGHEST,
|
| 62 |
+
)
|
| 63 |
+
return jnp.concatenate([jnp.sin(sinusoid_input), jnp.cos(sinusoid_input)], axis=-1)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class Pi0(_model.BaseModel):
|
| 67 |
+
def __init__(self, config: pi0_config.Pi0Config, rngs: nnx.Rngs):
|
| 68 |
+
super().__init__(config.action_dim, config.action_horizon, config.max_token_len)
|
| 69 |
+
self.pi05 = config.pi05
|
| 70 |
+
paligemma_config = _gemma.get_config(config.paligemma_variant)
|
| 71 |
+
action_expert_config = _gemma.get_config(config.action_expert_variant)
|
| 72 |
+
# TODO: rewrite gemma in NNX. For now, use bridge.
|
| 73 |
+
llm = nnx_bridge.ToNNX(
|
| 74 |
+
_gemma.Module(
|
| 75 |
+
configs=[paligemma_config, action_expert_config],
|
| 76 |
+
embed_dtype=config.dtype,
|
| 77 |
+
adarms=config.pi05,
|
| 78 |
+
)
|
| 79 |
+
)
|
| 80 |
+
llm.lazy_init(rngs=rngs, method="init", use_adarms=[False, True] if config.pi05 else [False, False])
|
| 81 |
+
img = nnx_bridge.ToNNX(
|
| 82 |
+
_siglip.Module(
|
| 83 |
+
num_classes=paligemma_config.width,
|
| 84 |
+
variant="So400m/14",
|
| 85 |
+
pool_type="none",
|
| 86 |
+
scan=True,
|
| 87 |
+
dtype_mm=config.dtype,
|
| 88 |
+
)
|
| 89 |
+
)
|
| 90 |
+
img.lazy_init(next(iter(config.fake_obs().images.values())), train=False, rngs=rngs)
|
| 91 |
+
self.PaliGemma = nnx.Dict(llm=llm, img=img)
|
| 92 |
+
self.action_in_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)
|
| 93 |
+
if config.pi05:
|
| 94 |
+
self.time_mlp_in = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)
|
| 95 |
+
self.time_mlp_out = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)
|
| 96 |
+
else:
|
| 97 |
+
self.state_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)
|
| 98 |
+
self.action_time_mlp_in = nnx.Linear(2 * action_expert_config.width, action_expert_config.width, rngs=rngs)
|
| 99 |
+
self.action_time_mlp_out = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)
|
| 100 |
+
self.action_out_proj = nnx.Linear(action_expert_config.width, config.action_dim, rngs=rngs)
|
| 101 |
+
|
| 102 |
+
# This attribute gets automatically set by model.train() and model.eval().
|
| 103 |
+
self.deterministic = True
|
| 104 |
+
|
| 105 |
+
@at.typecheck
|
| 106 |
+
def embed_prefix(
|
| 107 |
+
self, obs: _model.Observation
|
| 108 |
+
) -> tuple[at.Float[at.Array, "b s emb"], at.Bool[at.Array, "b s"], at.Bool[at.Array, " s"]]:
|
| 109 |
+
input_mask = []
|
| 110 |
+
ar_mask = []
|
| 111 |
+
tokens = []
|
| 112 |
+
# embed images
|
| 113 |
+
for name in obs.images:
|
| 114 |
+
image_tokens, _ = self.PaliGemma.img(obs.images[name], train=False)
|
| 115 |
+
|
| 116 |
+
tokens.append(image_tokens)
|
| 117 |
+
input_mask.append(
|
| 118 |
+
einops.repeat(
|
| 119 |
+
obs.image_masks[name],
|
| 120 |
+
"b -> b s",
|
| 121 |
+
s=image_tokens.shape[1],
|
| 122 |
+
)
|
| 123 |
+
)
|
| 124 |
+
# image tokens attend to each other
|
| 125 |
+
ar_mask += [False] * image_tokens.shape[1]
|
| 126 |
+
|
| 127 |
+
# add language (aka tokenized inputs)
|
| 128 |
+
if obs.tokenized_prompt is not None:
|
| 129 |
+
tokenized_inputs = self.PaliGemma.llm(obs.tokenized_prompt, method="embed")
|
| 130 |
+
tokens.append(tokenized_inputs)
|
| 131 |
+
input_mask.append(obs.tokenized_prompt_mask)
|
| 132 |
+
# full attention between image and language inputs
|
| 133 |
+
ar_mask += [False] * tokenized_inputs.shape[1]
|
| 134 |
+
tokens = jnp.concatenate(tokens, axis=1)
|
| 135 |
+
input_mask = jnp.concatenate(input_mask, axis=1)
|
| 136 |
+
ar_mask = jnp.array(ar_mask)
|
| 137 |
+
return tokens, input_mask, ar_mask
|
| 138 |
+
|
| 139 |
+
@at.typecheck
|
| 140 |
+
def embed_suffix(
|
| 141 |
+
self, obs: _model.Observation, noisy_actions: _model.Actions, timestep: at.Float[at.Array, " b"]
|
| 142 |
+
) -> tuple[
|
| 143 |
+
at.Float[at.Array, "b s emb"],
|
| 144 |
+
at.Bool[at.Array, "b s"],
|
| 145 |
+
at.Bool[at.Array, " s"],
|
| 146 |
+
at.Float[at.Array, "b emb"] | None,
|
| 147 |
+
]:
|
| 148 |
+
input_mask = []
|
| 149 |
+
ar_mask = []
|
| 150 |
+
tokens = []
|
| 151 |
+
if not self.pi05:
|
| 152 |
+
# add a single state token
|
| 153 |
+
state_token = self.state_proj(obs.state)[:, None, :]
|
| 154 |
+
tokens.append(state_token)
|
| 155 |
+
input_mask.append(jnp.ones((obs.state.shape[0], 1), dtype=jnp.bool_))
|
| 156 |
+
# image/language inputs do not attend to state or actions
|
| 157 |
+
ar_mask += [True]
|
| 158 |
+
|
| 159 |
+
action_tokens = self.action_in_proj(noisy_actions)
|
| 160 |
+
# embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
|
| 161 |
+
time_emb = posemb_sincos(timestep, self.action_in_proj.out_features, min_period=4e-3, max_period=4.0)
|
| 162 |
+
if self.pi05:
|
| 163 |
+
# time MLP (for adaRMS)
|
| 164 |
+
time_emb = self.time_mlp_in(time_emb)
|
| 165 |
+
time_emb = nnx.swish(time_emb)
|
| 166 |
+
time_emb = self.time_mlp_out(time_emb)
|
| 167 |
+
time_emb = nnx.swish(time_emb)
|
| 168 |
+
action_expert_tokens = action_tokens
|
| 169 |
+
adarms_cond = time_emb
|
| 170 |
+
else:
|
| 171 |
+
# mix timestep + action information using an MLP (no adaRMS)
|
| 172 |
+
time_tokens = einops.repeat(time_emb, "b emb -> b s emb", s=self.action_horizon)
|
| 173 |
+
action_time_tokens = jnp.concatenate([action_tokens, time_tokens], axis=-1)
|
| 174 |
+
action_time_tokens = self.action_time_mlp_in(action_time_tokens)
|
| 175 |
+
action_time_tokens = nnx.swish(action_time_tokens)
|
| 176 |
+
action_time_tokens = self.action_time_mlp_out(action_time_tokens)
|
| 177 |
+
action_expert_tokens = action_time_tokens
|
| 178 |
+
adarms_cond = None
|
| 179 |
+
tokens.append(action_expert_tokens)
|
| 180 |
+
input_mask.append(jnp.ones(action_expert_tokens.shape[:2], dtype=jnp.bool_))
|
| 181 |
+
# image/language/state inputs do not attend to action tokens
|
| 182 |
+
ar_mask += [True] + ([False] * (self.action_horizon - 1))
|
| 183 |
+
tokens = jnp.concatenate(tokens, axis=1)
|
| 184 |
+
input_mask = jnp.concatenate(input_mask, axis=1)
|
| 185 |
+
ar_mask = jnp.array(ar_mask)
|
| 186 |
+
return tokens, input_mask, ar_mask, adarms_cond
|
| 187 |
+
|
| 188 |
+
@override
|
| 189 |
+
def compute_loss(
|
| 190 |
+
self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool = False
|
| 191 |
+
) -> at.Float[at.Array, "*b ah"]:
|
| 192 |
+
preprocess_rng, noise_rng, time_rng = jax.random.split(rng, 3)
|
| 193 |
+
observation = _model.preprocess_observation(preprocess_rng, observation, train=train)
|
| 194 |
+
|
| 195 |
+
batch_shape = actions.shape[:-2]
|
| 196 |
+
noise = jax.random.normal(noise_rng, actions.shape)
|
| 197 |
+
time = jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 + 0.001
|
| 198 |
+
time_expanded = time[..., None, None]
|
| 199 |
+
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
| 200 |
+
u_t = noise - actions
|
| 201 |
+
|
| 202 |
+
# one big forward pass of prefix + suffix at once
|
| 203 |
+
prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
|
| 204 |
+
suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix(observation, x_t, time)
|
| 205 |
+
input_mask = jnp.concatenate([prefix_mask, suffix_mask], axis=1)
|
| 206 |
+
ar_mask = jnp.concatenate([prefix_ar_mask, suffix_ar_mask], axis=0)
|
| 207 |
+
attn_mask = make_attn_mask(input_mask, ar_mask)
|
| 208 |
+
positions = jnp.cumsum(input_mask, axis=1) - 1
|
| 209 |
+
(prefix_out, suffix_out), _ = self.PaliGemma.llm(
|
| 210 |
+
[prefix_tokens, suffix_tokens], mask=attn_mask, positions=positions, adarms_cond=[None, adarms_cond]
|
| 211 |
+
)
|
| 212 |
+
v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])
|
| 213 |
+
|
| 214 |
+
return jnp.mean(jnp.square(v_t - u_t), axis=-1)
|
| 215 |
+
|
| 216 |
+
@override
|
| 217 |
+
def sample_actions(
|
| 218 |
+
self,
|
| 219 |
+
rng: at.KeyArrayLike,
|
| 220 |
+
observation: _model.Observation,
|
| 221 |
+
*,
|
| 222 |
+
num_steps: int | at.Int[at.Array, ""] = 10,
|
| 223 |
+
noise: at.Float[at.Array, "b ah ad"] | None = None,
|
| 224 |
+
) -> _model.Actions:
|
| 225 |
+
observation = _model.preprocess_observation(None, observation, train=False)
|
| 226 |
+
# note that we use the convention more common in diffusion literature, where t=1 is noise and t=0 is the target
|
| 227 |
+
# distribution. yes, this is the opposite of the pi0 paper, and I'm sorry.
|
| 228 |
+
dt = -1.0 / num_steps
|
| 229 |
+
batch_size = observation.state.shape[0]
|
| 230 |
+
if noise is None:
|
| 231 |
+
noise = jax.random.normal(rng, (batch_size, self.action_horizon, self.action_dim))
|
| 232 |
+
|
| 233 |
+
# first fill KV cache with a forward pass of the prefix
|
| 234 |
+
prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
|
| 235 |
+
prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
|
| 236 |
+
positions = jnp.cumsum(prefix_mask, axis=1) - 1
|
| 237 |
+
_, kv_cache = self.PaliGemma.llm([prefix_tokens, None], mask=prefix_attn_mask, positions=positions)
|
| 238 |
+
|
| 239 |
+
def step(carry):
|
| 240 |
+
x_t, time = carry
|
| 241 |
+
suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix(
|
| 242 |
+
observation, x_t, jnp.broadcast_to(time, batch_size)
|
| 243 |
+
)
|
| 244 |
+
# `suffix_attn_mask` is shape (b, suffix_len, suffix_len) indicating how the suffix tokens can attend to each
|
| 245 |
+
# other
|
| 246 |
+
suffix_attn_mask = make_attn_mask(suffix_mask, suffix_ar_mask)
|
| 247 |
+
# `prefix_attn_mask` is shape (b, suffix_len, prefix_len) indicating how the suffix tokens can attend to the
|
| 248 |
+
# prefix tokens
|
| 249 |
+
prefix_attn_mask = einops.repeat(prefix_mask, "b p -> b s p", s=suffix_tokens.shape[1])
|
| 250 |
+
# `combined_mask` is shape (b, suffix_len, prefix_len + suffix_len) indicating how the suffix tokens (which
|
| 251 |
+
# generate the queries) can attend to the full prefix + suffix sequence (which generates the keys and values)
|
| 252 |
+
full_attn_mask = jnp.concatenate([prefix_attn_mask, suffix_attn_mask], axis=-1)
|
| 253 |
+
assert full_attn_mask.shape == (
|
| 254 |
+
batch_size,
|
| 255 |
+
suffix_tokens.shape[1],
|
| 256 |
+
prefix_tokens.shape[1] + suffix_tokens.shape[1],
|
| 257 |
+
)
|
| 258 |
+
# `positions` is shape (b, suffix_len) indicating the positions of the suffix tokens
|
| 259 |
+
positions = jnp.sum(prefix_mask, axis=-1)[:, None] + jnp.cumsum(suffix_mask, axis=-1) - 1
|
| 260 |
+
|
| 261 |
+
(prefix_out, suffix_out), _ = self.PaliGemma.llm(
|
| 262 |
+
[None, suffix_tokens],
|
| 263 |
+
mask=full_attn_mask,
|
| 264 |
+
positions=positions,
|
| 265 |
+
kv_cache=kv_cache,
|
| 266 |
+
adarms_cond=[None, adarms_cond],
|
| 267 |
+
)
|
| 268 |
+
assert prefix_out is None
|
| 269 |
+
v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])
|
| 270 |
+
|
| 271 |
+
return x_t + dt * v_t, time + dt
|
| 272 |
+
|
| 273 |
+
def cond(carry):
|
| 274 |
+
x_t, time = carry
|
| 275 |
+
# robust to floating-point error
|
| 276 |
+
return time >= -dt / 2
|
| 277 |
+
|
| 278 |
+
x_0, _ = jax.lax.while_loop(cond, step, (noise, 1.0))
|
| 279 |
+
return x_0
|
capvector-pi05/src/openpi/models/pi0_config.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
from typing import TYPE_CHECKING
|
| 3 |
+
|
| 4 |
+
import flax.nnx as nnx
|
| 5 |
+
import jax
|
| 6 |
+
import jax.numpy as jnp
|
| 7 |
+
from typing_extensions import override
|
| 8 |
+
|
| 9 |
+
from openpi.models import model as _model
|
| 10 |
+
import openpi.models.gemma as _gemma
|
| 11 |
+
from openpi.shared import array_typing as at
|
| 12 |
+
import openpi.shared.nnx_utils as nnx_utils
|
| 13 |
+
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
from openpi.models.pi0 import Pi0
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclasses.dataclass(frozen=True)
|
| 19 |
+
class Pi0Config(_model.BaseModelConfig):
|
| 20 |
+
dtype: str = "bfloat16"
|
| 21 |
+
paligemma_variant: _gemma.Variant = "gemma_2b"
|
| 22 |
+
action_expert_variant: _gemma.Variant = "gemma_300m"
|
| 23 |
+
|
| 24 |
+
# Set the model specific defaults.
|
| 25 |
+
action_dim: int = 32
|
| 26 |
+
action_horizon: int = 50
|
| 27 |
+
max_token_len: int = None # type: ignore
|
| 28 |
+
# Pi05 has two differences from Pi0:
|
| 29 |
+
# - the state input is part of the discrete language tokens rather than a continuous input that is part of the suffix
|
| 30 |
+
# - the action expert uses adaRMSNorm to inject the flow matching timestep
|
| 31 |
+
pi05: bool = False
|
| 32 |
+
# This config option is not used directly by the model, but it is read by the ModelTransformFactory.
|
| 33 |
+
discrete_state_input: bool = None # type: ignore
|
| 34 |
+
|
| 35 |
+
def __post_init__(self):
|
| 36 |
+
if self.max_token_len is None:
|
| 37 |
+
object.__setattr__(self, "max_token_len", 200 if self.pi05 else 48)
|
| 38 |
+
if self.discrete_state_input is None:
|
| 39 |
+
object.__setattr__(self, "discrete_state_input", self.pi05)
|
| 40 |
+
|
| 41 |
+
@property
|
| 42 |
+
@override
|
| 43 |
+
def model_type(self) -> _model.ModelType:
|
| 44 |
+
if self.pi05:
|
| 45 |
+
return _model.ModelType.PI05
|
| 46 |
+
return _model.ModelType.PI0
|
| 47 |
+
|
| 48 |
+
@override
|
| 49 |
+
def create(self, rng: at.KeyArrayLike) -> "Pi0":
|
| 50 |
+
from openpi.models.pi0 import Pi0
|
| 51 |
+
|
| 52 |
+
return Pi0(self, rngs=nnx.Rngs(rng))
|
| 53 |
+
|
| 54 |
+
@override
|
| 55 |
+
def inputs_spec(self, *, batch_size: int = 1) -> tuple[_model.Observation, _model.Actions]:
|
| 56 |
+
image_spec = jax.ShapeDtypeStruct([batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32)
|
| 57 |
+
image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_)
|
| 58 |
+
|
| 59 |
+
with at.disable_typechecking():
|
| 60 |
+
observation_spec = _model.Observation(
|
| 61 |
+
images={
|
| 62 |
+
"base_0_rgb": image_spec,
|
| 63 |
+
"left_wrist_0_rgb": image_spec,
|
| 64 |
+
"right_wrist_0_rgb": image_spec,
|
| 65 |
+
},
|
| 66 |
+
image_masks={
|
| 67 |
+
"base_0_rgb": image_mask_spec,
|
| 68 |
+
"left_wrist_0_rgb": image_mask_spec,
|
| 69 |
+
"right_wrist_0_rgb": image_mask_spec,
|
| 70 |
+
},
|
| 71 |
+
state=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32),
|
| 72 |
+
tokenized_prompt=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32),
|
| 73 |
+
tokenized_prompt_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], bool),
|
| 74 |
+
)
|
| 75 |
+
action_spec = jax.ShapeDtypeStruct([batch_size, self.action_horizon, self.action_dim], jnp.float32)
|
| 76 |
+
|
| 77 |
+
return observation_spec, action_spec
|
| 78 |
+
|
| 79 |
+
def get_freeze_filter(self) -> nnx.filterlib.Filter:
|
| 80 |
+
"""Returns the freeze filter based on the model config."""
|
| 81 |
+
filters = []
|
| 82 |
+
has_lora = False
|
| 83 |
+
gemma_params_filter = nnx_utils.PathRegex(".*llm.*")
|
| 84 |
+
action_expert_params_filter = nnx_utils.PathRegex(".*llm.*_1.*")
|
| 85 |
+
if "lora" in self.paligemma_variant:
|
| 86 |
+
filters.append(
|
| 87 |
+
gemma_params_filter,
|
| 88 |
+
)
|
| 89 |
+
if "lora" not in self.action_expert_variant:
|
| 90 |
+
# If only freeze gemma params, exclude action expert params.
|
| 91 |
+
filters.append(
|
| 92 |
+
nnx.Not(action_expert_params_filter),
|
| 93 |
+
)
|
| 94 |
+
has_lora = True
|
| 95 |
+
elif "lora" in self.action_expert_variant:
|
| 96 |
+
filters.append(
|
| 97 |
+
action_expert_params_filter,
|
| 98 |
+
)
|
| 99 |
+
has_lora = True
|
| 100 |
+
|
| 101 |
+
if has_lora:
|
| 102 |
+
# If any lora is used, exclude all lora params.
|
| 103 |
+
filters.append(
|
| 104 |
+
nnx.Not(nnx_utils.PathRegex(".*lora.*")),
|
| 105 |
+
)
|
| 106 |
+
if not filters:
|
| 107 |
+
return nnx.Nothing
|
| 108 |
+
return nnx.All(*filters)
|
capvector-pi05/src/openpi/models/pi0_fast.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
import einops
|
| 6 |
+
import flax.nnx as nnx
|
| 7 |
+
import flax.nnx.bridge as nnx_bridge
|
| 8 |
+
import jax
|
| 9 |
+
import jax.numpy as jnp
|
| 10 |
+
from typing_extensions import override
|
| 11 |
+
|
| 12 |
+
from openpi.models import model as _model
|
| 13 |
+
import openpi.models.gemma_fast as _gemma
|
| 14 |
+
import openpi.models.siglip as _siglip
|
| 15 |
+
from openpi.shared import array_typing as at
|
| 16 |
+
import openpi.shared.nnx_utils as nnx_utils
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger("openpi")
|
| 19 |
+
|
| 20 |
+
PALIGEMMA_EOS_TOKEN = 1
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def make_attn_mask(input_mask, mask_ar):
|
| 24 |
+
"""Adapted from big_vision.
|
| 25 |
+
|
| 26 |
+
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
|
| 27 |
+
smaller or equal to theirs. This way `mask_ar` bool[?B, N] can be used to
|
| 28 |
+
setup several types of attention, for example:
|
| 29 |
+
|
| 30 |
+
[[1 1 1 1 1 1]]: pure causal attention.
|
| 31 |
+
|
| 32 |
+
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
| 33 |
+
themselves and the last 3 tokens have a causal attention. The first
|
| 34 |
+
entry could also be a 1 without changing behaviour.
|
| 35 |
+
|
| 36 |
+
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
| 37 |
+
block can attend all previous blocks and all tokens on the same block.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
input_mask: bool[B, N] true if its part of the input, false if padding.
|
| 41 |
+
mask_ar: bool[?B, N] mask that's true where previous tokens cannot depend on
|
| 42 |
+
it and false where it shares the same attention mask as the previous token.
|
| 43 |
+
"""
|
| 44 |
+
mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape)
|
| 45 |
+
cumsum = jnp.cumsum(mask_ar, axis=1)
|
| 46 |
+
attn_mask = cumsum[:, None, :] <= cumsum[:, :, None]
|
| 47 |
+
valid_mask = input_mask[:, None, :] * input_mask[:, :, None]
|
| 48 |
+
return jnp.logical_and(attn_mask, valid_mask)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@jax.vmap
|
| 52 |
+
def left_to_right_align(x, input_mask, attn_mask):
|
| 53 |
+
"""Converts input from left-align to right-aligned."""
|
| 54 |
+
# Due to vmap, this is operating in a single example (not batch level).
|
| 55 |
+
assert x.ndim == 2
|
| 56 |
+
assert input_mask.ndim == 1
|
| 57 |
+
assert attn_mask.ndim == 2
|
| 58 |
+
assert x.shape[0] == input_mask.shape[0]
|
| 59 |
+
assert attn_mask.shape[0] == attn_mask.shape[1], attn_mask.shape
|
| 60 |
+
seqlen = jnp.max(input_mask * jnp.arange(input_mask.shape[0])) + 1
|
| 61 |
+
x = jnp.roll(x, -seqlen, axis=0)
|
| 62 |
+
input_mask = jnp.roll(input_mask, -seqlen, axis=0)
|
| 63 |
+
attn_mask = jnp.roll(attn_mask, -seqlen, axis=(0, 1))
|
| 64 |
+
return x, input_mask, attn_mask
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def put_along_last_axis(arr, indices, values):
|
| 68 |
+
"""Like np.put_along_axis(..., axis=-1), since jax is missing it."""
|
| 69 |
+
assert arr.ndim == indices.ndim == values.ndim, (arr.ndim, indices.ndim, values.ndim)
|
| 70 |
+
onehot = jax.nn.one_hot(indices, arr.shape[-1], dtype=values.dtype)
|
| 71 |
+
put_mask = jnp.einsum("...i,...in->...n", jnp.ones(values.shape, jnp.int32), onehot)
|
| 72 |
+
put_values = jnp.einsum("...i,...in->...n", values, onehot)
|
| 73 |
+
return jnp.where(put_mask, put_values, arr)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@dataclasses.dataclass(frozen=True)
|
| 77 |
+
class Pi0FASTConfig(_model.BaseModelConfig):
|
| 78 |
+
dtype: str = "bfloat16"
|
| 79 |
+
paligemma_variant: _gemma.Variant = "gemma_2b"
|
| 80 |
+
|
| 81 |
+
# Set the model specific defaults.
|
| 82 |
+
action_dim: int = 32
|
| 83 |
+
action_horizon: int = 32
|
| 84 |
+
max_token_len: int = 250
|
| 85 |
+
|
| 86 |
+
# Tokenizer for the fast model.
|
| 87 |
+
fast_model_tokenizer: Any | None = None
|
| 88 |
+
# Keyword arguments for the fast model tokenizer.
|
| 89 |
+
fast_model_tokenizer_kwargs: dict[str, Any] | None = None
|
| 90 |
+
|
| 91 |
+
@property
|
| 92 |
+
@override
|
| 93 |
+
def model_type(self) -> _model.ModelType:
|
| 94 |
+
return _model.ModelType.PI0_FAST
|
| 95 |
+
|
| 96 |
+
@override
|
| 97 |
+
def create(self, rng: at.KeyArrayLike) -> "Pi0FAST":
|
| 98 |
+
return Pi0FAST(self, rngs=nnx.Rngs(rng))
|
| 99 |
+
|
| 100 |
+
@override
|
| 101 |
+
def inputs_spec(self, *, batch_size: int = 1) -> tuple[_model.Observation, _model.Actions]:
|
| 102 |
+
image_spec = jax.ShapeDtypeStruct([batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32)
|
| 103 |
+
image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_)
|
| 104 |
+
|
| 105 |
+
with at.disable_typechecking():
|
| 106 |
+
observation_spec = _model.Observation(
|
| 107 |
+
images={
|
| 108 |
+
"base_0_rgb": image_spec,
|
| 109 |
+
"base_1_rgb": image_spec,
|
| 110 |
+
"wrist_0_rgb": image_spec,
|
| 111 |
+
},
|
| 112 |
+
image_masks={
|
| 113 |
+
"base_0_rgb": image_mask_spec,
|
| 114 |
+
"base_1_rgb": image_mask_spec,
|
| 115 |
+
"wrist_0_rgb": image_mask_spec,
|
| 116 |
+
},
|
| 117 |
+
state=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32),
|
| 118 |
+
tokenized_prompt=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32),
|
| 119 |
+
tokenized_prompt_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], bool),
|
| 120 |
+
token_ar_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32),
|
| 121 |
+
token_loss_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.bool_),
|
| 122 |
+
)
|
| 123 |
+
action_spec = jax.ShapeDtypeStruct([batch_size, self.action_horizon, self.action_dim], jnp.float32)
|
| 124 |
+
|
| 125 |
+
return observation_spec, action_spec
|
| 126 |
+
|
| 127 |
+
def get_freeze_filter(self) -> nnx.filterlib.Filter:
|
| 128 |
+
"""Returns the freeze filter based on the model config."""
|
| 129 |
+
if "lora" in self.paligemma_variant:
|
| 130 |
+
return nnx.All(nnx_utils.PathRegex(".*llm.*"), nnx.Not(nnx_utils.PathRegex(".*lora.*")))
|
| 131 |
+
return nnx.Nothing
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class Pi0FAST(_model.BaseModel):
|
| 135 |
+
def __init__(self, config: Pi0FASTConfig, rngs: nnx.Rngs):
|
| 136 |
+
super().__init__(config.action_dim, config.action_horizon, config.max_token_len)
|
| 137 |
+
paligemma_config = _gemma.get_config(config.paligemma_variant)
|
| 138 |
+
# TODO: rewrite gemma in NNX. For now, use bridge.
|
| 139 |
+
llm = nnx_bridge.ToNNX(
|
| 140 |
+
_gemma.Module(
|
| 141 |
+
**paligemma_config,
|
| 142 |
+
embed_dtype=config.dtype,
|
| 143 |
+
cache_dtype=config.dtype,
|
| 144 |
+
)
|
| 145 |
+
)
|
| 146 |
+
llm.lazy_init(rngs=rngs, method="init")
|
| 147 |
+
img = nnx_bridge.ToNNX(
|
| 148 |
+
_siglip.Module(
|
| 149 |
+
num_classes=paligemma_config.width,
|
| 150 |
+
variant="So400m/14",
|
| 151 |
+
pool_type="none",
|
| 152 |
+
scan=True,
|
| 153 |
+
dtype_mm=config.dtype,
|
| 154 |
+
)
|
| 155 |
+
)
|
| 156 |
+
img.lazy_init(next(iter(config.fake_obs().images.values())), train=False, rngs=rngs)
|
| 157 |
+
self.PaliGemma = nnx.Dict(llm=llm, img=img)
|
| 158 |
+
|
| 159 |
+
@at.typecheck
|
| 160 |
+
def embed_inputs(
|
| 161 |
+
self, obs: _model.Observation
|
| 162 |
+
) -> tuple[at.Float[at.Array, "b s emb"], at.Bool[at.Array, "b s"], at.Int[at.Array, "b s"]]:
|
| 163 |
+
input_mask = []
|
| 164 |
+
ar_mask = []
|
| 165 |
+
token_embeddings = []
|
| 166 |
+
# embed images
|
| 167 |
+
for name in obs.images:
|
| 168 |
+
image_token_embeddings, _ = self.PaliGemma.img(obs.images[name], train=False)
|
| 169 |
+
|
| 170 |
+
token_embeddings.append(image_token_embeddings)
|
| 171 |
+
input_mask.append(
|
| 172 |
+
einops.repeat(
|
| 173 |
+
obs.image_masks[name],
|
| 174 |
+
"b -> b s",
|
| 175 |
+
s=image_token_embeddings.shape[1],
|
| 176 |
+
)
|
| 177 |
+
)
|
| 178 |
+
# image tokens attend to each other --> AR mask = 0
|
| 179 |
+
ar_mask.append(0 * input_mask[-1])
|
| 180 |
+
|
| 181 |
+
# add tokenized inputs
|
| 182 |
+
assert obs.tokenized_prompt is not None, "Tokenized prompt is required"
|
| 183 |
+
assert obs.tokenized_prompt_mask is not None, "Tokenized prompt mask is required"
|
| 184 |
+
assert obs.token_ar_mask is not None, "Token auto-regressive mask is required"
|
| 185 |
+
tokenized_inputs_embeddings = self.PaliGemma.llm(obs.tokenized_prompt, embed_only=True)
|
| 186 |
+
token_embeddings.append(tokenized_inputs_embeddings)
|
| 187 |
+
input_mask.append(obs.tokenized_prompt_mask)
|
| 188 |
+
ar_mask.append(obs.token_ar_mask)
|
| 189 |
+
|
| 190 |
+
# return embeddings, input mask, and ar mask
|
| 191 |
+
return (
|
| 192 |
+
jnp.concatenate(token_embeddings, axis=1),
|
| 193 |
+
jnp.concatenate(input_mask, axis=1),
|
| 194 |
+
jnp.concatenate(ar_mask, axis=1),
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
@override
|
| 198 |
+
def compute_loss(
|
| 199 |
+
self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool = False
|
| 200 |
+
) -> at.Float[at.Array, "*b ah"]:
|
| 201 |
+
observation = _model.preprocess_observation(
|
| 202 |
+
rng, observation, train=train, image_keys=list(observation.images.keys())
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# Compute inputs: one big forward pass of prefix + suffix at once
|
| 206 |
+
input_token_embeddings, input_mask, ar_mask = self.embed_inputs(observation)
|
| 207 |
+
attn_mask = make_attn_mask(input_mask, ar_mask)
|
| 208 |
+
|
| 209 |
+
# Compute one-hot targets: we predict *next* token, so shift the input tokens by one.
|
| 210 |
+
targets = jax.nn.one_hot(
|
| 211 |
+
observation.tokenized_prompt[:, 1:],
|
| 212 |
+
self.PaliGemma.llm.module.vocab_size,
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
# Each input predicts *next* token, so we don't input the last token.
|
| 216 |
+
pre_logits, _, _ = self.PaliGemma.llm(
|
| 217 |
+
embedded_prefix=input_token_embeddings[:, :-1],
|
| 218 |
+
mask=attn_mask[:, :-1, :-1],
|
| 219 |
+
return_prelogits=True,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# Only decode logits for the target tokens to save memory
|
| 223 |
+
# (decoding matmul is large because it is a seq_len x vocab_size dense layer).
|
| 224 |
+
logits, _ = self.PaliGemma.llm(
|
| 225 |
+
pre_logits=pre_logits[:, -targets.shape[1] :],
|
| 226 |
+
)
|
| 227 |
+
logp = jax.nn.log_softmax(logits, axis=-1)
|
| 228 |
+
|
| 229 |
+
# Compute CE loss on token targets
|
| 230 |
+
assert observation.token_loss_mask is not None, "Token loss mask is required"
|
| 231 |
+
loss_mask = observation.token_loss_mask[:, 1:]
|
| 232 |
+
token_pplx = jnp.sum(targets * logp, axis=-1)
|
| 233 |
+
return -jnp.sum(token_pplx * loss_mask, axis=-1) / jnp.clip(jnp.sum(loss_mask, -1), 1)
|
| 234 |
+
|
| 235 |
+
@override
|
| 236 |
+
def sample_actions(
|
| 237 |
+
self,
|
| 238 |
+
rng: at.KeyArrayLike,
|
| 239 |
+
observation: _model.Observation,
|
| 240 |
+
*,
|
| 241 |
+
max_decoding_steps: int | at.Int[at.Array, ""] = 256,
|
| 242 |
+
temperature: float = 0.0,
|
| 243 |
+
) -> _model.Actions:
|
| 244 |
+
# TODO: this is a hack to get the image keys.
|
| 245 |
+
observation = _model.preprocess_observation(
|
| 246 |
+
None, observation, train=False, image_keys=list(observation.images.keys())
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
# embed inputs
|
| 250 |
+
prefix_token_embeddings, prefix_mask, prefix_ar_mask = self.embed_inputs(observation)
|
| 251 |
+
prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
|
| 252 |
+
|
| 253 |
+
# left to right align all input token sequences
|
| 254 |
+
prefix_token_embeddings, prefix_mask, prefix_attn_mask = left_to_right_align(
|
| 255 |
+
prefix_token_embeddings, prefix_mask, prefix_attn_mask
|
| 256 |
+
)
|
| 257 |
+
prefill_size = prefix_token_embeddings.shape[1]
|
| 258 |
+
prefill_len = jnp.sum(prefix_mask, axis=-1)
|
| 259 |
+
prefix_start = prefill_size - prefill_len
|
| 260 |
+
|
| 261 |
+
# first fill KV cache with a forward pass of the prefix
|
| 262 |
+
# pad attention mask to set the size of the KV cache (prefill_size + max_decoding_steps)
|
| 263 |
+
prefix_attn_mask = jnp.pad(prefix_attn_mask, ((0, 0), (0, 0), (0, max_decoding_steps)))
|
| 264 |
+
prefix_positions = jnp.cumsum(prefix_mask, axis=-1) - 1
|
| 265 |
+
prefix_logits, kv_cache, _ = self.PaliGemma.llm(
|
| 266 |
+
embedded_prefix=prefix_token_embeddings, mask=prefix_attn_mask, positions=prefix_positions, decode=True
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
# prepare decoding -- final logit decodes the first token
|
| 270 |
+
last_logit = prefix_logits[:, -1:]
|
| 271 |
+
output_tokens = jnp.zeros((last_logit.shape[0], max_decoding_steps))
|
| 272 |
+
|
| 273 |
+
def step(carry):
|
| 274 |
+
rng, last_logit, output_tokens, cache, _, step = carry
|
| 275 |
+
|
| 276 |
+
# Sample token from last logit
|
| 277 |
+
# Split RNG for this step
|
| 278 |
+
rng, rng_step = jax.random.split(rng)
|
| 279 |
+
token = jax.lax.cond(
|
| 280 |
+
temperature > 0.0,
|
| 281 |
+
lambda _: jax.random.categorical(rng_step, last_logit / temperature, axis=-1),
|
| 282 |
+
lambda _: jnp.argmax(last_logit, axis=-1),
|
| 283 |
+
operand=None,
|
| 284 |
+
)
|
| 285 |
+
output_tokens = put_along_last_axis(output_tokens, jnp.broadcast_to(step, (token.shape[0], 1)), token)
|
| 286 |
+
|
| 287 |
+
# Check for early stopping --> stop if all batch elements have EOS token
|
| 288 |
+
has_eos = jnp.any(token == PALIGEMMA_EOS_TOKEN, axis=-1)
|
| 289 |
+
all_eos = jnp.all(has_eos)
|
| 290 |
+
|
| 291 |
+
# Decode one step
|
| 292 |
+
token_embedding = self.PaliGemma.llm(token, embed_only=True)
|
| 293 |
+
positions = prefill_len[:, None] + step + 1
|
| 294 |
+
mask = jnp.logical_and(
|
| 295 |
+
jnp.arange(prefill_size + max_decoding_steps)[None, None, :] >= prefix_start[:, None, None],
|
| 296 |
+
jnp.arange(prefill_size + max_decoding_steps)[None, None, :]
|
| 297 |
+
< (jnp.broadcast_to(prefill_size + step + 1, (prefix_start.shape[0], 1, 1))),
|
| 298 |
+
)
|
| 299 |
+
last_logit, kv_cache, _ = self.PaliGemma.llm(
|
| 300 |
+
embedded_prefix=token_embedding, mask=mask, positions=positions, decode=True, kv_cache=cache
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
return rng, last_logit, output_tokens, kv_cache, all_eos, step + 1
|
| 304 |
+
|
| 305 |
+
def cond(carry):
|
| 306 |
+
_, _, _, _, all_eos, step = carry
|
| 307 |
+
return (~all_eos) & (step < max_decoding_steps)
|
| 308 |
+
|
| 309 |
+
# Use lax.while_loop so we can jit the full decoding loop.
|
| 310 |
+
_, _, output_tokens, _, _, _ = jax.lax.while_loop(
|
| 311 |
+
cond, step, (rng, last_logit, output_tokens, kv_cache, False, 0)
|
| 312 |
+
)
|
| 313 |
+
return output_tokens
|