haofuly commited on
Commit
5e4171f
·
verified ·
1 Parent(s): 45ac12e

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. capvector-pi05/examples/libero/Dockerfile +59 -0
  2. capvector-pi05/examples/libero/README.md +71 -0
  3. capvector-pi05/examples/libero/main.py +219 -0
  4. capvector-pi05/examples/libero/requirements.in +11 -0
  5. capvector-pi05/examples/libero/requirements.txt +136 -0
  6. capvector-pi05/examples/simple_client/Dockerfile +32 -0
  7. capvector-pi05/examples/simple_client/README.md +30 -0
  8. capvector-pi05/examples/simple_client/compose.yml +42 -0
  9. capvector-pi05/examples/simple_client/main.py +187 -0
  10. capvector-pi05/examples/simple_client/requirements.in +5 -0
  11. capvector-pi05/examples/simple_client/requirements.txt +30 -0
  12. capvector-pi05/examples/ur5/README.md +142 -0
  13. capvector-pi05/packages/openpi-client/pyproject.toml +23 -0
  14. capvector-pi05/packages/openpi-client/src/openpi_client/__init__.py +1 -0
  15. capvector-pi05/packages/openpi-client/src/openpi_client/action_chunk_broker.py +50 -0
  16. capvector-pi05/packages/openpi-client/src/openpi_client/base_policy.py +12 -0
  17. capvector-pi05/packages/openpi-client/src/openpi_client/image_tools.py +78 -0
  18. capvector-pi05/packages/openpi-client/src/openpi_client/image_tools_test.py +37 -0
  19. capvector-pi05/packages/openpi-client/src/openpi_client/msgpack_numpy.py +57 -0
  20. capvector-pi05/packages/openpi-client/src/openpi_client/msgpack_numpy_test.py +45 -0
  21. capvector-pi05/packages/openpi-client/src/openpi_client/runtime/agent.py +17 -0
  22. capvector-pi05/packages/openpi-client/src/openpi_client/runtime/agents/policy_agent.py +18 -0
  23. capvector-pi05/packages/openpi-client/src/openpi_client/runtime/environment.py +32 -0
  24. capvector-pi05/packages/openpi-client/src/openpi_client/runtime/runtime.py +92 -0
  25. capvector-pi05/packages/openpi-client/src/openpi_client/runtime/subscriber.py +20 -0
  26. capvector-pi05/packages/openpi-client/src/openpi_client/websocket_client_policy.py +55 -0
  27. capvector-pi05/scripts/__init__.py +0 -0
  28. capvector-pi05/scripts/compute_norm_stats.py +117 -0
  29. capvector-pi05/scripts/docker/compose.yml +29 -0
  30. capvector-pi05/scripts/docker/install_docker_ubuntu22.sh +37 -0
  31. capvector-pi05/scripts/docker/install_nvidia_container_toolkit.sh +17 -0
  32. capvector-pi05/scripts/docker/serve_policy.Dockerfile +38 -0
  33. capvector-pi05/scripts/serve_policy.py +122 -0
  34. capvector-pi05/scripts/train.py +280 -0
  35. capvector-pi05/scripts/train_align_pytorch.py +658 -0
  36. capvector-pi05/scripts/train_pytorch.py +632 -0
  37. capvector-pi05/scripts/train_regular_loss_pytorch.py +754 -0
  38. capvector-pi05/scripts/train_test.py +30 -0
  39. capvector-pi05/src/openpi/__init__.py +0 -0
  40. capvector-pi05/src/openpi/conftest.py +17 -0
  41. capvector-pi05/src/openpi/models/__init__.py +0 -0
  42. capvector-pi05/src/openpi/models/gemma.py +459 -0
  43. capvector-pi05/src/openpi/models/gemma_fast.py +437 -0
  44. capvector-pi05/src/openpi/models/lora.py +148 -0
  45. capvector-pi05/src/openpi/models/lora_test.py +94 -0
  46. capvector-pi05/src/openpi/models/model.py +335 -0
  47. capvector-pi05/src/openpi/models/model_test.py +94 -0
  48. capvector-pi05/src/openpi/models/pi0.py +279 -0
  49. capvector-pi05/src/openpi/models/pi0_config.py +108 -0
  50. capvector-pi05/src/openpi/models/pi0_fast.py +313 -0
capvector-pi05/examples/libero/Dockerfile ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile for the LIBERO benchmark.
2
+
3
+ # Build the container:
4
+ # docker build . -t libero -f examples/libero/Dockerfile
5
+
6
+ # Run the container:
7
+ # docker run --rm -it --network=host -v .:/app -v /tmp/.X11-unix:/tmp/.X11-unix:ro -e DISPLAY=$DISPLAY --gpus all libero /bin/bash
8
+
9
+ FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0
10
+ COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
11
+
12
+ RUN apt-get update && \
13
+ apt-get install -y \
14
+ make \
15
+ g++ \
16
+ clang \
17
+ libosmesa6-dev \
18
+ libgl1-mesa-glx \
19
+ libglew-dev \
20
+ libglfw3-dev \
21
+ libgles2-mesa-dev \
22
+ libglib2.0-0 \
23
+ libsm6 \
24
+ libxrender1 \
25
+ libxext6
26
+
27
+ WORKDIR /app
28
+
29
+ # Copy from the cache instead of linking since it's a mounted volume
30
+ ENV UV_LINK_MODE=copy
31
+
32
+ # Write the virtual environment outside of the project directory so it doesn't
33
+ # leak out of the container when we mount the application code.
34
+ ENV UV_PROJECT_ENVIRONMENT=/.venv
35
+
36
+ # Copy the requirements files so we can install dependencies.
37
+ # The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
38
+ # This strategy is best for development-style usage.
39
+ COPY ./examples/libero/requirements.txt /tmp/requirements.txt
40
+ COPY ./third_party/libero/requirements.txt /tmp/requirements-libero.txt
41
+ COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
42
+
43
+ # Install python dependencies.
44
+ RUN uv venv --python 3.8 $UV_PROJECT_ENVIRONMENT
45
+ RUN uv pip sync /tmp/requirements.txt /tmp/requirements-libero.txt /tmp/openpi-client/pyproject.toml --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match
46
+ ENV PYTHONPATH=/app:/app/packages/openpi-client/src:/app/third_party/libero
47
+
48
+ # Create a default config file to avoid an input prompt from LIBERO's init script.
49
+ # https://github.com/Lifelong-Robot-Learning/LIBERO/blob/master/libero/libero/__init__.py
50
+ ENV LIBERO_CONFIG_PATH=/tmp/libero
51
+ RUN mkdir -p /tmp/libero && cat <<'EOF' > /tmp/libero/config.yaml
52
+ benchmark_root: /app/third_party/libero/libero/libero
53
+ bddl_files: /app/third_party/libero/libero/libero/bddl_files
54
+ init_states: /app/third_party/libero/libero/libero/init_files
55
+ datasets: /app/third_party/libero/libero/datasets
56
+ assets: /app/third_party/libero/libero/libero/assets
57
+ EOF
58
+
59
+ CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/libero/main.py $CLIENT_ARGS"]
capvector-pi05/examples/libero/README.md ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LIBERO Benchmark
2
+
3
+ This example runs the LIBERO benchmark: https://github.com/Lifelong-Robot-Learning/LIBERO
4
+
5
+ Note: When updating requirements.txt in this directory, there is an additional flag `--extra-index-url https://download.pytorch.org/whl/cu113` that must be added to the `uv pip compile` command.
6
+
7
+ This example requires git submodules to be initialized. Don't forget to run:
8
+
9
+ ```bash
10
+ git submodule update --init --recursive
11
+ ```
12
+
13
+ ## With Docker (recommended)
14
+
15
+ ```bash
16
+ # Grant access to the X11 server:
17
+ sudo xhost +local:docker
18
+
19
+ # To run with the default checkpoint and task suite:
20
+ SERVER_ARGS="--env LIBERO" docker compose -f examples/libero/compose.yml up --build
21
+
22
+ # To run with glx for Mujoco instead (use this if you have egl errors):
23
+ MUJOCO_GL=glx SERVER_ARGS="--env LIBERO" docker compose -f examples/libero/compose.yml up --build
24
+ ```
25
+
26
+ You can customize the loaded checkpoint by providing additional `SERVER_ARGS` (see `scripts/serve_policy.py`), and the LIBERO task suite by providing additional `CLIENT_ARGS` (see `examples/libero/main.py`).
27
+ For example:
28
+
29
+ ```bash
30
+ # To load a custom checkpoint (located in the top-level openpi/ directory):
31
+ export SERVER_ARGS="--env LIBERO policy:checkpoint --policy.config pi05_libero --policy.dir ./my_custom_checkpoint"
32
+
33
+ # To run the libero_10 task suite:
34
+ export CLIENT_ARGS="--args.task-suite-name libero_10"
35
+ ```
36
+
37
+ ## Without Docker (not recommended)
38
+
39
+ Terminal window 1:
40
+
41
+ ```bash
42
+ # Create virtual environment
43
+ uv venv --python 3.8 examples/libero/.venv
44
+ source examples/libero/.venv/bin/activate
45
+ uv pip sync examples/libero/requirements.txt third_party/libero/requirements.txt --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match
46
+ uv pip install -e packages/openpi-client
47
+ uv pip install -e third_party/libero
48
+ export PYTHONPATH=$PYTHONPATH:$PWD/third_party/libero
49
+
50
+ # Run the simulation
51
+ python examples/libero/main.py
52
+
53
+ # To run with glx for Mujoco instead (use this if you have egl errors):
54
+ MUJOCO_GL=glx python examples/libero/main.py
55
+ ```
56
+
57
+ Terminal window 2:
58
+
59
+ ```bash
60
+ # Run the server
61
+ uv run scripts/serve_policy.py --env LIBERO
62
+ ```
63
+
64
+ ## Results
65
+
66
+ If you want to reproduce the following numbers, you can evaluate the checkpoint at `gs://openpi-assets/checkpoints/pi05_libero/`. This
67
+ checkpoint was trained in openpi with the `pi05_libero` config.
68
+
69
+ | Model | Libero Spatial | Libero Object | Libero Goal | Libero 10 | Average |
70
+ |-------|---------------|---------------|-------------|-----------|---------|
71
+ | π0.5 @ 30k (finetuned) | 98.8 | 98.2 | 98.0 | 92.4 | 96.85
capvector-pi05/examples/libero/main.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import dataclasses
3
+ import logging
4
+ import math
5
+ import pathlib
6
+
7
+ import imageio
8
+ from libero.libero import benchmark
9
+ from libero.libero import get_libero_path
10
+ from libero.libero.envs import OffScreenRenderEnv
11
+ import numpy as np
12
+ from openpi_client import image_tools
13
+ from openpi_client import websocket_client_policy as _websocket_client_policy
14
+ import tqdm
15
+ import tyro
16
+
17
+ LIBERO_DUMMY_ACTION = [0.0] * 6 + [-1.0]
18
+ LIBERO_ENV_RESOLUTION = 256 # resolution used to render training data
19
+
20
+
21
+ @dataclasses.dataclass
22
+ class Args:
23
+ #################################################################################################################
24
+ # Model server parameters
25
+ #################################################################################################################
26
+ host: str = "0.0.0.0"
27
+ port: int = 8000
28
+ resize_size: int = 224
29
+ replan_steps: int = 5
30
+
31
+ #################################################################################################################
32
+ # LIBERO environment-specific parameters
33
+ #################################################################################################################
34
+ task_suite_name: str = (
35
+ "libero_spatial" # Task suite. Options: libero_spatial, libero_object, libero_goal, libero_10, libero_90
36
+ )
37
+ num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize i n sim
38
+ num_trials_per_task: int = 50 # Number of rollouts per task
39
+
40
+ #################################################################################################################
41
+ # Utils
42
+ #################################################################################################################
43
+ video_out_path: str = "data/libero/videos" # Path to save videos
44
+
45
+ seed: int = 7 # Random Seed (for reproducibility)
46
+
47
+
48
+ def eval_libero(args: Args) -> None:
49
+ # Set random seed
50
+ np.random.seed(args.seed)
51
+
52
+ # Initialize LIBERO task suite
53
+ benchmark_dict = benchmark.get_benchmark_dict()
54
+ task_suite = benchmark_dict[args.task_suite_name]()
55
+ num_tasks_in_suite = task_suite.n_tasks
56
+ logging.info(f"Task suite: {args.task_suite_name}")
57
+
58
+ pathlib.Path(args.video_out_path).mkdir(parents=True, exist_ok=True)
59
+
60
+ if args.task_suite_name == "libero_spatial":
61
+ max_steps = 220 # longest training demo has 193 steps
62
+ elif args.task_suite_name == "libero_object":
63
+ max_steps = 280 # longest training demo has 254 steps
64
+ elif args.task_suite_name == "libero_goal":
65
+ max_steps = 300 # longest training demo has 270 steps
66
+ elif args.task_suite_name == "libero_10":
67
+ max_steps = 520 # longest training demo has 505 steps
68
+ elif args.task_suite_name == "libero_90":
69
+ max_steps = 400 # longest training demo has 373 steps
70
+ else:
71
+ raise ValueError(f"Unknown task suite: {args.task_suite_name}")
72
+
73
+ client = _websocket_client_policy.WebsocketClientPolicy(args.host, args.port)
74
+
75
+ # Start evaluation
76
+ total_episodes, total_successes = 0, 0
77
+ for task_id in tqdm.tqdm(range(num_tasks_in_suite)):
78
+ # Get task
79
+ task = task_suite.get_task(task_id)
80
+
81
+ # Get default LIBERO initial states
82
+ initial_states = task_suite.get_task_init_states(task_id)
83
+
84
+ # Initialize LIBERO environment and task description
85
+ env, task_description = _get_libero_env(task, LIBERO_ENV_RESOLUTION, args.seed)
86
+
87
+ # Start episodes
88
+ task_episodes, task_successes = 0, 0
89
+ for episode_idx in tqdm.tqdm(range(args.num_trials_per_task)):
90
+ logging.info(f"\nTask: {task_description}")
91
+
92
+ # Reset environment
93
+ env.reset()
94
+ action_plan = collections.deque()
95
+
96
+ # Set initial states
97
+ obs = env.set_init_state(initial_states[episode_idx])
98
+
99
+ # Setup
100
+ t = 0
101
+ replay_images = []
102
+
103
+ logging.info(f"Starting episode {task_episodes+1}...")
104
+ while t < max_steps + args.num_steps_wait:
105
+ try:
106
+ # IMPORTANT: Do nothing for the first few timesteps because the simulator drops objects
107
+ # and we need to wait for them to fall
108
+ if t < args.num_steps_wait:
109
+ obs, reward, done, info = env.step(LIBERO_DUMMY_ACTION)
110
+ t += 1
111
+ continue
112
+
113
+ # Get preprocessed image
114
+ # IMPORTANT: rotate 180 degrees to match train preprocessing
115
+ img = np.ascontiguousarray(obs["agentview_image"][::-1, ::-1])
116
+ wrist_img = np.ascontiguousarray(obs["robot0_eye_in_hand_image"][::-1, ::-1])
117
+ img = image_tools.convert_to_uint8(
118
+ image_tools.resize_with_pad(img, args.resize_size, args.resize_size)
119
+ )
120
+ wrist_img = image_tools.convert_to_uint8(
121
+ image_tools.resize_with_pad(wrist_img, args.resize_size, args.resize_size)
122
+ )
123
+
124
+ # Save preprocessed image for replay video
125
+ replay_images.append(img)
126
+
127
+ if not action_plan:
128
+ # Finished executing previous action chunk -- compute new chunk
129
+ # Prepare observations dict
130
+ element = {
131
+ "observation/image": img,
132
+ "observation/wrist_image": wrist_img,
133
+ "observation/state": np.concatenate(
134
+ (
135
+ obs["robot0_eef_pos"],
136
+ _quat2axisangle(obs["robot0_eef_quat"]),
137
+ obs["robot0_gripper_qpos"],
138
+ )
139
+ ),
140
+ "prompt": str(task_description),
141
+ }
142
+
143
+ # Query model to get action
144
+ action_chunk = client.infer(element)["actions"]
145
+ assert (
146
+ len(action_chunk) >= args.replan_steps
147
+ ), f"We want to replan every {args.replan_steps} steps, but policy only predicts {len(action_chunk)} steps."
148
+ action_plan.extend(action_chunk[: args.replan_steps])
149
+
150
+ action = action_plan.popleft()
151
+
152
+ # Execute action in environment
153
+ obs, reward, done, info = env.step(action.tolist())
154
+ if done:
155
+ task_successes += 1
156
+ total_successes += 1
157
+ break
158
+ t += 1
159
+
160
+ except Exception as e:
161
+ logging.error(f"Caught exception: {e}")
162
+ break
163
+
164
+ task_episodes += 1
165
+ total_episodes += 1
166
+
167
+ # Save a replay video of the episode
168
+ suffix = "success" if done else "failure"
169
+ task_segment = task_description.replace(" ", "_")
170
+ imageio.mimwrite(
171
+ pathlib.Path(args.video_out_path) / f"rollout_{task_segment}_{suffix}.mp4",
172
+ [np.asarray(x) for x in replay_images],
173
+ fps=10,
174
+ )
175
+
176
+ # Log current results
177
+ logging.info(f"Success: {done}")
178
+ logging.info(f"# episodes completed so far: {total_episodes}")
179
+ logging.info(f"# successes: {total_successes} ({total_successes / total_episodes * 100:.1f}%)")
180
+
181
+ # Log final results
182
+ logging.info(f"Current task success rate: {float(task_successes) / float(task_episodes)}")
183
+ logging.info(f"Current total success rate: {float(total_successes) / float(total_episodes)}")
184
+
185
+ logging.info(f"Total success rate: {float(total_successes) / float(total_episodes)}")
186
+ logging.info(f"Total episodes: {total_episodes}")
187
+
188
+
189
+ def _get_libero_env(task, resolution, seed):
190
+ """Initializes and returns the LIBERO environment, along with the task description."""
191
+ task_description = task.language
192
+ task_bddl_file = pathlib.Path(get_libero_path("bddl_files")) / task.problem_folder / task.bddl_file
193
+ env_args = {"bddl_file_name": task_bddl_file, "camera_heights": resolution, "camera_widths": resolution}
194
+ env = OffScreenRenderEnv(**env_args)
195
+ env.seed(seed) # IMPORTANT: seed seems to affect object positions even when using fixed initial state
196
+ return env, task_description
197
+
198
+
199
+ def _quat2axisangle(quat):
200
+ """
201
+ Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
202
+ """
203
+ # clip quaternion
204
+ if quat[3] > 1.0:
205
+ quat[3] = 1.0
206
+ elif quat[3] < -1.0:
207
+ quat[3] = -1.0
208
+
209
+ den = np.sqrt(1.0 - quat[3] * quat[3])
210
+ if math.isclose(den, 0.0):
211
+ # This is (close to) a zero degree rotation, immediately return
212
+ return np.zeros(3)
213
+
214
+ return (quat[:3] * 2.0 * math.acos(quat[3])) / den
215
+
216
+
217
+ if __name__ == "__main__":
218
+ logging.basicConfig(level=logging.INFO)
219
+ tyro.cli(eval_libero)
capvector-pi05/examples/libero/requirements.in ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ imageio[ffmpeg]
2
+ numpy==1.22.4
3
+ tqdm
4
+ tyro
5
+ PyYaml
6
+ opencv-python==4.6.0.66
7
+ torch==1.11.0+cu113
8
+ torchvision==0.12.0+cu113
9
+ torchaudio==0.11.0+cu113
10
+ robosuite==1.4.1
11
+ matplotlib==3.5.3
capvector-pi05/examples/libero/requirements.txt ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv pip compile examples/libero/requirements.in -o examples/libero/requirements.txt --python-version 3.8 --index-strategy=unsafe-best-match
3
+ absl-py==2.1.0
4
+ # via mujoco
5
+ certifi==2024.12.14
6
+ # via requests
7
+ charset-normalizer==3.4.0
8
+ # via requests
9
+ cycler==0.12.1
10
+ # via matplotlib
11
+ docstring-parser==0.16
12
+ # via tyro
13
+ etils==1.3.0
14
+ # via mujoco
15
+ eval-type-backport==0.2.0
16
+ # via tyro
17
+ evdev==1.7.1
18
+ # via pynput
19
+ fonttools==4.55.3
20
+ # via matplotlib
21
+ glfw==1.12.0
22
+ # via mujoco
23
+ idna==3.10
24
+ # via requests
25
+ imageio==2.35.1
26
+ # via -r examples/libero/requirements.in
27
+ imageio-ffmpeg==0.5.1
28
+ # via imageio
29
+ importlib-metadata==8.5.0
30
+ # via typeguard
31
+ importlib-resources==6.4.5
32
+ # via etils
33
+ kiwisolver==1.4.7
34
+ # via matplotlib
35
+ llvmlite==0.36.0
36
+ # via numba
37
+ markdown-it-py==3.0.0
38
+ # via rich
39
+ matplotlib==3.5.3
40
+ # via -r examples/libero/requirements.in
41
+ mdurl==0.1.2
42
+ # via markdown-it-py
43
+ mujoco==3.2.3
44
+ # via robosuite
45
+ numba==0.53.1
46
+ # via robosuite
47
+ numpy==1.22.4
48
+ # via
49
+ # -r examples/libero/requirements.in
50
+ # imageio
51
+ # matplotlib
52
+ # mujoco
53
+ # numba
54
+ # opencv-python
55
+ # robosuite
56
+ # scipy
57
+ # torchvision
58
+ opencv-python==4.6.0.66
59
+ # via
60
+ # -r examples/libero/requirements.in
61
+ # robosuite
62
+ packaging==24.2
63
+ # via matplotlib
64
+ pillow==10.4.0
65
+ # via
66
+ # imageio
67
+ # matplotlib
68
+ # robosuite
69
+ # torchvision
70
+ psutil==6.1.0
71
+ # via imageio
72
+ pygments==2.18.0
73
+ # via rich
74
+ pynput==1.7.7
75
+ # via robosuite
76
+ pyopengl==3.1.7
77
+ # via mujoco
78
+ pyparsing==3.1.4
79
+ # via matplotlib
80
+ python-dateutil==2.9.0.post0
81
+ # via matplotlib
82
+ python-xlib==0.33
83
+ # via pynput
84
+ pyyaml==6.0.2
85
+ # via -r examples/libero/requirements.in
86
+ requests==2.32.3
87
+ # via torchvision
88
+ rich==13.9.4
89
+ # via tyro
90
+ robosuite==1.4.1
91
+ # via -r examples/libero/requirements.in
92
+ scipy==1.10.1
93
+ # via robosuite
94
+ setuptools==75.3.0
95
+ # via
96
+ # imageio-ffmpeg
97
+ # numba
98
+ shtab==1.7.1
99
+ # via tyro
100
+ six==1.17.0
101
+ # via
102
+ # pynput
103
+ # python-dateutil
104
+ # python-xlib
105
+ termcolor==2.4.0
106
+ # via robosuite
107
+ torch==1.11.0+cu113
108
+ # via
109
+ # -r examples/libero/requirements.in
110
+ # torchaudio
111
+ # torchvision
112
+ torchaudio==0.11.0+cu113
113
+ # via -r examples/libero/requirements.in
114
+ torchvision==0.12.0+cu113
115
+ # via -r examples/libero/requirements.in
116
+ tqdm==4.67.1
117
+ # via -r examples/libero/requirements.in
118
+ typeguard==4.4.0
119
+ # via tyro
120
+ typing-extensions==4.12.2
121
+ # via
122
+ # etils
123
+ # rich
124
+ # torch
125
+ # torchvision
126
+ # typeguard
127
+ # tyro
128
+ tyro==0.9.2
129
+ # via -r examples/libero/requirements.in
130
+ urllib3==2.2.3
131
+ # via requests
132
+ zipp==3.20.2
133
+ # via
134
+ # etils
135
+ # importlib-metadata
136
+ # importlib-resources
capvector-pi05/examples/simple_client/Dockerfile ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile for the simple client.
2
+
3
+ # Build the container:
4
+ # docker build . -t simple_client -f examples/simple_client/Dockerfile
5
+
6
+ # Run the container:
7
+ # docker run --rm -it --network=host -v .:/app simple_client /bin/bash
8
+
9
+ FROM python:3.7-slim
10
+ COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
11
+
12
+ WORKDIR /app
13
+
14
+ # Copy from the cache instead of linking since it's a mounted volume
15
+ ENV UV_LINK_MODE=copy
16
+
17
+ # Write the virtual environment outside of the project directory so it doesn't
18
+ # leak out of the container when we mount the application code.
19
+ ENV UV_PROJECT_ENVIRONMENT=/.venv
20
+
21
+ # Copy the requirements files so we can install dependencies.
22
+ # The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
23
+ # This strategy is best for development-style usage.
24
+ COPY ./examples/simple_client/requirements.txt /tmp/requirements.txt
25
+ COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
26
+
27
+ # Install python dependencies.
28
+ RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
29
+ RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
30
+ ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
31
+
32
+ CMD /bin/bash -c "source /.venv/bin/activate && python examples/simple_client/main.py $SERVER_ARGS"
capvector-pi05/examples/simple_client/README.md ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Simple Client
2
+
3
+ A minimal client that sends observations to the server and prints the inference rate.
4
+
5
+ You can specify which runtime environment to use using the `--env` flag. You can see the available options by running:
6
+
7
+ ```bash
8
+ uv run examples/simple_client/main.py --help
9
+ ```
10
+
11
+ ## With Docker
12
+
13
+ ```bash
14
+ export SERVER_ARGS="--env ALOHA_SIM"
15
+ docker compose -f examples/simple_client/compose.yml up --build
16
+ ```
17
+
18
+ ## Without Docker
19
+
20
+ Terminal window 1:
21
+
22
+ ```bash
23
+ uv run examples/simple_client/main.py --env DROID
24
+ ```
25
+
26
+ Terminal window 2:
27
+
28
+ ```bash
29
+ uv run scripts/serve_policy.py --env DROID
30
+ ```
capvector-pi05/examples/simple_client/compose.yml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Run with:
2
+ # docker compose -f examples/simple_client/compose.yml up --build
3
+ services:
4
+ runtime:
5
+ image: simple_client
6
+ depends_on:
7
+ - openpi_server
8
+ build:
9
+ context: ../..
10
+ dockerfile: examples/simple_client/Dockerfile
11
+ init: true
12
+ tty: true
13
+ network_mode: host
14
+ volumes:
15
+ - $PWD:/app
16
+ environment:
17
+ - SERVER_ARGS
18
+
19
+ openpi_server:
20
+ image: openpi_server
21
+ build:
22
+ context: ../..
23
+ dockerfile: scripts/docker/serve_policy.Dockerfile
24
+ init: true
25
+ tty: true
26
+ network_mode: host
27
+ volumes:
28
+ - $PWD:/app
29
+ - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
30
+ environment:
31
+ - SERVER_ARGS
32
+ - OPENPI_DATA_HOME=/openpi_assets
33
+ - IS_DOCKER=true
34
+
35
+ # Comment out this block if not running on a machine with GPUs.
36
+ deploy:
37
+ resources:
38
+ reservations:
39
+ devices:
40
+ - driver: nvidia
41
+ count: 1
42
+ capabilities: [gpu]
capvector-pi05/examples/simple_client/main.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import enum
3
+ import logging
4
+ import pathlib
5
+ import time
6
+
7
+ import numpy as np
8
+ from openpi_client import websocket_client_policy as _websocket_client_policy
9
+ import polars as pl
10
+ import rich
11
+ import tqdm
12
+ import tyro
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class EnvMode(enum.Enum):
18
+ """Supported environments."""
19
+
20
+ ALOHA = "aloha"
21
+ ALOHA_SIM = "aloha_sim"
22
+ DROID = "droid"
23
+ LIBERO = "libero"
24
+
25
+
26
+ @dataclasses.dataclass
27
+ class Args:
28
+ """Command line arguments."""
29
+
30
+ # Host and port to connect to the server.
31
+ host: str = "0.0.0.0"
32
+ # Port to connect to the server. If None, the server will use the default port.
33
+ port: int | None = 8000
34
+ # API key to use for the server.
35
+ api_key: str | None = None
36
+ # Number of steps to run the policy for.
37
+ num_steps: int = 20
38
+ # Path to save the timings to a parquet file. (e.g., timing.parquet)
39
+ timing_file: pathlib.Path | None = None
40
+ # Environment to run the policy in.
41
+ env: EnvMode = EnvMode.ALOHA_SIM
42
+
43
+
44
+ class TimingRecorder:
45
+ """Records timing measurements for different keys."""
46
+
47
+ def __init__(self) -> None:
48
+ self._timings: dict[str, list[float]] = {}
49
+
50
+ def record(self, key: str, time_ms: float) -> None:
51
+ """Record a timing measurement for the given key."""
52
+ if key not in self._timings:
53
+ self._timings[key] = []
54
+ self._timings[key].append(time_ms)
55
+
56
+ def get_stats(self, key: str) -> dict[str, float]:
57
+ """Get statistics for the given key."""
58
+ times = self._timings[key]
59
+ return {
60
+ "mean": float(np.mean(times)),
61
+ "std": float(np.std(times)),
62
+ "p25": float(np.quantile(times, 0.25)),
63
+ "p50": float(np.quantile(times, 0.50)),
64
+ "p75": float(np.quantile(times, 0.75)),
65
+ "p90": float(np.quantile(times, 0.90)),
66
+ "p95": float(np.quantile(times, 0.95)),
67
+ "p99": float(np.quantile(times, 0.99)),
68
+ }
69
+
70
+ def print_all_stats(self) -> None:
71
+ """Print statistics for all keys in a concise format."""
72
+
73
+ table = rich.table.Table(
74
+ title="[bold blue]Timing Statistics[/bold blue]",
75
+ show_header=True,
76
+ header_style="bold white",
77
+ border_style="blue",
78
+ title_justify="center",
79
+ )
80
+
81
+ # Add metric column with custom styling
82
+ table.add_column("Metric", style="cyan", justify="left", no_wrap=True)
83
+
84
+ # Add statistical columns with consistent styling
85
+ stat_columns = [
86
+ ("Mean", "yellow", "mean"),
87
+ ("Std", "yellow", "std"),
88
+ ("P25", "magenta", "p25"),
89
+ ("P50", "magenta", "p50"),
90
+ ("P75", "magenta", "p75"),
91
+ ("P90", "magenta", "p90"),
92
+ ("P95", "magenta", "p95"),
93
+ ("P99", "magenta", "p99"),
94
+ ]
95
+
96
+ for name, style, _ in stat_columns:
97
+ table.add_column(name, justify="right", style=style, no_wrap=True)
98
+
99
+ # Add rows for each metric with formatted values
100
+ for key in sorted(self._timings.keys()):
101
+ stats = self.get_stats(key)
102
+ values = [f"{stats[key]:.1f}" for _, _, key in stat_columns]
103
+ table.add_row(key, *values)
104
+
105
+ # Print with custom console settings
106
+ console = rich.console.Console(width=None, highlight=True)
107
+ console.print(table)
108
+
109
+ def write_parquet(self, path: pathlib.Path) -> None:
110
+ """Save the timings to a parquet file."""
111
+ logger.info(f"Writing timings to {path}")
112
+ frame = pl.DataFrame(self._timings)
113
+ path.parent.mkdir(parents=True, exist_ok=True)
114
+ frame.write_parquet(path)
115
+
116
+
117
+ def main(args: Args) -> None:
118
+ obs_fn = {
119
+ EnvMode.ALOHA: _random_observation_aloha,
120
+ EnvMode.ALOHA_SIM: _random_observation_aloha,
121
+ EnvMode.DROID: _random_observation_droid,
122
+ EnvMode.LIBERO: _random_observation_libero,
123
+ }[args.env]
124
+
125
+ policy = _websocket_client_policy.WebsocketClientPolicy(
126
+ host=args.host,
127
+ port=args.port,
128
+ api_key=args.api_key,
129
+ )
130
+ logger.info(f"Server metadata: {policy.get_server_metadata()}")
131
+
132
+ # Send a few observations to make sure the model is loaded.
133
+ for _ in range(2):
134
+ policy.infer(obs_fn())
135
+
136
+ timing_recorder = TimingRecorder()
137
+
138
+ for _ in tqdm.trange(args.num_steps, desc="Running policy"):
139
+ inference_start = time.time()
140
+ action = policy.infer(obs_fn())
141
+ timing_recorder.record("client_infer_ms", 1000 * (time.time() - inference_start))
142
+ for key, value in action.get("server_timing", {}).items():
143
+ timing_recorder.record(f"server_{key}", value)
144
+ for key, value in action.get("policy_timing", {}).items():
145
+ timing_recorder.record(f"policy_{key}", value)
146
+
147
+ timing_recorder.print_all_stats()
148
+
149
+ if args.timing_file is not None:
150
+ timing_recorder.write_parquet(args.timing_file)
151
+
152
+
153
+ def _random_observation_aloha() -> dict:
154
+ return {
155
+ "state": np.ones((14,)),
156
+ "images": {
157
+ "cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
158
+ "cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
159
+ "cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
160
+ "cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
161
+ },
162
+ "prompt": "do something",
163
+ }
164
+
165
+
166
+ def _random_observation_droid() -> dict:
167
+ return {
168
+ "observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
169
+ "observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
170
+ "observation/joint_position": np.random.rand(7),
171
+ "observation/gripper_position": np.random.rand(1),
172
+ "prompt": "do something",
173
+ }
174
+
175
+
176
+ def _random_observation_libero() -> dict:
177
+ return {
178
+ "observation/state": np.random.rand(8),
179
+ "observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
180
+ "observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
181
+ "prompt": "do something",
182
+ }
183
+
184
+
185
+ if __name__ == "__main__":
186
+ logging.basicConfig(level=logging.INFO)
187
+ main(tyro.cli(Args))
capvector-pi05/examples/simple_client/requirements.in ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ numpy>=1.22.4,<2.0.0
2
+ rich
3
+ tqdm
4
+ tyro
5
+ polars
capvector-pi05/examples/simple_client/requirements.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv pip compile examples/simple_client/requirements.in -o examples/simple_client/requirements.txt --python-version 3.11.9
3
+ docstring-parser==0.16
4
+ # via tyro
5
+ markdown-it-py==3.0.0
6
+ # via rich
7
+ mdurl==0.1.2
8
+ # via markdown-it-py
9
+ numpy==1.26.4
10
+ # via -r examples/simple_client/requirements.in
11
+ polars==1.30.0
12
+ # via -r examples/simple_client/requirements.in
13
+ pygments==2.19.1
14
+ # via rich
15
+ rich==14.0.0
16
+ # via
17
+ # -r examples/simple_client/requirements.in
18
+ # tyro
19
+ shtab==1.7.2
20
+ # via tyro
21
+ tqdm==4.67.1
22
+ # via -r examples/simple_client/requirements.in
23
+ typeguard==4.4.2
24
+ # via tyro
25
+ typing-extensions==4.13.2
26
+ # via
27
+ # typeguard
28
+ # tyro
29
+ tyro==0.9.22
30
+ # via -r examples/simple_client/requirements.in
capvector-pi05/examples/ur5/README.md ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # UR5 Example
2
+
3
+ Below we provide an outline of how to implement the key components mentioned in the "Finetune on your data" section of the [README](../README.md) for finetuning on UR5 datasets.
4
+
5
+ First, we will define the `UR5Inputs` and `UR5Outputs` classes, which map the UR5 environment to the model and vice versa. Check the corresponding files in `src/openpi/policies/libero_policy.py` for comments explaining each line.
6
+
7
+ ```python
8
+
9
+ @dataclasses.dataclass(frozen=True)
10
+ class UR5Inputs(transforms.DataTransformFn):
11
+
12
+ model_type: _model.ModelType = _model.ModelType.PI0
13
+
14
+ def __call__(self, data: dict) -> dict:
15
+ # First, concatenate the joints and gripper into the state vector.
16
+ state = np.concatenate([data["joints"], data["gripper"]])
17
+
18
+ # Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically
19
+ # stores as float32 (C,H,W), gets skipped for policy inference.
20
+ base_image = _parse_image(data["base_rgb"])
21
+ wrist_image = _parse_image(data["wrist_rgb"])
22
+
23
+ # Create inputs dict.
24
+ inputs = {
25
+ "state": state,
26
+ "image": {
27
+ "base_0_rgb": base_image,
28
+ "left_wrist_0_rgb": wrist_image,
29
+ # Since there is no right wrist, replace with zeros
30
+ "right_wrist_0_rgb": np.zeros_like(base_image),
31
+ },
32
+ "image_mask": {
33
+ "base_0_rgb": np.True_,
34
+ "left_wrist_0_rgb": np.True_,
35
+ # Since the "slot" for the right wrist is not used, this mask is set
36
+ # to False
37
+ "right_wrist_0_rgb": np.True_ if self.model_type == _model.ModelType.PI0_FAST else np.False_,
38
+ },
39
+ }
40
+
41
+ if "actions" in data:
42
+ inputs["actions"] = data["actions"]
43
+
44
+ # Pass the prompt (aka language instruction) to the model.
45
+ if "prompt" in data:
46
+ inputs["prompt"] = data["prompt"]
47
+
48
+ return inputs
49
+
50
+
51
+ @dataclasses.dataclass(frozen=True)
52
+ class UR5Outputs(transforms.DataTransformFn):
53
+
54
+ def __call__(self, data: dict) -> dict:
55
+ # Since the robot has 7 action dimensions (6 DoF + gripper), return the first 7 dims
56
+ return {"actions": np.asarray(data["actions"][:, :7])}
57
+
58
+ ```
59
+
60
+ Next, we will define the `UR5DataConfig` class, which defines how to process raw UR5 data from LeRobot dataset for training. For a full example, see the `LeRobotLiberoDataConfig` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py).
61
+
62
+ ```python
63
+
64
+ @dataclasses.dataclass(frozen=True)
65
+ class LeRobotUR5DataConfig(DataConfigFactory):
66
+
67
+ @override
68
+ def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
69
+ # Boilerplate for remapping keys from the LeRobot dataset. We assume no renaming needed here.
70
+ repack_transform = _transforms.Group(
71
+ inputs=[
72
+ _transforms.RepackTransform(
73
+ {
74
+ "base_rgb": "image",
75
+ "wrist_rgb": "wrist_image",
76
+ "joints": "joints",
77
+ "gripper": "gripper",
78
+ "prompt": "prompt",
79
+ }
80
+ )
81
+ ]
82
+ )
83
+
84
+ # These transforms are the ones we wrote earlier.
85
+ data_transforms = _transforms.Group(
86
+ inputs=[UR5Inputs(action_dim=model_config.action_dim, model_type=model_config.model_type)],
87
+ outputs=[UR5Outputs()],
88
+ )
89
+
90
+ # Convert absolute actions to delta actions.
91
+ # By convention, we do not convert the gripper action (7th dimension).
92
+ delta_action_mask = _transforms.make_bool_mask(6, -1)
93
+ data_transforms = data_transforms.push(
94
+ inputs=[_transforms.DeltaActions(delta_action_mask)],
95
+ outputs=[_transforms.AbsoluteActions(delta_action_mask)],
96
+ )
97
+
98
+ # Model transforms include things like tokenizing the prompt and action targets
99
+ # You do not need to change anything here for your own dataset.
100
+ model_transforms = ModelTransformFactory()(model_config)
101
+
102
+ # We return all data transforms for training and inference. No need to change anything here.
103
+ return dataclasses.replace(
104
+ self.create_base_config(assets_dirs),
105
+ repack_transforms=repack_transform,
106
+ data_transforms=data_transforms,
107
+ model_transforms=model_transforms,
108
+ )
109
+
110
+ ```
111
+
112
+ Finally, we define the TrainConfig for our UR5 dataset. Here, we define a config for fine-tuning pi0 on our UR5 dataset. See the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py) for more examples, e.g. for pi0-FAST or for LoRA fine-tuning.
113
+
114
+ ```python
115
+ TrainConfig(
116
+ name="pi0_ur5",
117
+ model=pi0.Pi0Config(),
118
+ data=LeRobotUR5DataConfig(
119
+ repo_id="your_username/ur5_dataset",
120
+ # This config lets us reload the UR5 normalization stats from the base model checkpoint.
121
+ # Reloading normalization stats can help transfer pre-trained models to new environments.
122
+ # See the [norm_stats.md](../docs/norm_stats.md) file for more details.
123
+ assets=AssetsConfig(
124
+ assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
125
+ asset_id="ur5e",
126
+ ),
127
+ base_config=DataConfig(
128
+ # This flag determines whether we load the prompt (i.e. the task instruction) from the
129
+ # ``task`` field in the LeRobot dataset. The recommended setting is True.
130
+ prompt_from_task=True,
131
+ ),
132
+ ),
133
+ # Load the pi0 base model checkpoint.
134
+ weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"),
135
+ num_train_steps=30_000,
136
+ )
137
+ ```
138
+
139
+
140
+
141
+
142
+
capvector-pi05/packages/openpi-client/pyproject.toml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "openpi-client"
3
+ version = "0.1.0"
4
+ requires-python = ">=3.7"
5
+ dependencies = [
6
+ "dm-tree>=0.1.8",
7
+ "msgpack>=1.0.5",
8
+ "numpy>=1.22.4,<2.0.0",
9
+ "pillow>=9.0.0",
10
+ "tree>=0.2.4",
11
+ "websockets>=11.0",
12
+ ]
13
+
14
+ [build-system]
15
+ requires = ["hatchling"]
16
+ build-backend = "hatchling.build"
17
+
18
+ [tool.uv]
19
+ dev-dependencies = ["pytest>=8.3.4"]
20
+
21
+ [tool.ruff]
22
+ line-length = 120
23
+ target-version = "py37"
capvector-pi05/packages/openpi-client/src/openpi_client/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.1.0"
capvector-pi05/packages/openpi-client/src/openpi_client/action_chunk_broker.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import numpy as np
4
+ import tree
5
+ from typing_extensions import override
6
+
7
+ from openpi_client import base_policy as _base_policy
8
+
9
+
10
+ class ActionChunkBroker(_base_policy.BasePolicy):
11
+ """Wraps a policy to return action chunks one-at-a-time.
12
+
13
+ Assumes that the first dimension of all action fields is the chunk size.
14
+
15
+ A new inference call to the inner policy is only made when the current
16
+ list of chunks is exhausted.
17
+ """
18
+
19
+ def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int):
20
+ self._policy = policy
21
+ self._action_horizon = action_horizon
22
+ self._cur_step: int = 0
23
+
24
+ self._last_results: Dict[str, np.ndarray] | None = None
25
+
26
+ @override
27
+ def infer(self, obs: Dict) -> Dict: # noqa: UP006
28
+ if self._last_results is None:
29
+ self._last_results = self._policy.infer(obs)
30
+ self._cur_step = 0
31
+
32
+ def slicer(x):
33
+ if isinstance(x, np.ndarray):
34
+ return x[self._cur_step, ...]
35
+ else:
36
+ return x
37
+
38
+ results = tree.map_structure(slicer, self._last_results)
39
+ self._cur_step += 1
40
+
41
+ if self._cur_step >= self._action_horizon:
42
+ self._last_results = None
43
+
44
+ return results
45
+
46
+ @override
47
+ def reset(self) -> None:
48
+ self._policy.reset()
49
+ self._last_results = None
50
+ self._cur_step = 0
capvector-pi05/packages/openpi-client/src/openpi_client/base_policy.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import Dict
3
+
4
+
5
+ class BasePolicy(abc.ABC):
6
+ @abc.abstractmethod
7
+ def infer(self, obs: Dict) -> Dict:
8
+ """Infer actions from observations."""
9
+
10
+ def reset(self) -> None:
11
+ """Reset the policy to its initial state."""
12
+ pass
capvector-pi05/packages/openpi-client/src/openpi_client/image_tools.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+
4
+
5
+ def convert_to_uint8(img: np.ndarray) -> np.ndarray:
6
+ """Converts an image to uint8 if it is a float image.
7
+
8
+ This is important for reducing the size of the image when sending it over the network.
9
+ """
10
+ if np.issubdtype(img.dtype, np.floating):
11
+ img = (255 * img).astype(np.uint8)
12
+ return img
13
+
14
+
15
+ def resize_with_pad(images: np.ndarray, height: int, width: int, method=Image.BILINEAR, return_mask=False) -> np.ndarray:
16
+ """Replicates tf.image.resize_with_pad for multiple images using PIL. Resizes a batch of images to a target height.
17
+
18
+ Args:
19
+ images: A batch of images in [..., height, width, channel] format.
20
+ height: The target height of the image.
21
+ width: The target width of the image.
22
+ method: The interpolation method to use. Default is bilinear.
23
+
24
+ Returns:
25
+ The resized images in [..., height, width, channel].
26
+ """
27
+ # If the images are already the correct size, return them as is.
28
+ if images.shape[-3:-1] == (height, width):
29
+ if return_mask:
30
+ img_padding_mask = np.ones((*images.shape[:-3], height, width), dtype=bool)
31
+ return images, img_padding_mask
32
+ return images
33
+
34
+ original_shape = images.shape
35
+
36
+ images = images.reshape(-1, *original_shape[-3:])
37
+
38
+ resized_results = [
39
+ _resize_with_pad_pil(Image.fromarray(im), height, width, method=method) for im in images
40
+ ]
41
+ resized_images, img_padding_mask = zip(*resized_results)
42
+ resized_images = np.stack(resized_images)
43
+ img_padding_mask = np.stack(img_padding_mask)
44
+
45
+ if return_mask:
46
+ return (
47
+ resized_images.reshape(*original_shape[:-3], *resized_images.shape[-3:]),
48
+ img_padding_mask.reshape(*original_shape[:-3], *img_padding_mask.shape[-2:]),
49
+ )
50
+ else:
51
+ return resized_images.reshape(*original_shape[:-3], *resized_images.shape[-3:])
52
+
53
+
54
+ def _resize_with_pad_pil(image: Image.Image, height: int, width: int, method: int) -> Image.Image:
55
+ """Replicates tf.image.resize_with_pad for one image using PIL. Resizes an image to a target height and
56
+ width without distortion by padding with zeros.
57
+
58
+ Unlike the jax version, note that PIL uses [width, height, channel] ordering instead of [batch, h, w, c].
59
+ """
60
+ cur_width, cur_height = image.size
61
+ if cur_width == width and cur_height == height:
62
+ return image # No need to resize if the image is already the correct size.
63
+
64
+ ratio = max(cur_width / width, cur_height / height)
65
+ resized_height = int(cur_height / ratio)
66
+ resized_width = int(cur_width / ratio)
67
+ resized_image = image.resize((resized_width, resized_height), resample=method)
68
+
69
+ zero_image = Image.new(resized_image.mode, (width, height), 0)
70
+ pad_height = max(0, int((height - resized_height) / 2))
71
+ pad_width = max(0, int((width - resized_width) / 2))
72
+ zero_image.paste(resized_image, (pad_width, pad_height))
73
+ assert zero_image.size == (width, height)
74
+
75
+ img_padding_mask = np.zeros((height, width), dtype=bool)
76
+ img_padding_mask[pad_height:pad_height+resized_height, pad_width:pad_width+resized_width] = True
77
+
78
+ return zero_image, img_padding_mask
capvector-pi05/packages/openpi-client/src/openpi_client/image_tools_test.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import openpi_client.image_tools as image_tools
4
+
5
+
6
+ def test_resize_with_pad_shapes():
7
+ # Test case 1: Resize image with larger dimensions
8
+ images = np.zeros((2, 10, 10, 3), dtype=np.uint8) # Input images of shape (batch_size, height, width, channels)
9
+ height = 20
10
+ width = 20
11
+ resized_images = image_tools.resize_with_pad(images, height, width)
12
+ assert resized_images.shape == (2, height, width, 3)
13
+ assert np.all(resized_images == 0)
14
+
15
+ # Test case 2: Resize image with smaller dimensions
16
+ images = np.zeros((3, 30, 30, 3), dtype=np.uint8)
17
+ height = 15
18
+ width = 15
19
+ resized_images = image_tools.resize_with_pad(images, height, width)
20
+ assert resized_images.shape == (3, height, width, 3)
21
+ assert np.all(resized_images == 0)
22
+
23
+ # Test case 3: Resize image with the same dimensions
24
+ images = np.zeros((1, 50, 50, 3), dtype=np.uint8)
25
+ height = 50
26
+ width = 50
27
+ resized_images = image_tools.resize_with_pad(images, height, width)
28
+ assert resized_images.shape == (1, height, width, 3)
29
+ assert np.all(resized_images == 0)
30
+
31
+ # Test case 3: Resize image with odd-numbered padding
32
+ images = np.zeros((1, 256, 320, 3), dtype=np.uint8)
33
+ height = 60
34
+ width = 80
35
+ resized_images = image_tools.resize_with_pad(images, height, width)
36
+ assert resized_images.shape == (1, height, width, 3)
37
+ assert np.all(resized_images == 0)
capvector-pi05/packages/openpi-client/src/openpi_client/msgpack_numpy.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Adds NumPy array support to msgpack.
2
+
3
+ msgpack is good for (de)serializing data over a network for multiple reasons:
4
+ - msgpack is secure (as opposed to pickle/dill/etc which allow for arbitrary code execution)
5
+ - msgpack is widely used and has good cross-language support
6
+ - msgpack does not require a schema (as opposed to protobuf/flatbuffers/etc) which is convenient in dynamically typed
7
+ languages like Python and JavaScript
8
+ - msgpack is fast and efficient (as opposed to readable formats like JSON/YAML/etc); I found that msgpack was ~4x faster
9
+ than pickle for serializing large arrays using the below strategy
10
+
11
+ The code below is adapted from https://github.com/lebedov/msgpack-numpy. The reason not to use that library directly is
12
+ that it falls back to pickle for object arrays.
13
+ """
14
+
15
+ import functools
16
+
17
+ import msgpack
18
+ import numpy as np
19
+
20
+
21
+ def pack_array(obj):
22
+ if (isinstance(obj, (np.ndarray, np.generic))) and obj.dtype.kind in ("V", "O", "c"):
23
+ raise ValueError(f"Unsupported dtype: {obj.dtype}")
24
+
25
+ if isinstance(obj, np.ndarray):
26
+ return {
27
+ b"__ndarray__": True,
28
+ b"data": obj.tobytes(),
29
+ b"dtype": obj.dtype.str,
30
+ b"shape": obj.shape,
31
+ }
32
+
33
+ if isinstance(obj, np.generic):
34
+ return {
35
+ b"__npgeneric__": True,
36
+ b"data": obj.item(),
37
+ b"dtype": obj.dtype.str,
38
+ }
39
+
40
+ return obj
41
+
42
+
43
+ def unpack_array(obj):
44
+ if b"__ndarray__" in obj:
45
+ return np.ndarray(buffer=obj[b"data"], dtype=np.dtype(obj[b"dtype"]), shape=obj[b"shape"])
46
+
47
+ if b"__npgeneric__" in obj:
48
+ return np.dtype(obj[b"dtype"]).type(obj[b"data"])
49
+
50
+ return obj
51
+
52
+
53
+ Packer = functools.partial(msgpack.Packer, default=pack_array)
54
+ packb = functools.partial(msgpack.packb, default=pack_array)
55
+
56
+ Unpacker = functools.partial(msgpack.Unpacker, object_hook=unpack_array)
57
+ unpackb = functools.partial(msgpack.unpackb, object_hook=unpack_array)
capvector-pi05/packages/openpi-client/src/openpi_client/msgpack_numpy_test.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pytest
3
+ import tree
4
+
5
+ from openpi_client import msgpack_numpy
6
+
7
+
8
+ def _check(expected, actual):
9
+ if isinstance(expected, np.ndarray):
10
+ assert expected.shape == actual.shape
11
+ assert expected.dtype == actual.dtype
12
+ assert np.array_equal(expected, actual, equal_nan=expected.dtype.kind == "f")
13
+ else:
14
+ assert expected == actual
15
+
16
+
17
+ @pytest.mark.parametrize(
18
+ "data",
19
+ [
20
+ 1, # int
21
+ 1.0, # float
22
+ "hello", # string
23
+ np.bool_(True), # boolean scalar
24
+ np.array([1, 2, 3])[0], # int scalar
25
+ np.str_("asdf"), # string scalar
26
+ [1, 2, 3], # list
27
+ {"key": "value"}, # dict
28
+ {"key": [1, 2, 3]}, # nested dict
29
+ np.array(1.0), # 0D array
30
+ np.array([1, 2, 3], dtype=np.int32), # 1D integer array
31
+ np.array(["asdf", "qwer"]), # string array
32
+ np.array([True, False]), # boolean array
33
+ np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), # 2D float array
34
+ np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.int16), # 3D integer array
35
+ np.array([np.nan, np.inf, -np.inf]), # special float values
36
+ {"arr": np.array([1, 2, 3]), "nested": {"arr": np.array([4, 5, 6])}}, # nested dict with arrays
37
+ [np.array([1, 2]), np.array([3, 4])], # list of arrays
38
+ np.zeros((3, 4, 5), dtype=np.float32), # 3D zeros
39
+ np.ones((2, 3), dtype=np.float64), # 2D ones with double precision
40
+ ],
41
+ )
42
+ def test_pack_unpack(data):
43
+ packed = msgpack_numpy.packb(data)
44
+ unpacked = msgpack_numpy.unpackb(packed)
45
+ tree.map_structure(_check, data, unpacked)
capvector-pi05/packages/openpi-client/src/openpi_client/runtime/agent.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+
3
+
4
+ class Agent(abc.ABC):
5
+ """An Agent is the thing with agency, i.e. the entity that makes decisions.
6
+
7
+ Agents receive observations about the state of the world, and return actions
8
+ to take in response.
9
+ """
10
+
11
+ @abc.abstractmethod
12
+ def get_action(self, observation: dict) -> dict:
13
+ """Query the agent for the next action."""
14
+
15
+ @abc.abstractmethod
16
+ def reset(self) -> None:
17
+ """Reset the agent to its initial state."""
capvector-pi05/packages/openpi-client/src/openpi_client/runtime/agents/policy_agent.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing_extensions import override
2
+
3
+ from openpi_client import base_policy as _base_policy
4
+ from openpi_client.runtime import agent as _agent
5
+
6
+
7
+ class PolicyAgent(_agent.Agent):
8
+ """An agent that uses a policy to determine actions."""
9
+
10
+ def __init__(self, policy: _base_policy.BasePolicy) -> None:
11
+ self._policy = policy
12
+
13
+ @override
14
+ def get_action(self, observation: dict) -> dict:
15
+ return self._policy.infer(observation)
16
+
17
+ def reset(self) -> None:
18
+ self._policy.reset()
capvector-pi05/packages/openpi-client/src/openpi_client/runtime/environment.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+
3
+
4
+ class Environment(abc.ABC):
5
+ """An Environment represents the robot and the environment it inhabits.
6
+
7
+ The primary contract of environments is that they can be queried for observations
8
+ about their state, and have actions applied to them to change that state.
9
+ """
10
+
11
+ @abc.abstractmethod
12
+ def reset(self) -> None:
13
+ """Reset the environment to its initial state.
14
+
15
+ This will be called once before starting each episode.
16
+ """
17
+
18
+ @abc.abstractmethod
19
+ def is_episode_complete(self) -> bool:
20
+ """Allow the environment to signal that the episode is complete.
21
+
22
+ This will be called after each step. It should return `True` if the episode is
23
+ complete (either successfully or unsuccessfully), and `False` otherwise.
24
+ """
25
+
26
+ @abc.abstractmethod
27
+ def get_observation(self) -> dict:
28
+ """Query the environment for the current state."""
29
+
30
+ @abc.abstractmethod
31
+ def apply_action(self, action: dict) -> None:
32
+ """Take an action in the environment."""
capvector-pi05/packages/openpi-client/src/openpi_client/runtime/runtime.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import threading
3
+ import time
4
+
5
+ from openpi_client.runtime import agent as _agent
6
+ from openpi_client.runtime import environment as _environment
7
+ from openpi_client.runtime import subscriber as _subscriber
8
+
9
+
10
+ class Runtime:
11
+ """The core module orchestrating interactions between key components of the system."""
12
+
13
+ def __init__(
14
+ self,
15
+ environment: _environment.Environment,
16
+ agent: _agent.Agent,
17
+ subscribers: list[_subscriber.Subscriber],
18
+ max_hz: float = 0,
19
+ num_episodes: int = 1,
20
+ max_episode_steps: int = 0,
21
+ ) -> None:
22
+ self._environment = environment
23
+ self._agent = agent
24
+ self._subscribers = subscribers
25
+ self._max_hz = max_hz
26
+ self._num_episodes = num_episodes
27
+ self._max_episode_steps = max_episode_steps
28
+
29
+ self._in_episode = False
30
+ self._episode_steps = 0
31
+
32
+ def run(self) -> None:
33
+ """Runs the runtime loop continuously until stop() is called or the environment is done."""
34
+ for _ in range(self._num_episodes):
35
+ self._run_episode()
36
+
37
+ # Final reset, this is important for real environments to move the robot to its home position.
38
+ self._environment.reset()
39
+
40
+ def run_in_new_thread(self) -> threading.Thread:
41
+ """Runs the runtime loop in a new thread."""
42
+ thread = threading.Thread(target=self.run)
43
+ thread.start()
44
+ return thread
45
+
46
+ def mark_episode_complete(self) -> None:
47
+ """Marks the end of an episode."""
48
+ self._in_episode = False
49
+
50
+ def _run_episode(self) -> None:
51
+ """Runs a single episode."""
52
+ logging.info("Starting episode...")
53
+ self._environment.reset()
54
+ self._agent.reset()
55
+ for subscriber in self._subscribers:
56
+ subscriber.on_episode_start()
57
+
58
+ self._in_episode = True
59
+ self._episode_steps = 0
60
+ step_time = 1 / self._max_hz if self._max_hz > 0 else 0
61
+ last_step_time = time.time()
62
+
63
+ while self._in_episode:
64
+ self._step()
65
+ self._episode_steps += 1
66
+
67
+ # Sleep to maintain the desired frame rate
68
+ now = time.time()
69
+ dt = now - last_step_time
70
+ if dt < step_time:
71
+ time.sleep(step_time - dt)
72
+ last_step_time = time.time()
73
+ else:
74
+ last_step_time = now
75
+
76
+ logging.info("Episode completed.")
77
+ for subscriber in self._subscribers:
78
+ subscriber.on_episode_end()
79
+
80
+ def _step(self) -> None:
81
+ """A single step of the runtime loop."""
82
+ observation = self._environment.get_observation()
83
+ action = self._agent.get_action(observation)
84
+ self._environment.apply_action(action)
85
+
86
+ for subscriber in self._subscribers:
87
+ subscriber.on_step(observation, action)
88
+
89
+ if self._environment.is_episode_complete() or (
90
+ self._max_episode_steps > 0 and self._episode_steps >= self._max_episode_steps
91
+ ):
92
+ self.mark_episode_complete()
capvector-pi05/packages/openpi-client/src/openpi_client/runtime/subscriber.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+
3
+
4
+ class Subscriber(abc.ABC):
5
+ """Subscribes to events in the runtime.
6
+
7
+ Subscribers can be used to save data, visualize, etc.
8
+ """
9
+
10
+ @abc.abstractmethod
11
+ def on_episode_start(self) -> None:
12
+ """Called when an episode starts."""
13
+
14
+ @abc.abstractmethod
15
+ def on_step(self, observation: dict, action: dict) -> None:
16
+ """Append a step to the episode."""
17
+
18
+ @abc.abstractmethod
19
+ def on_episode_end(self) -> None:
20
+ """Called when an episode ends."""
capvector-pi05/packages/openpi-client/src/openpi_client/websocket_client_policy.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import time
3
+ from typing import Dict, Optional, Tuple
4
+
5
+ from typing_extensions import override
6
+ import websockets.sync.client
7
+
8
+ from openpi_client import base_policy as _base_policy
9
+ from openpi_client import msgpack_numpy
10
+
11
+
12
+ class WebsocketClientPolicy(_base_policy.BasePolicy):
13
+ """Implements the Policy interface by communicating with a server over websocket.
14
+
15
+ See WebsocketPolicyServer for a corresponding server implementation.
16
+ """
17
+
18
+ def __init__(self, host: str = "0.0.0.0", port: Optional[int] = None, api_key: Optional[str] = None) -> None:
19
+ self._uri = f"ws://{host}"
20
+ if port is not None:
21
+ self._uri += f":{port}"
22
+ self._packer = msgpack_numpy.Packer()
23
+ self._api_key = api_key
24
+ self._ws, self._server_metadata = self._wait_for_server()
25
+
26
+ def get_server_metadata(self) -> Dict:
27
+ return self._server_metadata
28
+
29
+ def _wait_for_server(self) -> Tuple[websockets.sync.client.ClientConnection, Dict]:
30
+ logging.info(f"Waiting for server at {self._uri}...")
31
+ while True:
32
+ try:
33
+ headers = {"Authorization": f"Api-Key {self._api_key}"} if self._api_key else None
34
+ conn = websockets.sync.client.connect(
35
+ self._uri, compression=None, max_size=None, additional_headers=headers
36
+ )
37
+ metadata = msgpack_numpy.unpackb(conn.recv())
38
+ return conn, metadata
39
+ except ConnectionRefusedError:
40
+ logging.info("Still waiting for server...")
41
+ time.sleep(5)
42
+
43
+ @override
44
+ def infer(self, obs: Dict) -> Dict: # noqa: UP006
45
+ data = self._packer.pack(obs)
46
+ self._ws.send(data)
47
+ response = self._ws.recv()
48
+ if isinstance(response, str):
49
+ # we're expecting bytes; if the server sends a string, it's an error.
50
+ raise RuntimeError(f"Error in inference server:\n{response}")
51
+ return msgpack_numpy.unpackb(response)
52
+
53
+ @override
54
+ def reset(self) -> None:
55
+ pass
capvector-pi05/scripts/__init__.py ADDED
File without changes
capvector-pi05/scripts/compute_norm_stats.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Compute normalization statistics for a config.
2
+
3
+ This script is used to compute the normalization statistics for a given config. It
4
+ will compute the mean and standard deviation of the data in the dataset and save it
5
+ to the config assets directory.
6
+ """
7
+
8
+ import numpy as np
9
+ import tqdm
10
+ import tyro
11
+
12
+ import openpi.models.model as _model
13
+ import openpi.shared.normalize as normalize
14
+ import openpi.training.config as _config
15
+ import openpi.training.data_loader as _data_loader
16
+ import openpi.transforms as transforms
17
+
18
+
19
+ class RemoveStrings(transforms.DataTransformFn):
20
+ def __call__(self, x: dict) -> dict:
21
+ return {k: v for k, v in x.items() if not np.issubdtype(np.asarray(v).dtype, np.str_)}
22
+
23
+
24
+ def create_torch_dataloader(
25
+ data_config: _config.DataConfig,
26
+ action_horizon: int,
27
+ batch_size: int,
28
+ model_config: _model.BaseModelConfig,
29
+ num_workers: int,
30
+ max_frames: int | None = None,
31
+ ) -> tuple[_data_loader.Dataset, int]:
32
+ if data_config.repo_id is None:
33
+ raise ValueError("Data config must have a repo_id")
34
+ dataset = _data_loader.create_torch_dataset(data_config, action_horizon, model_config)
35
+ dataset = _data_loader.TransformedDataset(
36
+ dataset,
37
+ [
38
+ *data_config.repack_transforms.inputs,
39
+ *data_config.data_transforms.inputs,
40
+ # Remove strings since they are not supported by JAX and are not needed to compute norm stats.
41
+ RemoveStrings(),
42
+ ],
43
+ )
44
+ if max_frames is not None and max_frames < len(dataset):
45
+ num_batches = max_frames // batch_size
46
+ shuffle = True
47
+ else:
48
+ num_batches = len(dataset) // batch_size
49
+ shuffle = False
50
+ data_loader = _data_loader.TorchDataLoader(
51
+ dataset,
52
+ local_batch_size=batch_size,
53
+ num_workers=num_workers,
54
+ shuffle=shuffle,
55
+ num_batches=num_batches,
56
+ )
57
+ return data_loader, num_batches
58
+
59
+
60
+ def create_rlds_dataloader(
61
+ data_config: _config.DataConfig,
62
+ action_horizon: int,
63
+ batch_size: int,
64
+ max_frames: int | None = None,
65
+ ) -> tuple[_data_loader.Dataset, int]:
66
+ dataset = _data_loader.create_rlds_dataset(data_config, action_horizon, batch_size, shuffle=False)
67
+ dataset = _data_loader.IterableTransformedDataset(
68
+ dataset,
69
+ [
70
+ *data_config.repack_transforms.inputs,
71
+ *data_config.data_transforms.inputs,
72
+ # Remove strings since they are not supported by JAX and are not needed to compute norm stats.
73
+ RemoveStrings(),
74
+ ],
75
+ is_batched=True,
76
+ )
77
+ if max_frames is not None and max_frames < len(dataset):
78
+ num_batches = max_frames // batch_size
79
+ else:
80
+ # NOTE: this length is currently hard-coded for DROID.
81
+ num_batches = len(dataset) // batch_size
82
+ data_loader = _data_loader.RLDSDataLoader(
83
+ dataset,
84
+ num_batches=num_batches,
85
+ )
86
+ return data_loader, num_batches
87
+
88
+
89
+ def main(config_name: str, max_frames: int | None = None):
90
+ config = _config.get_config(config_name)
91
+ data_config = config.data.create(config.assets_dirs, config.model)
92
+
93
+ if data_config.rlds_data_dir is not None:
94
+ data_loader, num_batches = create_rlds_dataloader(
95
+ data_config, config.model.action_horizon, config.batch_size, max_frames
96
+ )
97
+ else:
98
+ data_loader, num_batches = create_torch_dataloader(
99
+ data_config, config.model.action_horizon, config.batch_size, config.model, config.num_workers, max_frames
100
+ )
101
+
102
+ keys = ["state", "actions"]
103
+ stats = {key: normalize.RunningStats() for key in keys}
104
+
105
+ for batch in tqdm.tqdm(data_loader, total=num_batches, desc="Computing stats"):
106
+ for key in keys:
107
+ stats[key].update(np.asarray(batch[key]))
108
+
109
+ norm_stats = {key: stats.get_statistics() for key, stats in stats.items()}
110
+
111
+ output_path = config.assets_dirs / data_config.repo_id
112
+ print(f"Writing stats to: {output_path}")
113
+ normalize.save(output_path, norm_stats)
114
+
115
+
116
+ if __name__ == "__main__":
117
+ tyro.cli(main)
capvector-pi05/scripts/docker/compose.yml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Run with:
2
+ # docker compose -f scripts/docker/compose.yml up --build
3
+ services:
4
+ openpi_server:
5
+ image: openpi_server
6
+ build:
7
+ context: ../..
8
+ dockerfile: scripts/docker/serve_policy.Dockerfile
9
+ init: true
10
+ tty: true
11
+ network_mode: host
12
+ # Populate configured openpi data home to /openpi_assets inside the container.
13
+ # Populate aws credential inside the container.
14
+ volumes:
15
+ - $PWD:/app
16
+ - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
17
+ environment:
18
+ - SERVER_ARGS
19
+ - OPENPI_DATA_HOME=/openpi_assets
20
+ - IS_DOCKER=true
21
+
22
+ # Comment out this block if not running on a machine with GPUs.
23
+ deploy:
24
+ resources:
25
+ reservations:
26
+ devices:
27
+ - driver: nvidia
28
+ count: 1
29
+ capabilities: [gpu]
capvector-pi05/scripts/docker/install_docker_ubuntu22.sh ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Add Docker's official GPG key:
4
+ sudo apt-get update
5
+ sudo apt-get install -y ca-certificates curl
6
+ sudo install -m 0755 -d /etc/apt/keyrings
7
+ sudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc
8
+ sudo chmod a+r /etc/apt/keyrings/docker.asc
9
+
10
+ # Add the repository to Apt sources:
11
+ echo \
12
+ "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \
13
+ $(. /etc/os-release && echo "$VERSION_CODENAME") stable" |
14
+ sudo tee /etc/apt/sources.list.d/docker.list >/dev/null
15
+ sudo apt-get update
16
+
17
+ sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin
18
+
19
+ # Add current user to the 'docker' group, which allows them to use docker commands (docker build, docker run, etc).
20
+ # See https://docs.docker.com/engine/install/linux-postinstall/
21
+ username=$(whoami)
22
+ sudo usermod -aG docker $username
23
+
24
+ # Configure docker to start automatically on system boot.
25
+ sudo systemctl enable docker.service
26
+ sudo systemctl enable containerd.service
27
+
28
+ # https://forums.docker.com/t/docker-credential-desktop-exe-executable-file-not-found-in-path-using-wsl2/100225/5
29
+ if [ ~/.docker/config.json ]; then
30
+ sed -i 's/credsStore/credStore/g' ~/.docker/config.json
31
+ fi
32
+
33
+ echo ""
34
+ echo "********************************************************************"
35
+ echo "**** Restart to allow Docker permission changes to take effect. ****"
36
+ echo "********************************************************************"
37
+ echo ""
capvector-pi05/scripts/docker/install_nvidia_container_toolkit.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Installs the NVIDIA Container Toolkit, which allows Docker containers to access NVIDIA GPUs.
4
+ # NVIDIA's official documentation: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
5
+
6
+ curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg &&
7
+ curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list |
8
+ sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' |
9
+ sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
10
+
11
+ # NVIDIA's documenation omits 'sudo' in the following command, but it is required.
12
+ sudo sed -i -e '/experimental/ s/^#//g' /etc/apt/sources.list.d/nvidia-container-toolkit.list
13
+ sudo apt-get update
14
+ sudo apt-get install -y nvidia-container-toolkit
15
+
16
+ sudo nvidia-ctk runtime configure --runtime=docker
17
+ sudo systemctl restart docker
capvector-pi05/scripts/docker/serve_policy.Dockerfile ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile for serving a PI policy.
2
+ # Based on UV's instructions: https://docs.astral.sh/uv/guides/integration/docker/#developing-in-a-container
3
+
4
+ # Build the container:
5
+ # docker build . -t openpi_server -f scripts/docker/serve_policy.Dockerfile
6
+
7
+ # Run the container:
8
+ # docker run --rm -it --network=host -v .:/app --gpus=all openpi_server /bin/bash
9
+
10
+ FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0
11
+ COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
12
+
13
+ WORKDIR /app
14
+
15
+ # Needed because LeRobot uses git-lfs.
16
+ RUN apt-get update && apt-get install -y git git-lfs linux-headers-generic build-essential clang
17
+
18
+ # Copy from the cache instead of linking since it's a mounted volume
19
+ ENV UV_LINK_MODE=copy
20
+
21
+ # Write the virtual environment outside of the project directory so it doesn't
22
+ # leak out of the container when we mount the application code.
23
+ ENV UV_PROJECT_ENVIRONMENT=/.venv
24
+
25
+ # Install the project's dependencies using the lockfile and settings
26
+ RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
27
+ RUN --mount=type=cache,target=/root/.cache/uv \
28
+ --mount=type=bind,source=uv.lock,target=uv.lock \
29
+ --mount=type=bind,source=pyproject.toml,target=pyproject.toml \
30
+ --mount=type=bind,source=packages/openpi-client/pyproject.toml,target=packages/openpi-client/pyproject.toml \
31
+ --mount=type=bind,source=packages/openpi-client/src,target=packages/openpi-client/src \
32
+ GIT_LFS_SKIP_SMUDGE=1 uv sync --frozen --no-install-project --no-dev
33
+
34
+ # Copy transformers_replace files while preserving directory structure
35
+ COPY src/openpi/models_pytorch/transformers_replace/ /tmp/transformers_replace/
36
+ RUN /.venv/bin/python -c "import transformers; print(transformers.__file__)" | xargs dirname | xargs -I{} cp -r /tmp/transformers_replace/* {} && rm -rf /tmp/transformers_replace
37
+
38
+ CMD /bin/bash -c "uv run scripts/serve_policy.py $SERVER_ARGS"
capvector-pi05/scripts/serve_policy.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import enum
3
+ import logging
4
+ import socket
5
+
6
+ import tyro
7
+
8
+ from openpi.policies import policy as _policy
9
+ from openpi.policies import policy_config as _policy_config
10
+ from openpi.serving import websocket_policy_server
11
+ from openpi.training import config as _config
12
+
13
+
14
+ class EnvMode(enum.Enum):
15
+ """Supported environments."""
16
+
17
+ ALOHA = "aloha"
18
+ ALOHA_SIM = "aloha_sim"
19
+ DROID = "droid"
20
+ LIBERO = "libero"
21
+
22
+
23
+ @dataclasses.dataclass
24
+ class Checkpoint:
25
+ """Load a policy from a trained checkpoint."""
26
+
27
+ # Training config name (e.g., "pi0_aloha_sim").
28
+ config: str
29
+ # Checkpoint directory (e.g., "checkpoints/pi0_aloha_sim/exp/10000").
30
+ dir: str
31
+
32
+
33
+ @dataclasses.dataclass
34
+ class Default:
35
+ """Use the default policy for the given environment."""
36
+
37
+
38
+ @dataclasses.dataclass
39
+ class Args:
40
+ """Arguments for the serve_policy script."""
41
+
42
+ # Environment to serve the policy for. This is only used when serving default policies.
43
+ env: EnvMode = EnvMode.ALOHA_SIM
44
+
45
+ # If provided, will be used in case the "prompt" key is not present in the data, or if the model doesn't have a default
46
+ # prompt.
47
+ default_prompt: str | None = None
48
+
49
+ # Port to serve the policy on.
50
+ port: int = 8000
51
+ # Record the policy's behavior for debugging.
52
+ record: bool = False
53
+
54
+ # Specifies how to load the policy. If not provided, the default policy for the environment will be used.
55
+ policy: Checkpoint | Default = dataclasses.field(default_factory=Default)
56
+
57
+
58
+ # Default checkpoints that should be used for each environment.
59
+ DEFAULT_CHECKPOINT: dict[EnvMode, Checkpoint] = {
60
+ EnvMode.ALOHA: Checkpoint(
61
+ config="pi05_aloha",
62
+ dir="gs://openpi-assets/checkpoints/pi05_base",
63
+ ),
64
+ EnvMode.ALOHA_SIM: Checkpoint(
65
+ config="pi0_aloha_sim",
66
+ dir="gs://openpi-assets/checkpoints/pi0_aloha_sim",
67
+ ),
68
+ EnvMode.DROID: Checkpoint(
69
+ config="pi05_droid",
70
+ dir="gs://openpi-assets/checkpoints/pi05_droid",
71
+ ),
72
+ EnvMode.LIBERO: Checkpoint(
73
+ config="pi05_libero",
74
+ dir="gs://openpi-assets/checkpoints/pi05_libero",
75
+ ),
76
+ }
77
+
78
+
79
+ def create_default_policy(env: EnvMode, *, default_prompt: str | None = None) -> _policy.Policy:
80
+ """Create a default policy for the given environment."""
81
+ if checkpoint := DEFAULT_CHECKPOINT.get(env):
82
+ return _policy_config.create_trained_policy(
83
+ _config.get_config(checkpoint.config), checkpoint.dir, default_prompt=default_prompt
84
+ )
85
+ raise ValueError(f"Unsupported environment mode: {env}")
86
+
87
+
88
+ def create_policy(args: Args) -> _policy.Policy:
89
+ """Create a policy from the given arguments."""
90
+ match args.policy:
91
+ case Checkpoint():
92
+ return _policy_config.create_trained_policy(
93
+ _config.get_config(args.policy.config), args.policy.dir, default_prompt=args.default_prompt
94
+ )
95
+ case Default():
96
+ return create_default_policy(args.env, default_prompt=args.default_prompt)
97
+
98
+
99
+ def main(args: Args) -> None:
100
+ policy = create_policy(args)
101
+ policy_metadata = policy.metadata
102
+
103
+ # Record the policy's behavior.
104
+ if args.record:
105
+ policy = _policy.PolicyRecorder(policy, "policy_records")
106
+
107
+ hostname = socket.gethostname()
108
+ local_ip = socket.gethostbyname(hostname)
109
+ logging.info("Creating server (host: %s, ip: %s)", hostname, local_ip)
110
+
111
+ server = websocket_policy_server.WebsocketPolicyServer(
112
+ policy=policy,
113
+ host="0.0.0.0",
114
+ port=args.port,
115
+ metadata=policy_metadata,
116
+ )
117
+ server.serve_forever()
118
+
119
+
120
+ if __name__ == "__main__":
121
+ logging.basicConfig(level=logging.INFO, force=True)
122
+ main(tyro.cli(Args))
capvector-pi05/scripts/train.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import functools
3
+ import logging
4
+ import platform
5
+ from typing import Any
6
+
7
+ import etils.epath as epath
8
+ import flax.nnx as nnx
9
+ from flax.training import common_utils
10
+ import flax.traverse_util as traverse_util
11
+ import jax
12
+ import jax.experimental
13
+ import jax.numpy as jnp
14
+ import numpy as np
15
+ import optax
16
+ import tqdm_loggable.auto as tqdm
17
+ import wandb
18
+
19
+ import openpi.models.model as _model
20
+ import openpi.shared.array_typing as at
21
+ import openpi.shared.nnx_utils as nnx_utils
22
+ import openpi.training.checkpoints as _checkpoints
23
+ import openpi.training.config as _config
24
+ import openpi.training.data_loader as _data_loader
25
+ import openpi.training.optimizer as _optimizer
26
+ import openpi.training.sharding as sharding
27
+ import openpi.training.utils as training_utils
28
+ import openpi.training.weight_loaders as _weight_loaders
29
+
30
+
31
+ def init_logging():
32
+ """Custom logging format for better readability."""
33
+ level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
34
+
35
+ class CustomFormatter(logging.Formatter):
36
+ def format(self, record):
37
+ record.levelname = level_mapping.get(record.levelname, record.levelname)
38
+ return super().format(record)
39
+
40
+ formatter = CustomFormatter(
41
+ fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
42
+ datefmt="%H:%M:%S",
43
+ )
44
+
45
+ logger = logging.getLogger()
46
+ logger.setLevel(logging.INFO)
47
+ logger.handlers[0].setFormatter(formatter)
48
+
49
+
50
+ def init_wandb(config: _config.TrainConfig, *, resuming: bool, log_code: bool = False, enabled: bool = True):
51
+ if not enabled:
52
+ wandb.init(mode="disabled")
53
+ return
54
+
55
+ ckpt_dir = config.checkpoint_dir
56
+ if not ckpt_dir.exists():
57
+ raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
58
+ if resuming:
59
+ run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
60
+ wandb.init(id=run_id, resume="must", project=config.project_name)
61
+ else:
62
+ wandb.init(
63
+ name=config.exp_name,
64
+ config=dataclasses.asdict(config),
65
+ project=config.project_name,
66
+ )
67
+ (ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
68
+
69
+ if log_code:
70
+ wandb.run.log_code(epath.Path(__file__).parent.parent)
71
+
72
+
73
+ def _load_weights_and_validate(loader: _weight_loaders.WeightLoader, params_shape: at.Params) -> at.Params:
74
+ """Loads and validates the weights. Returns a loaded subset of the weights."""
75
+ loaded_params = loader.load(params_shape)
76
+ at.check_pytree_equality(expected=params_shape, got=loaded_params, check_shapes=True, check_dtypes=True)
77
+
78
+ # Remove jax.ShapeDtypeStruct from the loaded params. This makes sure that only the loaded params are returned.
79
+ return traverse_util.unflatten_dict(
80
+ {k: v for k, v in traverse_util.flatten_dict(loaded_params).items() if not isinstance(v, jax.ShapeDtypeStruct)}
81
+ )
82
+
83
+
84
+ @at.typecheck
85
+ def init_train_state(
86
+ config: _config.TrainConfig, init_rng: at.KeyArrayLike, mesh: jax.sharding.Mesh, *, resume: bool
87
+ ) -> tuple[training_utils.TrainState, Any]:
88
+ tx = _optimizer.create_optimizer(config.optimizer, config.lr_schedule, weight_decay_mask=None)
89
+
90
+ def init(rng: at.KeyArrayLike, partial_params: at.Params | None = None) -> training_utils.TrainState:
91
+ rng, model_rng = jax.random.split(rng)
92
+ # initialize the model (and its parameters).
93
+ model = config.model.create(model_rng)
94
+
95
+ # Merge the partial params into the model.
96
+ if partial_params is not None:
97
+ graphdef, state = nnx.split(model)
98
+ # This will produce an error if the partial params are not a subset of the state.
99
+ state.replace_by_pure_dict(partial_params)
100
+ model = nnx.merge(graphdef, state)
101
+
102
+ params = nnx.state(model)
103
+ # Convert frozen params to bfloat16.
104
+ params = nnx_utils.state_map(params, config.freeze_filter, lambda p: p.replace(p.value.astype(jnp.bfloat16)))
105
+
106
+ return training_utils.TrainState(
107
+ step=0,
108
+ params=params,
109
+ model_def=nnx.graphdef(model),
110
+ tx=tx,
111
+ opt_state=tx.init(params.filter(config.trainable_filter)),
112
+ ema_decay=config.ema_decay,
113
+ ema_params=None if config.ema_decay is None else params,
114
+ )
115
+
116
+ train_state_shape = jax.eval_shape(init, init_rng)
117
+ state_sharding = sharding.fsdp_sharding(train_state_shape, mesh, log=True)
118
+
119
+ if resume:
120
+ return train_state_shape, state_sharding
121
+
122
+ partial_params = _load_weights_and_validate(config.weight_loader, train_state_shape.params.to_pure_dict())
123
+ replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
124
+
125
+ # Initialize the train state and mix in the partial params.
126
+ train_state = jax.jit(
127
+ init,
128
+ donate_argnums=(1,), # donate the partial params buffer.
129
+ in_shardings=replicated_sharding,
130
+ out_shardings=state_sharding,
131
+ )(init_rng, partial_params)
132
+
133
+ return train_state, state_sharding
134
+
135
+
136
+ @at.typecheck
137
+ def train_step(
138
+ config: _config.TrainConfig,
139
+ rng: at.KeyArrayLike,
140
+ state: training_utils.TrainState,
141
+ batch: tuple[_model.Observation, _model.Actions],
142
+ ) -> tuple[training_utils.TrainState, dict[str, at.Array]]:
143
+ model = nnx.merge(state.model_def, state.params)
144
+ model.train()
145
+
146
+ @at.typecheck
147
+ def loss_fn(
148
+ model: _model.BaseModel, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions
149
+ ):
150
+ chunked_loss = model.compute_loss(rng, observation, actions, train=True)
151
+ return jnp.mean(chunked_loss)
152
+
153
+ train_rng = jax.random.fold_in(rng, state.step)
154
+ observation, actions = batch
155
+
156
+ # Filter out frozen params.
157
+ diff_state = nnx.DiffState(0, config.trainable_filter)
158
+ loss, grads = nnx.value_and_grad(loss_fn, argnums=diff_state)(model, train_rng, observation, actions)
159
+
160
+ params = state.params.filter(config.trainable_filter)
161
+ updates, new_opt_state = state.tx.update(grads, state.opt_state, params)
162
+ new_params = optax.apply_updates(params, updates)
163
+
164
+ # Update the model in place and return the new full state.
165
+ nnx.update(model, new_params)
166
+ new_params = nnx.state(model)
167
+
168
+ new_state = dataclasses.replace(state, step=state.step + 1, params=new_params, opt_state=new_opt_state)
169
+ if state.ema_decay is not None:
170
+ new_state = dataclasses.replace(
171
+ new_state,
172
+ ema_params=jax.tree.map(
173
+ lambda old, new: state.ema_decay * old + (1 - state.ema_decay) * new, state.ema_params, new_params
174
+ ),
175
+ )
176
+
177
+ # Filter out params that aren't kernels.
178
+ kernel_params = nnx.state(
179
+ model,
180
+ nnx.All(
181
+ nnx.Param,
182
+ nnx.Not(nnx_utils.PathRegex(".*/(bias|scale|pos_embedding|input_embedding)")),
183
+ lambda _, x: x.value.ndim > 1,
184
+ ),
185
+ )
186
+ info = {
187
+ "loss": loss,
188
+ "grad_norm": optax.global_norm(grads),
189
+ "param_norm": optax.global_norm(kernel_params),
190
+ }
191
+ return new_state, info
192
+
193
+
194
+ def main(config: _config.TrainConfig):
195
+ init_logging()
196
+ logging.info(f"Running on: {platform.node()}")
197
+
198
+ if config.batch_size % jax.device_count() != 0:
199
+ raise ValueError(
200
+ f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}."
201
+ )
202
+
203
+ jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax").expanduser()))
204
+
205
+ rng = jax.random.key(config.seed)
206
+ train_rng, init_rng = jax.random.split(rng)
207
+
208
+ mesh = sharding.make_mesh(config.fsdp_devices)
209
+ data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(sharding.DATA_AXIS))
210
+ replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
211
+
212
+ checkpoint_manager, resuming = _checkpoints.initialize_checkpoint_dir(
213
+ config.checkpoint_dir,
214
+ keep_period=config.keep_period,
215
+ overwrite=config.overwrite,
216
+ resume=config.resume,
217
+ )
218
+ init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
219
+
220
+ data_loader = _data_loader.create_data_loader(
221
+ config,
222
+ sharding=data_sharding,
223
+ shuffle=True,
224
+ )
225
+ data_iter = iter(data_loader)
226
+ batch = next(data_iter)
227
+ logging.info(f"Initialized data loader:\n{training_utils.array_tree_to_info(batch)}")
228
+
229
+ # Log images from first batch to sanity check.
230
+ images_to_log = [
231
+ wandb.Image(np.concatenate([np.array(img[i]) for img in batch[0].images.values()], axis=1))
232
+ for i in range(min(5, len(next(iter(batch[0].images.values())))))
233
+ ]
234
+ wandb.log({"camera_views": images_to_log}, step=0)
235
+
236
+ train_state, train_state_sharding = init_train_state(config, init_rng, mesh, resume=resuming)
237
+ jax.block_until_ready(train_state)
238
+ logging.info(f"Initialized train state:\n{training_utils.array_tree_to_info(train_state.params)}")
239
+
240
+ if resuming:
241
+ train_state = _checkpoints.restore_state(checkpoint_manager, train_state, data_loader)
242
+
243
+ ptrain_step = jax.jit(
244
+ functools.partial(train_step, config),
245
+ in_shardings=(replicated_sharding, train_state_sharding, data_sharding),
246
+ out_shardings=(train_state_sharding, replicated_sharding),
247
+ donate_argnums=(1,),
248
+ )
249
+
250
+ start_step = int(train_state.step)
251
+ pbar = tqdm.tqdm(
252
+ range(start_step, config.num_train_steps),
253
+ initial=start_step,
254
+ total=config.num_train_steps,
255
+ dynamic_ncols=True,
256
+ )
257
+
258
+ infos = []
259
+ for step in pbar:
260
+ with sharding.set_mesh(mesh):
261
+ train_state, info = ptrain_step(train_rng, train_state, batch)
262
+ infos.append(info)
263
+ if step % config.log_interval == 0:
264
+ stacked_infos = common_utils.stack_forest(infos)
265
+ reduced_info = jax.device_get(jax.tree.map(jnp.mean, stacked_infos))
266
+ info_str = ", ".join(f"{k}={v:.4f}" for k, v in reduced_info.items())
267
+ pbar.write(f"Step {step}: {info_str}")
268
+ wandb.log(reduced_info, step=step)
269
+ infos = []
270
+ batch = next(data_iter)
271
+
272
+ if (step % config.save_interval == 0 and step > start_step) or step == config.num_train_steps - 1:
273
+ _checkpoints.save_state(checkpoint_manager, train_state, data_loader, step)
274
+
275
+ logging.info("Waiting for checkpoint manager to finish")
276
+ checkpoint_manager.wait_until_finished()
277
+
278
+
279
+ if __name__ == "__main__":
280
+ main(_config.cli())
capvector-pi05/scripts/train_align_pytorch.py ADDED
@@ -0,0 +1,658 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch training entrypoint for PI0/PI05 with multi-GPU and multi-node (DDP) support.
3
+ This script mirrors the behavior of the JAX trainer (`scripts/train.py`) but runs
4
+ entirely in PyTorch using the `PI0Pytorch` model and your existing config/data
5
+ pipeline from `src/openpi/training/config.py` and `src/openpi/training/data_loader.py`.
6
+
7
+ Usage
8
+ Single GPU:
9
+ python scripts/train_pytorch.py <config_name> --exp_name <run_name> --save_interval <interval>
10
+ Example:
11
+ python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test
12
+ python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test --resume # Resume from latest checkpoint
13
+ Multi-GPU (single node):
14
+ torchrun --standalone --nnodes=1 --nproc_per_node=<num_gpus> scripts/train_pytorch.py <config_name> --exp_name <run_name>
15
+ Example:
16
+ torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test
17
+ torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume
18
+ Multi-Node Training:
19
+ torchrun \
20
+ --nnodes=<num_nodes> --nproc_per_node=<gpus_per_node> --node_rank=<rank_of_node> \
21
+ --master_addr=<master_ip> --master_port=<port> \
22
+ scripts/train_pytorch.py <config_name> --exp_name=<run_name> --save_interval <interval>
23
+
24
+ """
25
+
26
+ import dataclasses
27
+ import gc
28
+ import logging
29
+ import os
30
+ import platform
31
+ import shutil
32
+ import time
33
+
34
+ import jax
35
+ import numpy as np
36
+ import safetensors.torch
37
+ import torch
38
+ import torch.distributed as dist
39
+ import torch.nn.parallel
40
+ import tqdm
41
+ import wandb
42
+
43
+ import openpi.models.pi0_config
44
+ from openpi.models_pytorch import pi0_pytorch, pi0_align_pytorch, projectors
45
+ import openpi.shared.normalize as _normalize
46
+ import openpi.training.config as _config
47
+ import openpi.training.data_loader as _data
48
+
49
+ from vggt.models.vggt import VGGT
50
+
51
+
52
+ def init_logging():
53
+ level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
54
+
55
+ class CustomFormatter(logging.Formatter):
56
+ def format(self, record):
57
+ record.levelname = level_mapping.get(record.levelname, record.levelname)
58
+ return super().format(record)
59
+
60
+ formatter = CustomFormatter(
61
+ fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
62
+ datefmt="%H:%M:%S",
63
+ )
64
+ logger = logging.getLogger()
65
+ logger.setLevel(logging.INFO)
66
+ if not logger.handlers:
67
+ ch = logging.StreamHandler()
68
+ ch.setFormatter(formatter)
69
+ logger.addHandler(ch)
70
+ else:
71
+ logger.handlers[0].setFormatter(formatter)
72
+
73
+
74
+ def init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: bool = True):
75
+ """Initialize wandb logging."""
76
+ if not enabled:
77
+ wandb.init(mode="disabled")
78
+ return
79
+
80
+ ckpt_dir = config.checkpoint_dir
81
+ if not ckpt_dir.exists():
82
+ raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
83
+
84
+ if resuming:
85
+ run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
86
+ wandb.init(id=run_id, resume="must", project=config.project_name)
87
+ else:
88
+ wandb.init(
89
+ name=config.exp_name,
90
+ config=dataclasses.asdict(config),
91
+ project=config.project_name,
92
+ )
93
+ (ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
94
+
95
+
96
+ def setup_ddp():
97
+ world_size = int(os.environ.get("WORLD_SIZE", "1"))
98
+ use_ddp = world_size > 1
99
+ if use_ddp and not torch.distributed.is_initialized():
100
+ backend = "nccl" if torch.cuda.is_available() else "gloo"
101
+ torch.distributed.init_process_group(backend=backend, init_method="env://")
102
+
103
+ # Set up debugging environment variables for DDP issues
104
+ if os.environ.get("TORCH_DISTRIBUTED_DEBUG") is None:
105
+ os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO"
106
+
107
+ local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0")))
108
+ device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
109
+ if torch.cuda.is_available():
110
+ torch.cuda.set_device(device)
111
+ return use_ddp, local_rank, device
112
+
113
+
114
+ def cleanup_ddp():
115
+ if torch.distributed.is_initialized():
116
+ torch.distributed.barrier()
117
+ torch.distributed.destroy_process_group()
118
+
119
+
120
+ def set_seed(seed: int, local_rank: int):
121
+ torch.manual_seed(seed + local_rank)
122
+ np.random.seed(seed + local_rank)
123
+ if torch.cuda.is_available():
124
+ torch.cuda.manual_seed_all(seed + local_rank)
125
+
126
+
127
+ def build_datasets(config: _config.TrainConfig):
128
+ # Use the unified data loader with PyTorch framework
129
+ data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=True)
130
+ return data_loader, data_loader.data_config()
131
+
132
+
133
+ def get_model_state_dict(model):
134
+ """Get state dict from model, handling DDP wrapper."""
135
+ return (
136
+ model.module.state_dict()
137
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel)
138
+ else model.state_dict()
139
+ )
140
+
141
+
142
+ def get_model_parameters(model):
143
+ """Get parameters from model, handling DDP wrapper."""
144
+ return (
145
+ model.module.parameters()
146
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel)
147
+ else model.parameters()
148
+ )
149
+
150
+
151
+ def save_checkpoint(model, optimizer, global_step, config, is_main, data_config):
152
+ """Save a checkpoint with model state, optimizer state, and metadata."""
153
+ if not is_main:
154
+ return
155
+
156
+ # Only save if it's time to save or if it's the final step
157
+ if (global_step % config.save_interval == 0 and global_step > 0) or global_step == config.num_train_steps - 1:
158
+ # Create temporary directory for atomic checkpoint saving
159
+ final_ckpt_dir = config.checkpoint_dir / f"{global_step}"
160
+ tmp_ckpt_dir = config.checkpoint_dir / f"tmp_{global_step}"
161
+
162
+ # Remove any existing temp directory and create new one
163
+ if tmp_ckpt_dir.exists():
164
+ shutil.rmtree(tmp_ckpt_dir)
165
+ tmp_ckpt_dir.mkdir(parents=True, exist_ok=True)
166
+
167
+ # Save model state using safetensors (handle shared tensors)
168
+ model_to_save = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
169
+ safetensors.torch.save_model(model_to_save, tmp_ckpt_dir / "model.safetensors")
170
+
171
+ # Save optimizer state using PyTorch format
172
+ torch.save(optimizer.state_dict(), tmp_ckpt_dir / "optimizer.pt")
173
+
174
+ # Save training metadata (avoid saving full config to prevent JAX/Flax compatibility issues)
175
+ metadata = {
176
+ "global_step": global_step,
177
+ "config": dataclasses.asdict(config),
178
+ "timestamp": time.time(),
179
+ }
180
+ torch.save(metadata, tmp_ckpt_dir / "metadata.pt")
181
+
182
+ # save norm stats
183
+ norm_stats = data_config.norm_stats
184
+ if norm_stats is not None and data_config.asset_id is not None:
185
+ _normalize.save(tmp_ckpt_dir / "assets" / data_config.asset_id, norm_stats)
186
+
187
+ # Atomically move temp directory to final location
188
+ if final_ckpt_dir.exists():
189
+ shutil.rmtree(final_ckpt_dir)
190
+ tmp_ckpt_dir.rename(final_ckpt_dir)
191
+
192
+ logging.info(f"Saved checkpoint at step {global_step} -> {final_ckpt_dir}")
193
+
194
+ # Log checkpoint to wandb
195
+ if config.wandb_enabled:
196
+ wandb.log({"checkpoint_step": global_step}, step=global_step)
197
+
198
+
199
+ def load_checkpoint(model, optimizer, checkpoint_dir, device):
200
+ """Load the latest checkpoint and return the global step."""
201
+ checkpoint_steps = [
202
+ int(d.name)
203
+ for d in checkpoint_dir.iterdir()
204
+ if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_")
205
+ ]
206
+
207
+ if not checkpoint_steps:
208
+ raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
209
+
210
+ latest_step = max(checkpoint_steps)
211
+ ckpt_dir = checkpoint_dir / f"{latest_step}"
212
+
213
+ # Clear memory before loading checkpoints
214
+ if torch.cuda.is_available():
215
+ torch.cuda.empty_cache()
216
+ gc.collect()
217
+ log_memory_usage(device, latest_step, "before_loading_checkpoint")
218
+
219
+ try:
220
+ # Load model state with error handling
221
+ logging.info("Loading model state...")
222
+ safetensors_path = ckpt_dir / "model.safetensors"
223
+
224
+ if safetensors_path.exists():
225
+ model_to_load = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
226
+ safetensors.torch.load_model(model_to_load, safetensors_path, device=str(device))
227
+ logging.info("Loaded model state from safetensors format")
228
+ else:
229
+ raise FileNotFoundError(f"No model checkpoint found at {ckpt_dir}")
230
+
231
+ torch.cuda.empty_cache()
232
+ gc.collect()
233
+ log_memory_usage(device, latest_step, "after_loading_model")
234
+
235
+ # Load optimizer state with error handling
236
+ logging.info("Loading optimizer state...")
237
+ optimizer_path = ckpt_dir / "optimizer.pt"
238
+
239
+ if optimizer_path.exists():
240
+ optimizer_state_dict = torch.load(optimizer_path, map_location=device, weights_only=False)
241
+ logging.info("Loaded optimizer state from pt format")
242
+ else:
243
+ raise FileNotFoundError(f"No optimizer checkpoint found at {ckpt_dir}")
244
+
245
+ optimizer.load_state_dict(optimizer_state_dict)
246
+ del optimizer_state_dict
247
+ torch.cuda.empty_cache()
248
+ gc.collect()
249
+ log_memory_usage(device, latest_step, "after_loading_optimizer")
250
+
251
+ # Load metadata
252
+ logging.info("Loading metadata...")
253
+ metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False)
254
+ global_step = metadata.get("global_step", latest_step)
255
+ del metadata
256
+ torch.cuda.empty_cache()
257
+ gc.collect()
258
+ log_memory_usage(device, latest_step, "after_loading_metadata")
259
+
260
+ logging.info(f"Successfully loaded all checkpoint components from step {latest_step}")
261
+ return global_step
262
+
263
+ except RuntimeError as e:
264
+ if "out of memory" in str(e):
265
+ # Clear memory and provide detailed error message
266
+ torch.cuda.empty_cache()
267
+ gc.collect()
268
+ logging.error(f"Out of memory error while loading checkpoint: {e!s}")
269
+ log_memory_usage(device, latest_step, "after_oom_error")
270
+ raise RuntimeError(
271
+ "Out of memory while loading checkpoint. Try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
272
+ ) from e
273
+ raise
274
+
275
+
276
+ def get_latest_checkpoint_step(checkpoint_dir):
277
+ """Get the latest checkpoint step number from a checkpoint directory."""
278
+ checkpoint_steps = [
279
+ int(d.name)
280
+ for d in checkpoint_dir.iterdir()
281
+ if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_")
282
+ ]
283
+ return max(checkpoint_steps) if checkpoint_steps else None
284
+
285
+
286
+ def log_memory_usage(device, step, phase="unknown"):
287
+ """Log detailed memory usage information."""
288
+ if not torch.cuda.is_available():
289
+ return
290
+
291
+ memory_allocated = torch.cuda.memory_allocated(device) / 1e9
292
+ memory_reserved = torch.cuda.memory_reserved(device) / 1e9
293
+ memory_free = torch.cuda.memory_reserved(device) - torch.cuda.memory_allocated(device)
294
+ memory_free = memory_free / 1e9
295
+
296
+ # Get more detailed memory info
297
+ memory_stats = torch.cuda.memory_stats(device)
298
+ max_memory_allocated = memory_stats.get("allocated_bytes.all.peak", 0) / 1e9
299
+ max_memory_reserved = memory_stats.get("reserved_bytes.all.peak", 0) / 1e9
300
+
301
+ # Get DDP info if available
302
+ ddp_info = ""
303
+ if dist.is_initialized():
304
+ ddp_info = f" | DDP: rank={dist.get_rank()}, world_size={dist.get_world_size()}"
305
+
306
+ logging.info(
307
+ f"Step {step} ({phase}): GPU memory - allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB, free: {memory_free:.2f}GB, peak_allocated: {max_memory_allocated:.2f}GB, peak_reserved: {max_memory_reserved:.2f}GB{ddp_info}"
308
+ )
309
+
310
+
311
+ def train_loop(config: _config.TrainConfig):
312
+ use_ddp, local_rank, device = setup_ddp()
313
+ is_main = (not use_ddp) or (dist.get_rank() == 0)
314
+ set_seed(config.seed, local_rank)
315
+
316
+ # Initialize checkpoint directory and wandb
317
+ resuming = False
318
+ if config.resume:
319
+ # Find checkpoint directory based on experiment name
320
+ exp_checkpoint_dir = config.checkpoint_dir
321
+ if exp_checkpoint_dir.exists():
322
+ # Use validation to find the latest working checkpoint
323
+ latest_step = get_latest_checkpoint_step(exp_checkpoint_dir)
324
+ if latest_step is not None:
325
+ resuming = True
326
+ logging.info(
327
+ f"Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}"
328
+ )
329
+ else:
330
+ raise FileNotFoundError(f"No valid checkpoints found in {exp_checkpoint_dir} for resume")
331
+ else:
332
+ raise FileNotFoundError(f"Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume")
333
+ elif config.overwrite and config.checkpoint_dir.exists():
334
+ shutil.rmtree(config.checkpoint_dir)
335
+ logging.info(f"Overwriting checkpoint directory: {config.checkpoint_dir}")
336
+
337
+ # Create checkpoint directory with experiment name
338
+ if not resuming:
339
+ # For new runs, create experiment-specific checkpoint directory
340
+ exp_checkpoint_dir = config.checkpoint_dir
341
+ exp_checkpoint_dir.mkdir(parents=True, exist_ok=True)
342
+ logging.info(f"Created experiment checkpoint directory: {exp_checkpoint_dir}")
343
+ else:
344
+ # For resume, checkpoint_dir is already set to the experiment directory
345
+ logging.info(f"Using existing experiment checkpoint directory: {config.checkpoint_dir}")
346
+
347
+ # Initialize wandb (only on main process)
348
+ if is_main:
349
+ init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
350
+
351
+ # Build data loader using the unified data loader
352
+ # Calculate effective batch size per GPU for DDP
353
+ # For N GPUs, each GPU should get batch_size/N samples, so total across all GPUs is batch_size
354
+ world_size = torch.distributed.get_world_size() if use_ddp else 1
355
+ effective_batch_size = config.batch_size // world_size
356
+ logging.info(
357
+ f"Using batch size per GPU: {effective_batch_size} (total batch size across {world_size} GPUs: {config.batch_size})"
358
+ )
359
+
360
+ # Pass the original batch size to data loader - it will handle DDP splitting internally
361
+ loader, data_config = build_datasets(config)
362
+
363
+ # Log sample images to wandb on first batch
364
+ if is_main and config.wandb_enabled and not resuming:
365
+ # Create a separate data loader for sample batch to avoid consuming the main loader
366
+ sample_data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=False)
367
+ sample_batch = next(iter(sample_data_loader))
368
+ # Convert observation and actions to torch tensors
369
+ observation, actions = sample_batch
370
+ sample_batch = observation.to_dict()
371
+ sample_batch["actions"] = actions
372
+
373
+ # Create sample images for wandb
374
+ images_to_log = []
375
+ # Get batch size from the first image tensor
376
+ batch_size = next(iter(sample_batch["image"].values())).shape[0]
377
+ for i in range(min(5, batch_size)):
378
+ # Concatenate all camera views horizontally for this batch item
379
+ # Convert from NCHW to NHWC format for wandb
380
+ img_concatenated = torch.cat([img[i].permute(1, 2, 0) for img in sample_batch["image"].values()], axis=1)
381
+ img_concatenated = img_concatenated.cpu().numpy()
382
+ images_to_log.append(wandb.Image(img_concatenated))
383
+
384
+ wandb.log({"camera_views": images_to_log}, step=0)
385
+
386
+ # Clear sample batch from memory aggressively
387
+ del sample_batch, observation, actions, images_to_log, img_concatenated
388
+ del sample_data_loader # Also delete the sample data loader
389
+ gc.collect()
390
+ if torch.cuda.is_available():
391
+ torch.cuda.empty_cache()
392
+ logging.info("Cleared sample batch and data loader from memory")
393
+
394
+ # Build model
395
+ if not isinstance(config.model, openpi.models.pi0_config.Pi0Config):
396
+ # Convert dataclass to Pi0Config if needed
397
+ model_cfg = openpi.models.pi0_config.Pi0Config(
398
+ dtype=config.pytorch_training_precision,
399
+ action_dim=config.model.action_dim,
400
+ action_horizon=config.model.action_horizon,
401
+ max_token_len=config.model.max_token_len,
402
+ paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"),
403
+ action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"),
404
+ pi05=getattr(config.model, "pi05", False),
405
+ )
406
+ else:
407
+ model_cfg = config.model
408
+ # Update dtype to match pytorch_training_precision
409
+ object.__setattr__(model_cfg, "dtype", config.pytorch_training_precision)
410
+
411
+ model = openpi.models_pytorch.pi0_align_pytorch.PI0Pytorch(model_cfg, config).to(device)
412
+ vggt_model = VGGT(
413
+ enable_camera=False,
414
+ enable_point=False,
415
+ enable_depth=False,
416
+ enable_track=False,
417
+ feature_only=True,
418
+ ).to(device)
419
+ align_projector = projectors.AlignProjector(
420
+ model.LLM_width,
421
+ config.vggt_dim,
422
+ config.use_vlm_norm).to(device)
423
+
424
+ if hasattr(model, "gradient_checkpointing_enable"):
425
+ enable_gradient_checkpointing = True
426
+ model.gradient_checkpointing_enable()
427
+ logging.info("Enabled gradient checkpointing for memory optimization")
428
+ else:
429
+ enable_gradient_checkpointing = False
430
+ logging.info("Gradient checkpointing is not supported for this model")
431
+
432
+ # Log initial memory usage after model creation
433
+ if is_main and torch.cuda.is_available():
434
+ log_memory_usage(device, 0, "after_model_creation")
435
+
436
+ # Enable memory optimizations for large-scale training
437
+ if world_size >= 8:
438
+ torch.backends.cudnn.benchmark = True
439
+ torch.backends.cuda.matmul.allow_tf32 = True
440
+ torch.backends.cudnn.allow_tf32 = True
441
+ # Set memory allocation configuration
442
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
443
+ logging.info("Enabled memory optimizations for 8+ GPU training")
444
+
445
+ if use_ddp:
446
+ model = torch.nn.parallel.DistributedDataParallel(
447
+ model,
448
+ device_ids=[device.index] if device.type == "cuda" else None,
449
+ find_unused_parameters=True, # Disable for memory efficiency
450
+ gradient_as_bucket_view=True, # Enable for memory efficiency
451
+ static_graph=world_size >= 8, # Enable for 8+ GPUs
452
+ )
453
+ align_projector = torch.nn.parallel.DistributedDataParallel(
454
+ align_projector,
455
+ device_ids=[device.index] if device.type == "cuda" else None,
456
+ find_unused_parameters=True, # Disable for memory efficiency
457
+ gradient_as_bucket_view=True, # Enable for memory efficiency
458
+ static_graph=world_size >= 8, # Enable for 8+ GPUs
459
+ )
460
+
461
+ # Load weights from weight_loader if specified (for fine-tuning)
462
+ if config.pytorch_weight_path is not None:
463
+ logging.info(f"Loading weights from: {config.pytorch_weight_path}")
464
+ model_path = os.path.join(config.pytorch_weight_path, "model.safetensors")
465
+ safetensors.torch.load_model(
466
+ (model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model),
467
+ model_path,
468
+ strict=False,
469
+ )
470
+ logging.info(f"Loaded PyTorch weights from {config.pytorch_weight_path}")
471
+ if config.vggt_weight_path is not None:
472
+ vggt_path = os.path.join(config.vggt_weight_path, "model.pt")
473
+ if not os.path.exists(vggt_path):
474
+ raise FileNotFoundError(f"VGGT weight file not found at {vggt_path}")
475
+ vggt_model.load_state_dict(torch.load(vggt_path), strict=False)
476
+ logging.info(f"Loaded VGGT weights from {config.vggt_weight_path}")
477
+
478
+ # Optimizer + learning rate schedule from config
479
+ warmup_steps = config.lr_schedule.warmup_steps
480
+ peak_lr = config.lr_schedule.peak_lr
481
+ decay_steps = config.lr_schedule.decay_steps
482
+ end_lr = config.lr_schedule.decay_lr
483
+
484
+ # Create optimizer with config parameters
485
+ optim = torch.optim.AdamW(
486
+ list(model.parameters()) + list(align_projector.parameters()),
487
+ lr=peak_lr,
488
+ betas=(config.optimizer.b1, config.optimizer.b2),
489
+ eps=config.optimizer.eps,
490
+ weight_decay=config.optimizer.weight_decay,
491
+ )
492
+
493
+ # Load checkpoint if resuming
494
+ global_step = 0
495
+ if resuming:
496
+ global_step = load_checkpoint(model, optim, config.checkpoint_dir, device)
497
+ logging.info(f"Resumed training from step {global_step}")
498
+
499
+ def lr_schedule(step: int):
500
+ if step < warmup_steps:
501
+ # Match JAX behavior: start from peak_lr / (warmup_steps + 1)
502
+ init_lr = peak_lr / (warmup_steps + 1)
503
+ return init_lr + (peak_lr - init_lr) * step / warmup_steps
504
+ # cosine decay
505
+ progress = min(1.0, (step - warmup_steps) / max(1, decay_steps - warmup_steps))
506
+ cos = 0.5 * (1 + np.cos(np.pi * progress))
507
+ return end_lr + (peak_lr - end_lr) * cos
508
+
509
+ model.train()
510
+ align_projector.train()
511
+ vggt_model.eval()
512
+ start_time = time.time()
513
+ infos = [] # Collect stats over log interval
514
+ if is_main:
515
+ logging.info(
516
+ f"Running on: {platform.node()} | world_size={torch.distributed.get_world_size() if use_ddp else 1}"
517
+ )
518
+ logging.info(
519
+ f"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}"
520
+ )
521
+ logging.info(f"Memory optimizations: gradient_checkpointing={enable_gradient_checkpointing}")
522
+ logging.info(
523
+ f"LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}"
524
+ )
525
+ logging.info(
526
+ f"Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}"
527
+ )
528
+ logging.info("EMA is not supported for PyTorch training")
529
+ logging.info(f"Training precision: {model_cfg.dtype}")
530
+
531
+ # Training loop - iterate until we reach num_train_steps
532
+ pbar = (
533
+ tqdm.tqdm(total=config.num_train_steps, initial=global_step, desc="Training", disable=not is_main)
534
+ if is_main
535
+ else None
536
+ )
537
+
538
+ while global_step < config.num_train_steps:
539
+ # Set epoch for distributed training
540
+ if use_ddp and hasattr(loader, "set_epoch"):
541
+ loader.set_epoch(global_step // len(loader))
542
+
543
+ for observation, actions in loader:
544
+ # Check if we've reached the target number of steps
545
+ if global_step >= config.num_train_steps:
546
+ break
547
+
548
+ # The unified data loader returns (observation, actions) tuple
549
+ observation = jax.tree.map(lambda x: x.to(device), observation) # noqa: PLW2901
550
+ actions = actions.to(torch.float32) # noqa: PLW2901
551
+ actions = actions.to(device) # noqa: PLW2901
552
+
553
+ # Update LR
554
+ for pg in optim.param_groups:
555
+ pg["lr"] = lr_schedule(global_step)
556
+
557
+ # Forward pass
558
+ action_losses, align_loss = model(observation, actions, vggt=vggt_model, align_proj=align_projector)
559
+ loss = action_losses + config.align_loss_coeff * align_loss
560
+
561
+ # Backward pass
562
+ loss.backward()
563
+
564
+ # Log memory usage after backward pass
565
+ if global_step < 5 and is_main and torch.cuda.is_available():
566
+ log_memory_usage(device, global_step, "after_backward")
567
+
568
+ # Gradient clipping
569
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optimizer.clip_gradient_norm)
570
+
571
+ # Optimizer step
572
+ optim.step()
573
+ optim.zero_grad(set_to_none=True)
574
+
575
+ # Clear gradients more aggressively
576
+ for param in model.parameters():
577
+ if param.grad is not None:
578
+ param.grad.detach_()
579
+ param.grad = None
580
+
581
+ # Collect stats
582
+ if is_main:
583
+ infos.append(
584
+ {
585
+ "action_loss": action_losses.item(),
586
+ "align_loss": align_loss.item(),
587
+ "learning_rate": optim.param_groups[0]["lr"],
588
+ "grad_norm": float(grad_norm) if isinstance(grad_norm, torch.Tensor) else grad_norm,
589
+ }
590
+ )
591
+
592
+ if is_main and (global_step % config.log_interval == 0):
593
+ elapsed = time.time() - start_time
594
+
595
+ # Average stats over log interval
596
+ avg_loss = sum(info["action_loss"] for info in infos) / len(infos)
597
+ avg_align_loss = sum(info["align_loss"] for info in infos) / len(infos)
598
+ avg_lr = sum(info["learning_rate"] for info in infos) / len(infos)
599
+
600
+ avg_grad_norm = None
601
+ if any("grad_norm" in info for info in infos):
602
+ vals = [
603
+ info["grad_norm"] for info in infos if "grad_norm" in info and info["grad_norm"] is not None
604
+ ]
605
+ if len(vals) > 0:
606
+ avg_grad_norm = sum(vals) / len(vals)
607
+ logging.info(
608
+ f"step={global_step} action_loss={avg_loss:.4f} align_loss={avg_align_loss:.4f} lr={avg_lr:.2e} grad_norm={avg_grad_norm:.2f} time={elapsed:.1f}s"
609
+ if avg_grad_norm is not None
610
+ else f"step={global_step} action_loss={avg_loss:.4f} align_loss={avg_align_loss:.4f} lr={avg_lr:.2e} time={elapsed:.1f}s"
611
+ )
612
+
613
+ # Log to wandb
614
+ if config.wandb_enabled and len(infos) > 0:
615
+ log_payload = {
616
+ "action_loss": avg_loss,
617
+ "align_loss": avg_align_loss,
618
+ "learning_rate": avg_lr,
619
+ "step": global_step,
620
+ "time_per_step": elapsed / config.log_interval,
621
+ }
622
+ if avg_grad_norm is not None:
623
+ log_payload["grad_norm"] = avg_grad_norm
624
+ wandb.log(log_payload, step=global_step)
625
+
626
+ start_time = time.time()
627
+ infos = [] # Reset stats collection
628
+
629
+ global_step += 1
630
+ # Save checkpoint using the new mechanism
631
+ save_checkpoint(model, optim, global_step, config, is_main, data_config)
632
+
633
+ # Update progress bar
634
+ if pbar is not None:
635
+ pbar.update(1)
636
+ pbar.set_postfix(
637
+ {"loss": f"{loss.item():.4f}", "lr": f"{optim.param_groups[0]['lr']:.2e}", "step": global_step}
638
+ )
639
+
640
+ # Close progress bar
641
+ if pbar is not None:
642
+ pbar.close()
643
+
644
+ # Finish wandb run
645
+ if is_main and config.wandb_enabled:
646
+ wandb.finish()
647
+
648
+ cleanup_ddp()
649
+
650
+
651
+ def main():
652
+ init_logging()
653
+ config = _config.cli()
654
+ train_loop(config)
655
+
656
+
657
+ if __name__ == "__main__":
658
+ main()
capvector-pi05/scripts/train_pytorch.py ADDED
@@ -0,0 +1,632 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch training entrypoint for PI0/PI05 with multi-GPU and multi-node (DDP) support.
3
+ This script mirrors the behavior of the JAX trainer (`scripts/train.py`) but runs
4
+ entirely in PyTorch using the `PI0Pytorch` model and your existing config/data
5
+ pipeline from `src/openpi/training/config.py` and `src/openpi/training/data_loader.py`.
6
+
7
+ Usage
8
+ Single GPU:
9
+ python scripts/train_pytorch.py <config_name> --exp_name <run_name> --save_interval <interval>
10
+ Example:
11
+ python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test
12
+ python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test --resume # Resume from latest checkpoint
13
+ Multi-GPU (single node):
14
+ torchrun --standalone --nnodes=1 --nproc_per_node=<num_gpus> scripts/train_pytorch.py <config_name> --exp_name <run_name>
15
+ Example:
16
+ torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test
17
+ torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume
18
+ Multi-Node Training:
19
+ torchrun \
20
+ --nnodes=<num_nodes> --nproc_per_node=<gpus_per_node> --node_rank=<rank_of_node> \
21
+ --master_addr=<master_ip> --master_port=<port> \
22
+ scripts/train_pytorch.py <config_name> --exp_name=<run_name> --save_interval <interval>
23
+
24
+ """
25
+
26
+ import dataclasses
27
+ import gc
28
+ import logging
29
+ import os
30
+ import platform
31
+ import shutil
32
+ import time
33
+
34
+ import jax
35
+ import numpy as np
36
+ import safetensors.torch
37
+ import torch
38
+ import torch.distributed as dist
39
+ import torch.nn.parallel
40
+ import tqdm
41
+ import wandb
42
+
43
+ import openpi.models.pi0_config
44
+ import openpi.models_pytorch.pi0_pytorch
45
+ import openpi.shared.normalize as _normalize
46
+ import openpi.training.config as _config
47
+ import openpi.training.data_loader as _data
48
+
49
+
50
+ def init_logging():
51
+ level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
52
+
53
+ class CustomFormatter(logging.Formatter):
54
+ def format(self, record):
55
+ record.levelname = level_mapping.get(record.levelname, record.levelname)
56
+ return super().format(record)
57
+
58
+ formatter = CustomFormatter(
59
+ fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
60
+ datefmt="%H:%M:%S",
61
+ )
62
+ logger = logging.getLogger()
63
+ logger.setLevel(logging.INFO)
64
+ if not logger.handlers:
65
+ ch = logging.StreamHandler()
66
+ ch.setFormatter(formatter)
67
+ logger.addHandler(ch)
68
+ else:
69
+ logger.handlers[0].setFormatter(formatter)
70
+
71
+
72
+ def init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: bool = True):
73
+ """Initialize wandb logging."""
74
+ if not enabled:
75
+ wandb.init(mode="disabled")
76
+ return
77
+
78
+ ckpt_dir = config.checkpoint_dir
79
+ if not ckpt_dir.exists():
80
+ raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
81
+
82
+ if resuming:
83
+ run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
84
+ wandb.init(id=run_id, resume="must", project=config.project_name)
85
+ else:
86
+ wandb.init(
87
+ name=config.exp_name,
88
+ config=dataclasses.asdict(config),
89
+ project=config.project_name,
90
+ )
91
+ (ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
92
+
93
+
94
+ def setup_ddp():
95
+ world_size = int(os.environ.get("WORLD_SIZE", "1"))
96
+ use_ddp = world_size > 1
97
+ if use_ddp and not torch.distributed.is_initialized():
98
+ backend = "nccl" if torch.cuda.is_available() else "gloo"
99
+ torch.distributed.init_process_group(backend=backend, init_method="env://")
100
+
101
+ # Set up debugging environment variables for DDP issues
102
+ if os.environ.get("TORCH_DISTRIBUTED_DEBUG") is None:
103
+ os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO"
104
+
105
+ local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0")))
106
+ device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
107
+ if torch.cuda.is_available():
108
+ torch.cuda.set_device(device)
109
+ return use_ddp, local_rank, device
110
+
111
+
112
+ def cleanup_ddp():
113
+ if torch.distributed.is_initialized():
114
+ torch.distributed.barrier()
115
+ torch.distributed.destroy_process_group()
116
+
117
+
118
+ def set_seed(seed: int, local_rank: int):
119
+ torch.manual_seed(seed + local_rank)
120
+ np.random.seed(seed + local_rank)
121
+ if torch.cuda.is_available():
122
+ torch.cuda.manual_seed_all(seed + local_rank)
123
+
124
+
125
+ def build_datasets(config: _config.TrainConfig):
126
+ # Use the unified data loader with PyTorch framework
127
+ data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=True)
128
+ return data_loader, data_loader.data_config()
129
+
130
+
131
+ def get_model_state_dict(model):
132
+ """Get state dict from model, handling DDP wrapper."""
133
+ return (
134
+ model.module.state_dict()
135
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel)
136
+ else model.state_dict()
137
+ )
138
+
139
+
140
+ def get_model_parameters(model):
141
+ """Get parameters from model, handling DDP wrapper."""
142
+ return (
143
+ model.module.parameters()
144
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel)
145
+ else model.parameters()
146
+ )
147
+
148
+
149
+ def save_checkpoint(model, optimizer, global_step, config, is_main, data_config):
150
+ """Save a checkpoint with model state, optimizer state, and metadata."""
151
+ if not is_main:
152
+ return
153
+
154
+ # Only save if it's time to save or if it's the final step
155
+ if (global_step % config.save_interval == 0 and global_step > 0) or global_step == config.num_train_steps - 1:
156
+ # Create temporary directory for atomic checkpoint saving
157
+ final_ckpt_dir = config.checkpoint_dir / f"{global_step}"
158
+ tmp_ckpt_dir = config.checkpoint_dir / f"tmp_{global_step}"
159
+
160
+ # Remove any existing temp directory and create new one
161
+ if tmp_ckpt_dir.exists():
162
+ shutil.rmtree(tmp_ckpt_dir)
163
+ tmp_ckpt_dir.mkdir(parents=True, exist_ok=True)
164
+
165
+ # Save model state using safetensors (handle shared tensors)
166
+ model_to_save = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
167
+ safetensors.torch.save_model(model_to_save, tmp_ckpt_dir / "model.safetensors")
168
+
169
+ # Save optimizer state using PyTorch format
170
+ torch.save(optimizer.state_dict(), tmp_ckpt_dir / "optimizer.pt")
171
+
172
+ # Save training metadata (avoid saving full config to prevent JAX/Flax compatibility issues)
173
+ metadata = {
174
+ "global_step": global_step,
175
+ "config": dataclasses.asdict(config),
176
+ "timestamp": time.time(),
177
+ }
178
+ torch.save(metadata, tmp_ckpt_dir / "metadata.pt")
179
+
180
+ # save norm stats
181
+ norm_stats = data_config.norm_stats
182
+ if norm_stats is not None and data_config.asset_id is not None:
183
+ _normalize.save(tmp_ckpt_dir / "assets" / data_config.asset_id, norm_stats)
184
+
185
+ # Atomically move temp directory to final location
186
+ if final_ckpt_dir.exists():
187
+ shutil.rmtree(final_ckpt_dir)
188
+ tmp_ckpt_dir.rename(final_ckpt_dir)
189
+
190
+ logging.info(f"Saved checkpoint at step {global_step} -> {final_ckpt_dir}")
191
+
192
+ # Log checkpoint to wandb
193
+ if config.wandb_enabled:
194
+ wandb.log({"checkpoint_step": global_step}, step=global_step)
195
+
196
+
197
+ def load_checkpoint(model, optimizer, checkpoint_dir, device):
198
+ """Load the latest checkpoint and return the global step."""
199
+ checkpoint_steps = [
200
+ int(d.name)
201
+ for d in checkpoint_dir.iterdir()
202
+ if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_")
203
+ ]
204
+
205
+ if not checkpoint_steps:
206
+ raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
207
+
208
+ latest_step = max(checkpoint_steps)
209
+ ckpt_dir = checkpoint_dir / f"{latest_step}"
210
+
211
+ # Clear memory before loading checkpoints
212
+ if torch.cuda.is_available():
213
+ torch.cuda.empty_cache()
214
+ gc.collect()
215
+ log_memory_usage(device, latest_step, "before_loading_checkpoint")
216
+
217
+ try:
218
+ # Load model state with error handling
219
+ logging.info("Loading model state...")
220
+ safetensors_path = ckpt_dir / "model.safetensors"
221
+
222
+ if safetensors_path.exists():
223
+ model_to_load = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
224
+ safetensors.torch.load_model(model_to_load, safetensors_path, device=str(device))
225
+ logging.info("Loaded model state from safetensors format")
226
+ else:
227
+ raise FileNotFoundError(f"No model checkpoint found at {ckpt_dir}")
228
+
229
+ torch.cuda.empty_cache()
230
+ gc.collect()
231
+ log_memory_usage(device, latest_step, "after_loading_model")
232
+
233
+ # Load optimizer state with error handling
234
+ logging.info("Loading optimizer state...")
235
+ optimizer_path = ckpt_dir / "optimizer.pt"
236
+
237
+ if optimizer_path.exists():
238
+ optimizer_state_dict = torch.load(optimizer_path, map_location=device, weights_only=False)
239
+ logging.info("Loaded optimizer state from pt format")
240
+ else:
241
+ raise FileNotFoundError(f"No optimizer checkpoint found at {ckpt_dir}")
242
+
243
+ optimizer.load_state_dict(optimizer_state_dict)
244
+ del optimizer_state_dict
245
+ torch.cuda.empty_cache()
246
+ gc.collect()
247
+ log_memory_usage(device, latest_step, "after_loading_optimizer")
248
+
249
+ # Load metadata
250
+ logging.info("Loading metadata...")
251
+ metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False)
252
+ global_step = metadata.get("global_step", latest_step)
253
+ del metadata
254
+ torch.cuda.empty_cache()
255
+ gc.collect()
256
+ log_memory_usage(device, latest_step, "after_loading_metadata")
257
+
258
+ logging.info(f"Successfully loaded all checkpoint components from step {latest_step}")
259
+ return global_step
260
+
261
+ except RuntimeError as e:
262
+ if "out of memory" in str(e):
263
+ # Clear memory and provide detailed error message
264
+ torch.cuda.empty_cache()
265
+ gc.collect()
266
+ logging.error(f"Out of memory error while loading checkpoint: {e!s}")
267
+ log_memory_usage(device, latest_step, "after_oom_error")
268
+ raise RuntimeError(
269
+ "Out of memory while loading checkpoint. Try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
270
+ ) from e
271
+ raise
272
+
273
+
274
+ def get_latest_checkpoint_step(checkpoint_dir):
275
+ """Get the latest checkpoint step number from a checkpoint directory."""
276
+ checkpoint_steps = [
277
+ int(d.name)
278
+ for d in checkpoint_dir.iterdir()
279
+ if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_")
280
+ ]
281
+ return max(checkpoint_steps) if checkpoint_steps else None
282
+
283
+
284
+ def log_memory_usage(device, step, phase="unknown"):
285
+ """Log detailed memory usage information."""
286
+ if not torch.cuda.is_available():
287
+ return
288
+
289
+ memory_allocated = torch.cuda.memory_allocated(device) / 1e9
290
+ memory_reserved = torch.cuda.memory_reserved(device) / 1e9
291
+ memory_free = torch.cuda.memory_reserved(device) - torch.cuda.memory_allocated(device)
292
+ memory_free = memory_free / 1e9
293
+
294
+ # Get more detailed memory info
295
+ memory_stats = torch.cuda.memory_stats(device)
296
+ max_memory_allocated = memory_stats.get("allocated_bytes.all.peak", 0) / 1e9
297
+ max_memory_reserved = memory_stats.get("reserved_bytes.all.peak", 0) / 1e9
298
+
299
+ # Get DDP info if available
300
+ ddp_info = ""
301
+ if dist.is_initialized():
302
+ ddp_info = f" | DDP: rank={dist.get_rank()}, world_size={dist.get_world_size()}"
303
+
304
+ logging.info(
305
+ f"Step {step} ({phase}): GPU memory - allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB, free: {memory_free:.2f}GB, peak_allocated: {max_memory_allocated:.2f}GB, peak_reserved: {max_memory_reserved:.2f}GB{ddp_info}"
306
+ )
307
+
308
+
309
+ def train_loop(config: _config.TrainConfig):
310
+ use_ddp, local_rank, device = setup_ddp()
311
+ is_main = (not use_ddp) or (dist.get_rank() == 0)
312
+ set_seed(config.seed, local_rank)
313
+
314
+ # Initialize checkpoint directory and wandb
315
+ resuming = False
316
+ if config.resume:
317
+ # Find checkpoint directory based on experiment name
318
+ exp_checkpoint_dir = config.checkpoint_dir
319
+ if exp_checkpoint_dir.exists():
320
+ # Use validation to find the latest working checkpoint
321
+ latest_step = get_latest_checkpoint_step(exp_checkpoint_dir)
322
+ if latest_step is not None:
323
+ resuming = True
324
+ logging.info(
325
+ f"Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}"
326
+ )
327
+ else:
328
+ raise FileNotFoundError(f"No valid checkpoints found in {exp_checkpoint_dir} for resume")
329
+ else:
330
+ raise FileNotFoundError(f"Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume")
331
+ elif config.overwrite and config.checkpoint_dir.exists():
332
+ shutil.rmtree(config.checkpoint_dir)
333
+ logging.info(f"Overwriting checkpoint directory: {config.checkpoint_dir}")
334
+
335
+ # Create checkpoint directory with experiment name
336
+ if not resuming:
337
+ # For new runs, create experiment-specific checkpoint directory
338
+ exp_checkpoint_dir = config.checkpoint_dir
339
+ exp_checkpoint_dir.mkdir(parents=True, exist_ok=True)
340
+ logging.info(f"Created experiment checkpoint directory: {exp_checkpoint_dir}")
341
+ else:
342
+ # For resume, checkpoint_dir is already set to the experiment directory
343
+ logging.info(f"Using existing experiment checkpoint directory: {config.checkpoint_dir}")
344
+
345
+ # Initialize wandb (only on main process)
346
+ if is_main:
347
+ init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
348
+
349
+ # Build data loader using the unified data loader
350
+ # Calculate effective batch size per GPU for DDP
351
+ # For N GPUs, each GPU should get batch_size/N samples, so total across all GPUs is batch_size
352
+ world_size = torch.distributed.get_world_size() if use_ddp else 1
353
+ effective_batch_size = config.batch_size // world_size
354
+ logging.info(
355
+ f"Using batch size per GPU: {effective_batch_size} (total batch size across {world_size} GPUs: {config.batch_size})"
356
+ )
357
+
358
+ # Pass the original batch size to data loader - it will handle DDP splitting internally
359
+ loader, data_config = build_datasets(config)
360
+
361
+ # Log sample images to wandb on first batch
362
+ if is_main and config.wandb_enabled and not resuming:
363
+ # Create a separate data loader for sample batch to avoid consuming the main loader
364
+ sample_data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=False)
365
+ sample_batch = next(iter(sample_data_loader))
366
+ # Convert observation and actions to torch tensors
367
+ observation, actions = sample_batch
368
+ sample_batch = observation.to_dict()
369
+ sample_batch["actions"] = actions
370
+
371
+ # Create sample images for wandb
372
+ images_to_log = []
373
+ # Get batch size from the first image tensor
374
+ batch_size = next(iter(sample_batch["image"].values())).shape[0]
375
+ for i in range(min(5, batch_size)):
376
+ # Concatenate all camera views horizontally for this batch item
377
+ # Convert from NCHW to NHWC format for wandb
378
+ img_concatenated = torch.cat([img[i].permute(1, 2, 0) for img in sample_batch["image"].values()], axis=1)
379
+ img_concatenated = img_concatenated.cpu().numpy()
380
+ images_to_log.append(wandb.Image(img_concatenated))
381
+
382
+ wandb.log({"camera_views": images_to_log}, step=0)
383
+
384
+ # Clear sample batch from memory aggressively
385
+ del sample_batch, observation, actions, images_to_log, img_concatenated
386
+ del sample_data_loader # Also delete the sample data loader
387
+ gc.collect()
388
+ if torch.cuda.is_available():
389
+ torch.cuda.empty_cache()
390
+ logging.info("Cleared sample batch and data loader from memory")
391
+
392
+ # Build model
393
+ if not isinstance(config.model, openpi.models.pi0_config.Pi0Config):
394
+ # Convert dataclass to Pi0Config if needed
395
+ model_cfg = openpi.models.pi0_config.Pi0Config(
396
+ dtype=config.pytorch_training_precision,
397
+ action_dim=config.model.action_dim,
398
+ action_horizon=config.model.action_horizon,
399
+ max_token_len=config.model.max_token_len,
400
+ paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"),
401
+ action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"),
402
+ pi05=getattr(config.model, "pi05", False),
403
+ )
404
+ else:
405
+ model_cfg = config.model
406
+ # Update dtype to match pytorch_training_precision
407
+ object.__setattr__(model_cfg, "dtype", config.pytorch_training_precision)
408
+
409
+ model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_cfg).to(device)
410
+
411
+ if hasattr(model, "gradient_checkpointing_enable"):
412
+ enable_gradient_checkpointing = True
413
+ model.gradient_checkpointing_enable()
414
+ logging.info("Enabled gradient checkpointing for memory optimization")
415
+ else:
416
+ enable_gradient_checkpointing = False
417
+ logging.info("Gradient checkpointing is not supported for this model")
418
+
419
+ # Log initial memory usage after model creation
420
+ if is_main and torch.cuda.is_available():
421
+ log_memory_usage(device, 0, "after_model_creation")
422
+
423
+ # Enable memory optimizations for large-scale training
424
+ if world_size >= 8:
425
+ torch.backends.cudnn.benchmark = True
426
+ torch.backends.cuda.matmul.allow_tf32 = True
427
+ torch.backends.cudnn.allow_tf32 = True
428
+ # Set memory allocation configuration
429
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
430
+ logging.info("Enabled memory optimizations for 8+ GPU training")
431
+
432
+ if use_ddp:
433
+ model = torch.nn.parallel.DistributedDataParallel(
434
+ model,
435
+ device_ids=[device.index] if device.type == "cuda" else None,
436
+ find_unused_parameters=True, # Disable for memory efficiency
437
+ gradient_as_bucket_view=True, # Enable for memory efficiency
438
+ static_graph=world_size >= 8, # Enable for 8+ GPUs
439
+ )
440
+
441
+ # Load weights from weight_loader if specified (for fine-tuning)
442
+ if config.pytorch_weight_path is not None:
443
+ logging.info(f"Loading weights from: {config.pytorch_weight_path}")
444
+
445
+ model_path = os.path.join(config.pytorch_weight_path, "model.safetensors")
446
+ safetensors.torch.load_model(
447
+ (model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model), model_path
448
+ )
449
+ logging.info(f"Loaded PyTorch weights from {config.pytorch_weight_path}")
450
+
451
+ # Optimizer + learning rate schedule from config
452
+ warmup_steps = config.lr_schedule.warmup_steps
453
+ peak_lr = config.lr_schedule.peak_lr
454
+ decay_steps = config.lr_schedule.decay_steps
455
+ end_lr = config.lr_schedule.decay_lr
456
+
457
+ # Create optimizer with config parameters
458
+ optim = torch.optim.AdamW(
459
+ model.parameters(),
460
+ lr=peak_lr,
461
+ betas=(config.optimizer.b1, config.optimizer.b2),
462
+ eps=config.optimizer.eps,
463
+ weight_decay=config.optimizer.weight_decay,
464
+ )
465
+
466
+ # Load checkpoint if resuming
467
+ global_step = 0
468
+ if resuming:
469
+ global_step = load_checkpoint(model, optim, config.checkpoint_dir, device)
470
+ logging.info(f"Resumed training from step {global_step}")
471
+
472
+ def lr_schedule(step: int):
473
+ if step < warmup_steps:
474
+ # Match JAX behavior: start from peak_lr / (warmup_steps + 1)
475
+ init_lr = peak_lr / (warmup_steps + 1)
476
+ return init_lr + (peak_lr - init_lr) * step / warmup_steps
477
+ # cosine decay
478
+ progress = min(1.0, (step - warmup_steps) / max(1, decay_steps - warmup_steps))
479
+ cos = 0.5 * (1 + np.cos(np.pi * progress))
480
+ return end_lr + (peak_lr - end_lr) * cos
481
+
482
+ model.train()
483
+ start_time = time.time()
484
+ infos = [] # Collect stats over log interval
485
+ if is_main:
486
+ logging.info(
487
+ f"Running on: {platform.node()} | world_size={torch.distributed.get_world_size() if use_ddp else 1}"
488
+ )
489
+ logging.info(
490
+ f"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}"
491
+ )
492
+ logging.info(f"Memory optimizations: gradient_checkpointing={enable_gradient_checkpointing}")
493
+ logging.info(
494
+ f"LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}"
495
+ )
496
+ logging.info(
497
+ f"Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}"
498
+ )
499
+ logging.info("EMA is not supported for PyTorch training")
500
+ logging.info(f"Training precision: {model_cfg.dtype}")
501
+
502
+ # Training loop - iterate until we reach num_train_steps
503
+ pbar = (
504
+ tqdm.tqdm(total=config.num_train_steps, initial=global_step, desc="Training", disable=not is_main)
505
+ if is_main
506
+ else None
507
+ )
508
+
509
+ while global_step < config.num_train_steps:
510
+ # Set epoch for distributed training
511
+ if use_ddp and hasattr(loader, "set_epoch"):
512
+ loader.set_epoch(global_step // len(loader))
513
+
514
+ for observation, actions in loader:
515
+ # Check if we've reached the target number of steps
516
+ if global_step >= config.num_train_steps:
517
+ break
518
+
519
+ # The unified data loader returns (observation, actions) tuple
520
+ observation = jax.tree.map(lambda x: x.to(device), observation) # noqa: PLW2901
521
+ actions = actions.to(torch.float32) # noqa: PLW2901
522
+ actions = actions.to(device) # noqa: PLW2901
523
+
524
+ # Update LR
525
+ for pg in optim.param_groups:
526
+ pg["lr"] = lr_schedule(global_step)
527
+
528
+ # Forward pass
529
+ losses = model(observation, actions)
530
+ # Ensure losses is a tensor and handle different return types
531
+ if isinstance(losses, list | tuple):
532
+ losses = torch.stack(losses)
533
+ elif not isinstance(losses, torch.Tensor):
534
+ losses = torch.tensor(losses, device=device, dtype=torch.float32)
535
+
536
+ loss = losses.mean()
537
+
538
+ # Backward pass
539
+ loss.backward()
540
+
541
+ # Log memory usage after backward pass
542
+ if global_step < 5 and is_main and torch.cuda.is_available():
543
+ log_memory_usage(device, global_step, "after_backward")
544
+
545
+ # Gradient clipping
546
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optimizer.clip_gradient_norm)
547
+
548
+ # Optimizer step
549
+ optim.step()
550
+ optim.zero_grad(set_to_none=True)
551
+
552
+ # Clear gradients more aggressively
553
+ for param in model.parameters():
554
+ if param.grad is not None:
555
+ param.grad.detach_()
556
+ param.grad = None
557
+
558
+ # Collect stats
559
+ if is_main:
560
+ infos.append(
561
+ {
562
+ "loss": loss.item(),
563
+ "learning_rate": optim.param_groups[0]["lr"],
564
+ "grad_norm": float(grad_norm) if isinstance(grad_norm, torch.Tensor) else grad_norm,
565
+ }
566
+ )
567
+
568
+ if is_main and (global_step % config.log_interval == 0):
569
+ elapsed = time.time() - start_time
570
+
571
+ # Average stats over log interval
572
+ avg_loss = sum(info["loss"] for info in infos) / len(infos)
573
+ avg_lr = sum(info["learning_rate"] for info in infos) / len(infos)
574
+
575
+ avg_grad_norm = None
576
+ if any("grad_norm" in info for info in infos):
577
+ vals = [
578
+ info["grad_norm"] for info in infos if "grad_norm" in info and info["grad_norm"] is not None
579
+ ]
580
+ if len(vals) > 0:
581
+ avg_grad_norm = sum(vals) / len(vals)
582
+ logging.info(
583
+ f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} grad_norm={avg_grad_norm:.2f} time={elapsed:.1f}s"
584
+ if avg_grad_norm is not None
585
+ else f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} time={elapsed:.1f}s"
586
+ )
587
+
588
+ # Log to wandb
589
+ if config.wandb_enabled and len(infos) > 0:
590
+ log_payload = {
591
+ "loss": avg_loss,
592
+ "learning_rate": avg_lr,
593
+ "step": global_step,
594
+ "time_per_step": elapsed / config.log_interval,
595
+ }
596
+ if avg_grad_norm is not None:
597
+ log_payload["grad_norm"] = avg_grad_norm
598
+ wandb.log(log_payload, step=global_step)
599
+
600
+ start_time = time.time()
601
+ infos = [] # Reset stats collection
602
+
603
+ global_step += 1
604
+ # Save checkpoint using the new mechanism
605
+ save_checkpoint(model, optim, global_step, config, is_main, data_config)
606
+
607
+ # Update progress bar
608
+ if pbar is not None:
609
+ pbar.update(1)
610
+ pbar.set_postfix(
611
+ {"loss": f"{loss.item():.4f}", "lr": f"{optim.param_groups[0]['lr']:.2e}", "step": global_step}
612
+ )
613
+
614
+ # Close progress bar
615
+ if pbar is not None:
616
+ pbar.close()
617
+
618
+ # Finish wandb run
619
+ if is_main and config.wandb_enabled:
620
+ wandb.finish()
621
+
622
+ cleanup_ddp()
623
+
624
+
625
+ def main():
626
+ init_logging()
627
+ config = _config.cli()
628
+ train_loop(config)
629
+
630
+
631
+ if __name__ == "__main__":
632
+ main()
capvector-pi05/scripts/train_regular_loss_pytorch.py ADDED
@@ -0,0 +1,754 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch training entrypoint for PI0/PI05 with multi-GPU and multi-node (DDP) support.
3
+ This script mirrors the behavior of the JAX trainer (`scripts/train.py`) but runs
4
+ entirely in PyTorch using the `PI0Pytorch` model and your existing config/data
5
+ pipeline from `src/openpi/training/config.py` and `src/openpi/training/data_loader.py`.
6
+
7
+ Usage
8
+ Single GPU:
9
+ python scripts/train_pytorch.py <config_name> --exp_name <run_name> --save_interval <interval>
10
+ Example:
11
+ python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test
12
+ python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test --resume # Resume from latest checkpoint
13
+ Multi-GPU (single node):
14
+ torchrun --standalone --nnodes=1 --nproc_per_node=<num_gpus> scripts/train_pytorch.py <config_name> --exp_name <run_name>
15
+ Example:
16
+ torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test
17
+ torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume
18
+ Multi-Node Training:
19
+ torchrun \
20
+ --nnodes=<num_nodes> --nproc_per_node=<gpus_per_node> --node_rank=<rank_of_node> \
21
+ --master_addr=<master_ip> --master_port=<port> \
22
+ scripts/train_pytorch.py <config_name> --exp_name=<run_name> --save_interval <interval>
23
+
24
+ """
25
+
26
+ import dataclasses
27
+ import gc
28
+ import logging
29
+ import os
30
+ import platform
31
+ from pathlib import Path
32
+ import shutil
33
+ import time
34
+
35
+ import jax
36
+ import numpy as np
37
+ import safetensors.torch
38
+ import torch
39
+ import torch.distributed as dist
40
+ import torch.nn.parallel
41
+ import tqdm
42
+ import wandb
43
+
44
+ import openpi.models.pi0_config
45
+ import openpi.models_pytorch.pi0_pytorch
46
+ import openpi.shared.normalize as _normalize
47
+ import openpi.training.config as _config
48
+ import openpi.training.data_loader as _data
49
+
50
+
51
+ def init_logging():
52
+ level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
53
+
54
+ class CustomFormatter(logging.Formatter):
55
+ def format(self, record):
56
+ record.levelname = level_mapping.get(record.levelname, record.levelname)
57
+ return super().format(record)
58
+
59
+ formatter = CustomFormatter(
60
+ fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
61
+ datefmt="%H:%M:%S",
62
+ )
63
+ logger = logging.getLogger()
64
+ logger.setLevel(logging.INFO)
65
+ if not logger.handlers:
66
+ ch = logging.StreamHandler()
67
+ ch.setFormatter(formatter)
68
+ logger.addHandler(ch)
69
+ else:
70
+ logger.handlers[0].setFormatter(formatter)
71
+
72
+
73
+ def init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: bool = True):
74
+ """Initialize wandb logging."""
75
+ if not enabled:
76
+ wandb.init(mode="disabled")
77
+ return
78
+
79
+ ckpt_dir = config.checkpoint_dir
80
+ if not ckpt_dir.exists():
81
+ raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
82
+
83
+ if resuming:
84
+ run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
85
+ wandb.init(id=run_id, resume="must", project=config.project_name)
86
+ else:
87
+ wandb.init(
88
+ name=config.name,
89
+ config=dataclasses.asdict(config),
90
+ project=config.project_name,
91
+ id="-".join([config.name, config.exp_name]),
92
+ )
93
+ (ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
94
+
95
+
96
+ def setup_ddp():
97
+ world_size = int(os.environ.get("WORLD_SIZE", "1"))
98
+ use_ddp = world_size > 1
99
+ if use_ddp and not torch.distributed.is_initialized():
100
+ backend = "nccl" if torch.cuda.is_available() else "gloo"
101
+ torch.distributed.init_process_group(backend=backend, init_method="env://")
102
+
103
+ # Set up debugging environment variables for DDP issues
104
+ if os.environ.get("TORCH_DISTRIBUTED_DEBUG") is None:
105
+ os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO"
106
+
107
+ local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0")))
108
+ device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
109
+ if torch.cuda.is_available():
110
+ torch.cuda.set_device(device)
111
+ return use_ddp, local_rank, device
112
+
113
+
114
+ def cleanup_ddp():
115
+ if torch.distributed.is_initialized():
116
+ torch.distributed.barrier()
117
+ torch.distributed.destroy_process_group()
118
+
119
+
120
+ def set_seed(seed: int, local_rank: int):
121
+ torch.manual_seed(seed + local_rank)
122
+ np.random.seed(seed + local_rank)
123
+ if torch.cuda.is_available():
124
+ torch.cuda.manual_seed_all(seed + local_rank)
125
+
126
+
127
+ def build_datasets(config: _config.TrainConfig):
128
+ # Use the unified data loader with PyTorch framework
129
+ data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=True)
130
+ return data_loader, data_loader.data_config()
131
+
132
+
133
+ def get_model_state_dict(model):
134
+ """Get state dict from model, handling DDP wrapper."""
135
+ return (
136
+ model.module.state_dict()
137
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel)
138
+ else model.state_dict()
139
+ )
140
+
141
+
142
+ def get_model_parameters(model):
143
+ """Get parameters from model, handling DDP wrapper."""
144
+ return (
145
+ model.module.parameters()
146
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel)
147
+ else model.parameters()
148
+ )
149
+
150
+
151
+ def load_regular_vector_dict(path: str | Path) -> dict[str, torch.Tensor]:
152
+ """Load the regularization vectors, which are used for delta-based regularization."""
153
+ tensor_path = Path(path)
154
+ suffix = tensor_path.suffix.lower()
155
+
156
+ if suffix in {".pt", ".pth"}:
157
+ tensors = torch.load(tensor_path, map_location="cpu", weights_only=False, mmap=True)
158
+ elif suffix == ".safetensors":
159
+ tensors = safetensors.torch.load_file(str(tensor_path), device="cpu")
160
+ else:
161
+ raise ValueError(f"Unsupported tensor file format: {tensor_path}")
162
+
163
+ return tensors["state_dict"]
164
+
165
+
166
+ def prepare_regularization_context(
167
+ model,
168
+ config: _config.TrainConfig,
169
+ ) -> dict | None:
170
+ """Load regularization tensors and build the runtime context for delta-based regularization."""
171
+
172
+ # Don't use regularization optionally
173
+ if not config.regularization_vector_path or config.regularization_coeff == 0.0:
174
+ return None
175
+
176
+ # Get the regularization vectors as reference directions
177
+ if config.resume:
178
+ raise ValueError(
179
+ "Delta-based regularization with --resume is not supported in this PyTorch trainer. "
180
+ "This run now keeps the anchor only in memory at startup."
181
+ )
182
+ vector_path = Path(config.regularization_vector_path).expanduser()
183
+ if not vector_path.exists():
184
+ raise FileNotFoundError(f"Regularization vector file does not exist: {vector_path}")
185
+ regularization_vectors = load_regular_vector_dict(vector_path)
186
+
187
+ # Get the model's trainable parameters to be regularized and the corresponding freezing anchors at startup
188
+ model_module = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
189
+
190
+ trainable_entries = []
191
+ missing_vectors = 0
192
+ shape_mismatches = 0
193
+ trainable_param_names = set()
194
+
195
+ for name, param in model_module.named_parameters():
196
+ if not param.requires_grad:
197
+ continue
198
+ trainable_param_names.add(name)
199
+ regularization_vector = regularization_vectors.get(name)
200
+ if regularization_vector is None:
201
+ missing_vectors += 1
202
+ continue
203
+ anchor_param = param.detach().clone().contiguous()
204
+ if regularization_vector.shape != param.shape or anchor_param.shape != param.shape:
205
+ shape_mismatches += 1
206
+ continue
207
+ trainable_entries.append(
208
+ {
209
+ "name": name,
210
+ "param": param,
211
+ "anchor": anchor_param,
212
+ "vector": regularization_vector.to(device=param.device, dtype=param.dtype).contiguous(),
213
+ }
214
+ )
215
+
216
+ logging.info(
217
+ "Regularization coverage: matched=%d missing_vectors=%d shape_mismatches=%d",
218
+ len(trainable_entries),
219
+ missing_vectors,
220
+ shape_mismatches,
221
+ )
222
+
223
+ return {
224
+ "entries": trainable_entries,
225
+ "weight": config.regularization_coeff,
226
+ "vector_path": str(vector_path),
227
+ }
228
+
229
+
230
+ def compute_regularization_loss(regularization_context: dict | None, device: torch.device) -> torch.Tensor:
231
+ """Compute the delta-based regularization loss for the current model parameters."""
232
+ reg_loss = torch.zeros((), device=device, dtype=torch.float32)
233
+
234
+ if not regularization_context:
235
+ return reg_loss
236
+
237
+ for entry in regularization_context["entries"]:
238
+ param = entry["param"]
239
+ anchor = entry["anchor"]
240
+ vector = entry["vector"]
241
+
242
+ delta = (param - anchor).reshape(-1).float()
243
+ direction = vector.reshape(-1).float()
244
+ reg_loss = reg_loss + torch.abs(torch.dot(delta, direction))
245
+
246
+ return reg_loss * regularization_context["weight"]
247
+
248
+
249
+ def save_checkpoint(model, optimizer, global_step, config, is_main, data_config):
250
+ """Save a checkpoint with model state, optimizer state, and metadata."""
251
+ if not is_main:
252
+ return
253
+
254
+ # Only save if it's time to save or if it's the final step
255
+ if (global_step % config.save_interval == 0 and global_step > 0) or global_step == config.num_train_steps - 1:
256
+ # Create temporary directory for atomic checkpoint saving
257
+ final_ckpt_dir = config.checkpoint_dir / f"{global_step}"
258
+ tmp_ckpt_dir = config.checkpoint_dir / f"tmp_{global_step}"
259
+
260
+ # Remove any existing temp directory and create new one
261
+ if tmp_ckpt_dir.exists():
262
+ shutil.rmtree(tmp_ckpt_dir)
263
+ tmp_ckpt_dir.mkdir(parents=True, exist_ok=True)
264
+
265
+ # Save model state using safetensors (handle shared tensors)
266
+ model_to_save = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
267
+ safetensors.torch.save_model(model_to_save, tmp_ckpt_dir / "model.safetensors")
268
+
269
+ # Save optimizer state using PyTorch format
270
+ torch.save(optimizer.state_dict(), tmp_ckpt_dir / "optimizer.pt")
271
+
272
+ # Save training metadata (avoid saving full config to prevent JAX/Flax compatibility issues)
273
+ metadata = {
274
+ "global_step": global_step,
275
+ "config": dataclasses.asdict(config),
276
+ "timestamp": time.time(),
277
+ }
278
+ torch.save(metadata, tmp_ckpt_dir / "metadata.pt")
279
+
280
+ # save norm stats
281
+ norm_stats = data_config.norm_stats
282
+ if norm_stats is not None and data_config.asset_id is not None:
283
+ _normalize.save(tmp_ckpt_dir / "assets" / data_config.asset_id, norm_stats)
284
+
285
+ # Atomically move temp directory to final location
286
+ if final_ckpt_dir.exists():
287
+ shutil.rmtree(final_ckpt_dir)
288
+ tmp_ckpt_dir.rename(final_ckpt_dir)
289
+
290
+ logging.info(f"Saved checkpoint at step {global_step} -> {final_ckpt_dir}")
291
+
292
+ # Log checkpoint to wandb
293
+ if config.wandb_enabled:
294
+ wandb.log({"checkpoint_step": global_step}, step=global_step)
295
+
296
+
297
+ def load_checkpoint(model, optimizer, checkpoint_dir, device):
298
+ """Load the latest checkpoint and return the global step."""
299
+ checkpoint_steps = [
300
+ int(d.name)
301
+ for d in checkpoint_dir.iterdir()
302
+ if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_")
303
+ ]
304
+
305
+ if not checkpoint_steps:
306
+ raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
307
+
308
+ latest_step = max(checkpoint_steps)
309
+ ckpt_dir = checkpoint_dir / f"{latest_step}"
310
+
311
+ # Clear memory before loading checkpoints
312
+ if torch.cuda.is_available():
313
+ torch.cuda.empty_cache()
314
+ gc.collect()
315
+ log_memory_usage(device, latest_step, "before_loading_checkpoint")
316
+
317
+ try:
318
+ # Load model state with error handling
319
+ logging.info("Loading model state...")
320
+ safetensors_path = ckpt_dir / "model.safetensors"
321
+
322
+ if safetensors_path.exists():
323
+ model_to_load = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
324
+ safetensors.torch.load_model(model_to_load, safetensors_path, device=str(device))
325
+ logging.info("Loaded model state from safetensors format")
326
+ else:
327
+ raise FileNotFoundError(f"No model checkpoint found at {ckpt_dir}")
328
+
329
+ torch.cuda.empty_cache()
330
+ gc.collect()
331
+ log_memory_usage(device, latest_step, "after_loading_model")
332
+
333
+ # Load optimizer state with error handling
334
+ logging.info("Loading optimizer state...")
335
+ optimizer_path = ckpt_dir / "optimizer.pt"
336
+
337
+ if optimizer_path.exists():
338
+ optimizer_state_dict = torch.load(optimizer_path, map_location=device, weights_only=False)
339
+ logging.info("Loaded optimizer state from pt format")
340
+ else:
341
+ raise FileNotFoundError(f"No optimizer checkpoint found at {ckpt_dir}")
342
+
343
+ optimizer.load_state_dict(optimizer_state_dict)
344
+ del optimizer_state_dict
345
+ torch.cuda.empty_cache()
346
+ gc.collect()
347
+ log_memory_usage(device, latest_step, "after_loading_optimizer")
348
+
349
+ # Load metadata
350
+ logging.info("Loading metadata...")
351
+ metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False)
352
+ global_step = metadata.get("global_step", latest_step)
353
+ del metadata
354
+ torch.cuda.empty_cache()
355
+ gc.collect()
356
+ log_memory_usage(device, latest_step, "after_loading_metadata")
357
+
358
+ logging.info(f"Successfully loaded all checkpoint components from step {latest_step}")
359
+ return global_step
360
+
361
+ except RuntimeError as e:
362
+ if "out of memory" in str(e):
363
+ # Clear memory and provide detailed error message
364
+ torch.cuda.empty_cache()
365
+ gc.collect()
366
+ logging.error(f"Out of memory error while loading checkpoint: {e!s}")
367
+ log_memory_usage(device, latest_step, "after_oom_error")
368
+ raise RuntimeError(
369
+ "Out of memory while loading checkpoint. Try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
370
+ ) from e
371
+ raise
372
+
373
+
374
+ def get_latest_checkpoint_step(checkpoint_dir):
375
+ """Get the latest checkpoint step number from a checkpoint directory."""
376
+ checkpoint_steps = [
377
+ int(d.name)
378
+ for d in checkpoint_dir.iterdir()
379
+ if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_")
380
+ ]
381
+ return max(checkpoint_steps) if checkpoint_steps else None
382
+
383
+
384
+ def log_memory_usage(device, step, phase="unknown"):
385
+ """Log detailed memory usage information."""
386
+ if not torch.cuda.is_available():
387
+ return
388
+
389
+ memory_allocated = torch.cuda.memory_allocated(device) / 1e9
390
+ memory_reserved = torch.cuda.memory_reserved(device) / 1e9
391
+ memory_free = torch.cuda.memory_reserved(device) - torch.cuda.memory_allocated(device)
392
+ memory_free = memory_free / 1e9
393
+
394
+ # Get more detailed memory info
395
+ memory_stats = torch.cuda.memory_stats(device)
396
+ max_memory_allocated = memory_stats.get("allocated_bytes.all.peak", 0) / 1e9
397
+ max_memory_reserved = memory_stats.get("reserved_bytes.all.peak", 0) / 1e9
398
+
399
+ # Get DDP info if available
400
+ ddp_info = ""
401
+ if dist.is_initialized():
402
+ ddp_info = f" | DDP: rank={dist.get_rank()}, world_size={dist.get_world_size()}"
403
+
404
+ logging.info(
405
+ f"Step {step} ({phase}): GPU memory - allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB, free: {memory_free:.2f}GB, peak_allocated: {max_memory_allocated:.2f}GB, peak_reserved: {max_memory_reserved:.2f}GB{ddp_info}"
406
+ )
407
+
408
+
409
+ def train_loop(config: _config.TrainConfig):
410
+ use_ddp, local_rank, device = setup_ddp()
411
+ is_main = (not use_ddp) or (dist.get_rank() == 0)
412
+ set_seed(config.seed, local_rank)
413
+
414
+ # Initialize checkpoint directory and wandb
415
+ resuming = False
416
+ if config.resume:
417
+ # Find checkpoint directory based on experiment name
418
+ exp_checkpoint_dir = config.checkpoint_dir
419
+ if exp_checkpoint_dir.exists():
420
+ # Use validation to find the latest working checkpoint
421
+ latest_step = get_latest_checkpoint_step(exp_checkpoint_dir)
422
+ if latest_step is not None:
423
+ resuming = True
424
+ logging.info(
425
+ f"Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}"
426
+ )
427
+ else:
428
+ raise FileNotFoundError(f"No valid checkpoints found in {exp_checkpoint_dir} for resume")
429
+ else:
430
+ raise FileNotFoundError(f"Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume")
431
+ elif config.overwrite and config.checkpoint_dir.exists():
432
+ shutil.rmtree(config.checkpoint_dir)
433
+ logging.info(f"Overwriting checkpoint directory: {config.checkpoint_dir}")
434
+
435
+ # Create checkpoint directory with experiment name
436
+ if not resuming:
437
+ # For new runs, create experiment-specific checkpoint directory
438
+ exp_checkpoint_dir = config.checkpoint_dir
439
+ exp_checkpoint_dir.mkdir(parents=True, exist_ok=True)
440
+ logging.info(f"Created experiment checkpoint directory: {exp_checkpoint_dir}")
441
+ else:
442
+ # For resume, checkpoint_dir is already set to the experiment directory
443
+ logging.info(f"Using existing experiment checkpoint directory: {config.checkpoint_dir}")
444
+
445
+ # Initialize wandb (only on main process)
446
+ if is_main:
447
+ init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
448
+
449
+ # Build data loader using the unified data loader
450
+ # Calculate effective batch size per GPU for DDP
451
+ # For N GPUs, each GPU should get batch_size/N samples, so total across all GPUs is batch_size
452
+ world_size = torch.distributed.get_world_size() if use_ddp else 1
453
+ effective_batch_size = config.batch_size // world_size
454
+ logging.info(
455
+ f"Using batch size per GPU: {effective_batch_size} (total batch size across {world_size} GPUs: {config.batch_size})"
456
+ )
457
+
458
+ # Pass the original batch size to data loader - it will handle DDP splitting internally
459
+ loader, data_config = build_datasets(config)
460
+
461
+ # Log sample images to wandb on first batch
462
+ if is_main and config.wandb_enabled and not resuming:
463
+ # Create a separate data loader for sample batch to avoid consuming the main loader
464
+ sample_data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=False)
465
+ sample_batch = next(iter(sample_data_loader))
466
+ # Convert observation and actions to torch tensors
467
+ observation, actions = sample_batch
468
+ sample_batch = observation.to_dict()
469
+ sample_batch["actions"] = actions
470
+
471
+ # Create sample images for wandb
472
+ images_to_log = []
473
+ # Get batch size from the first image tensor
474
+ batch_size = next(iter(sample_batch["image"].values())).shape[0]
475
+ for i in range(min(5, batch_size)):
476
+ # Concatenate all camera views horizontally for this batch item
477
+ # Convert from NCHW to NHWC format for wandb
478
+ img_concatenated = torch.cat([img[i].permute(1, 2, 0) for img in sample_batch["image"].values()], axis=1)
479
+ img_concatenated = img_concatenated.cpu().numpy()
480
+ images_to_log.append(wandb.Image(img_concatenated))
481
+
482
+ wandb.log({"camera_views": images_to_log}, step=0)
483
+
484
+ # Clear sample batch from memory aggressively
485
+ del sample_batch, observation, actions, images_to_log, img_concatenated
486
+ del sample_data_loader # Also delete the sample data loader
487
+ gc.collect()
488
+ if torch.cuda.is_available():
489
+ torch.cuda.empty_cache()
490
+ logging.info("Cleared sample batch and data loader from memory")
491
+
492
+ # Build model
493
+ if not isinstance(config.model, openpi.models.pi0_config.Pi0Config):
494
+ # Convert dataclass to Pi0Config if needed
495
+ model_cfg = openpi.models.pi0_config.Pi0Config(
496
+ dtype=config.pytorch_training_precision,
497
+ action_dim=config.model.action_dim,
498
+ action_horizon=config.model.action_horizon,
499
+ max_token_len=config.model.max_token_len,
500
+ paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"),
501
+ action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"),
502
+ pi05=getattr(config.model, "pi05", False),
503
+ )
504
+ else:
505
+ model_cfg = config.model
506
+ # Update dtype to match pytorch_training_precision
507
+ object.__setattr__(model_cfg, "dtype", config.pytorch_training_precision)
508
+
509
+ model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_cfg).to(device)
510
+
511
+ if hasattr(model, "gradient_checkpointing_enable"):
512
+ enable_gradient_checkpointing = True
513
+ model.gradient_checkpointing_enable()
514
+ logging.info("Enabled gradient checkpointing for memory optimization")
515
+ else:
516
+ enable_gradient_checkpointing = False
517
+ logging.info("Gradient checkpointing is not supported for this model")
518
+
519
+ # Log initial memory usage after model creation
520
+ if is_main and torch.cuda.is_available():
521
+ log_memory_usage(device, 0, "after_model_creation")
522
+
523
+ # Enable memory optimizations for large-scale training
524
+ if world_size >= 8:
525
+ torch.backends.cudnn.benchmark = True
526
+ torch.backends.cuda.matmul.allow_tf32 = True
527
+ torch.backends.cudnn.allow_tf32 = True
528
+ # Set memory allocation configuration
529
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
530
+ logging.info("Enabled memory optimizations for 8+ GPU training")
531
+
532
+ if use_ddp:
533
+ model = torch.nn.parallel.DistributedDataParallel(
534
+ model,
535
+ device_ids=[device.index] if device.type == "cuda" else None,
536
+ find_unused_parameters=True, # Disable for memory efficiency
537
+ gradient_as_bucket_view=True, # Enable for memory efficiency
538
+ static_graph=world_size >= 8, # Enable for 8+ GPUs
539
+ )
540
+
541
+ # Load weights from weight_loader if specified (for fine-tuning)
542
+ if config.pytorch_weight_path is not None:
543
+ logging.info(f"Loading weights from: {config.pytorch_weight_path}")
544
+
545
+ model_path = os.path.join(config.pytorch_weight_path, "model.safetensors")
546
+ safetensors.torch.load_model(
547
+ (model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model), model_path
548
+ )
549
+ logging.info(f"Loaded PyTorch weights from {config.pytorch_weight_path}")
550
+
551
+ regularization_context = prepare_regularization_context(model, config)
552
+
553
+ # Optimizer + learning rate schedule from config
554
+ warmup_steps = config.lr_schedule.warmup_steps
555
+ peak_lr = config.lr_schedule.peak_lr
556
+ decay_steps = config.lr_schedule.decay_steps
557
+ end_lr = config.lr_schedule.decay_lr
558
+
559
+ # Create optimizer with config parameters
560
+ optim = torch.optim.AdamW(
561
+ model.parameters(),
562
+ lr=peak_lr,
563
+ betas=(config.optimizer.b1, config.optimizer.b2),
564
+ eps=config.optimizer.eps,
565
+ weight_decay=config.optimizer.weight_decay,
566
+ )
567
+
568
+ # Load checkpoint if resuming
569
+ global_step = 0
570
+ if resuming:
571
+ global_step = load_checkpoint(model, optim, config.checkpoint_dir, device)
572
+ logging.info(f"Resumed training from step {global_step}")
573
+
574
+ def lr_schedule(step: int):
575
+ if step < warmup_steps:
576
+ # Match JAX behavior: start from peak_lr / (warmup_steps + 1)
577
+ init_lr = peak_lr / (warmup_steps + 1)
578
+ return init_lr + (peak_lr - init_lr) * step / warmup_steps
579
+ # cosine decay
580
+ progress = min(1.0, (step - warmup_steps) / max(1, decay_steps - warmup_steps))
581
+ cos = 0.5 * (1 + np.cos(np.pi * progress))
582
+ return end_lr + (peak_lr - end_lr) * cos
583
+
584
+ model.train()
585
+ start_time = time.time()
586
+ infos = [] # Collect stats over log interval
587
+ if is_main:
588
+ logging.info(
589
+ f"Running on: {platform.node()} | world_size={torch.distributed.get_world_size() if use_ddp else 1}"
590
+ )
591
+ logging.info(
592
+ f"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}"
593
+ )
594
+ logging.info(f"Memory optimizations: gradient_checkpointing={enable_gradient_checkpointing}")
595
+ logging.info(
596
+ f"LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}"
597
+ )
598
+ logging.info(
599
+ f"Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}"
600
+ )
601
+ logging.info("EMA is not supported for PyTorch training")
602
+ logging.info(f"Training precision: {model_cfg.dtype}")
603
+ if regularization_context:
604
+ logging.info(
605
+ "Delta-based regularization: enabled | weight=%.2e | vector=%s",
606
+ config.regularization_coeff,
607
+ regularization_context["vector_path"],
608
+ )
609
+
610
+ # Training loop - iterate until we reach num_train_steps
611
+ pbar = (
612
+ tqdm.tqdm(total=config.num_train_steps, initial=global_step, desc="Training", disable=not is_main)
613
+ if is_main
614
+ else None
615
+ )
616
+
617
+ while global_step < config.num_train_steps:
618
+ # Set epoch for distributed training
619
+ if use_ddp and hasattr(loader, "set_epoch"):
620
+ loader.set_epoch(global_step // len(loader))
621
+
622
+ for observation, actions in loader:
623
+ # Check if we've reached the target number of steps
624
+ if global_step >= config.num_train_steps:
625
+ break
626
+
627
+ # The unified data loader returns (observation, actions) tuple
628
+ observation = jax.tree.map(lambda x: x.to(device), observation) # noqa: PLW2901
629
+ actions = actions.to(torch.float32) # noqa: PLW2901
630
+ actions = actions.to(device) # noqa: PLW2901
631
+
632
+ # Update LR
633
+ for pg in optim.param_groups:
634
+ pg["lr"] = lr_schedule(global_step)
635
+
636
+ # Forward pass
637
+ losses = model(observation, actions)
638
+ # Ensure losses is a tensor and handle different return types
639
+ if isinstance(losses, list | tuple):
640
+ losses = torch.stack(losses)
641
+ elif not isinstance(losses, torch.Tensor):
642
+ losses = torch.tensor(losses, device=device, dtype=torch.float32)
643
+
644
+ action_loss = losses.mean()
645
+ regularization_loss = compute_regularization_loss(regularization_context, device)
646
+ total_loss = action_loss + regularization_loss
647
+
648
+ # Backward pass
649
+ total_loss.backward()
650
+
651
+ # Log memory usage after backward pass
652
+ if global_step < 5 and is_main and torch.cuda.is_available():
653
+ log_memory_usage(device, global_step, "after_backward")
654
+
655
+ # Gradient clipping
656
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optimizer.clip_gradient_norm)
657
+
658
+ # Optimizer step
659
+ optim.step()
660
+ optim.zero_grad(set_to_none=True)
661
+
662
+ # Clear gradients more aggressively
663
+ for param in model.parameters():
664
+ if param.grad is not None:
665
+ param.grad.detach_()
666
+ param.grad = None
667
+
668
+ # Collect stats
669
+ if is_main:
670
+ infos.append(
671
+ {
672
+ "action_loss": action_loss.item(),
673
+ "regularization_loss": regularization_loss.item(),
674
+ "total_loss": total_loss.item(),
675
+ "learning_rate": optim.param_groups[0]["lr"],
676
+ "grad_norm": float(grad_norm) if isinstance(grad_norm, torch.Tensor) else grad_norm,
677
+ }
678
+ )
679
+
680
+ if is_main and (global_step % config.log_interval == 0):
681
+ elapsed = time.time() - start_time
682
+
683
+ # Average stats over log interval
684
+ avg_action_loss = sum(info["action_loss"] for info in infos) / len(infos)
685
+ avg_regularization_loss = sum(info["regularization_loss"] for info in infos) / len(infos)
686
+ avg_total_loss = sum(info["total_loss"] for info in infos) / len(infos)
687
+ avg_lr = sum(info["learning_rate"] for info in infos) / len(infos)
688
+
689
+ avg_grad_norm = None
690
+ if any("grad_norm" in info for info in infos):
691
+ vals = [
692
+ info["grad_norm"] for info in infos if "grad_norm" in info and info["grad_norm"] is not None
693
+ ]
694
+ if len(vals) > 0:
695
+ avg_grad_norm = sum(vals) / len(vals)
696
+ logging.info(
697
+ f"step={global_step} action_loss={avg_action_loss:.4f} regularization_loss={avg_regularization_loss:.4f} total_loss={avg_total_loss:.4f} lr={avg_lr:.2e} grad_norm={avg_grad_norm:.2f} time={elapsed:.1f}s"
698
+ if avg_grad_norm is not None
699
+ else f"step={global_step} action_loss={avg_action_loss:.4f} regularization_loss={avg_regularization_loss:.4f} total_loss={avg_total_loss:.4f} lr={avg_lr:.2e} time={elapsed:.1f}s"
700
+ )
701
+
702
+ # Log to wandb
703
+ if config.wandb_enabled and len(infos) > 0:
704
+ log_payload = {
705
+ "action_loss": avg_action_loss,
706
+ "regularization_loss": avg_regularization_loss,
707
+ "total_loss": avg_total_loss,
708
+ "learning_rate": avg_lr,
709
+ "step": global_step,
710
+ "time_per_step": elapsed / config.log_interval,
711
+ }
712
+ if avg_grad_norm is not None:
713
+ log_payload["grad_norm"] = avg_grad_norm
714
+ wandb.log(log_payload, step=global_step)
715
+
716
+ start_time = time.time()
717
+ infos = [] # Reset stats collection
718
+
719
+ global_step += 1
720
+ # Save checkpoint using the new mechanism
721
+ save_checkpoint(model, optim, global_step, config, is_main, data_config)
722
+
723
+ # Update progress bar
724
+ if pbar is not None:
725
+ pbar.update(1)
726
+ pbar.set_postfix(
727
+ {
728
+ "action_loss": f"{action_loss.item():.4f}",
729
+ "reg_loss": f"{regularization_loss.item():.4f}",
730
+ "total_loss": f"{total_loss.item():.4f}",
731
+ "lr": f"{optim.param_groups[0]['lr']:.2e}",
732
+ "step": global_step,
733
+ }
734
+ )
735
+
736
+ # Close progress bar
737
+ if pbar is not None:
738
+ pbar.close()
739
+
740
+ # Finish wandb run
741
+ if is_main and config.wandb_enabled:
742
+ wandb.finish()
743
+
744
+ cleanup_ddp()
745
+
746
+
747
+ def main():
748
+ init_logging()
749
+ config = _config.cli()
750
+ train_loop(config)
751
+
752
+
753
+ if __name__ == "__main__":
754
+ main()
capvector-pi05/scripts/train_test.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import os
3
+ import pathlib
4
+
5
+ import pytest
6
+
7
+ os.environ["JAX_PLATFORMS"] = "cpu"
8
+
9
+ from openpi.training import config as _config
10
+
11
+ from . import train
12
+
13
+
14
+ @pytest.mark.parametrize("config_name", ["debug"])
15
+ def test_train(tmp_path: pathlib.Path, config_name: str):
16
+ config = dataclasses.replace(
17
+ _config._CONFIGS_DICT[config_name], # noqa: SLF001
18
+ batch_size=2,
19
+ checkpoint_base_dir=str(tmp_path / "checkpoint"),
20
+ exp_name="test",
21
+ overwrite=False,
22
+ resume=False,
23
+ num_train_steps=2,
24
+ log_interval=1,
25
+ )
26
+ train.main(config)
27
+
28
+ # test resuming
29
+ config = dataclasses.replace(config, resume=True, num_train_steps=4)
30
+ train.main(config)
capvector-pi05/src/openpi/__init__.py ADDED
File without changes
capvector-pi05/src/openpi/conftest.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import pynvml
4
+ import pytest
5
+
6
+
7
+ def set_jax_cpu_backend_if_no_gpu() -> None:
8
+ try:
9
+ pynvml.nvmlInit()
10
+ pynvml.nvmlShutdown()
11
+ except pynvml.NVMLError:
12
+ # No GPU found.
13
+ os.environ["JAX_PLATFORMS"] = "cpu"
14
+
15
+
16
+ def pytest_configure(config: pytest.Config) -> None:
17
+ set_jax_cpu_backend_if_no_gpu()
capvector-pi05/src/openpi/models/__init__.py ADDED
File without changes
capvector-pi05/src/openpi/models/gemma.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Gemma adaptation for Pi, taken from big_vision.
16
+
17
+ We follow this einsum axis naming convention:
18
+ B: batch
19
+ T: query length
20
+ S: k/v length
21
+ N: num query heads
22
+ K: num k/v heads
23
+ G: num query heads per k/v head
24
+ H: head dim
25
+ D: d_model ("features")
26
+ """
27
+
28
+ from collections.abc import Sequence
29
+ import dataclasses
30
+ from typing import Literal, TypeAlias
31
+
32
+ import einops
33
+ import flax.linen as nn
34
+ import jax
35
+ import jax.numpy as jnp
36
+
37
+ import openpi.models.lora as lora
38
+ import openpi.shared.array_typing as at
39
+ import openpi.training.sharding as sharding
40
+
41
+ PALIGEMMA_VOCAB_SIZE = 257_152
42
+
43
+
44
+ @dataclasses.dataclass
45
+ class Config:
46
+ width: int
47
+ depth: int
48
+ mlp_dim: int
49
+ num_heads: int
50
+ num_kv_heads: int
51
+ head_dim: int
52
+ lora_configs: dict[str, lora.LoRAConfig] = dataclasses.field(default_factory=dict)
53
+
54
+
55
+ Variant = Literal["dummy", "gemma_300m", "gemma_300m_lora", "gemma_2b", "gemma_2b_lora"]
56
+
57
+
58
+ def get_config(variant: Variant) -> Config:
59
+ """Returns config for specified gemma variant."""
60
+ if variant == "dummy":
61
+ return Config(
62
+ width=64,
63
+ depth=4,
64
+ mlp_dim=128,
65
+ num_heads=8,
66
+ num_kv_heads=1,
67
+ head_dim=16,
68
+ )
69
+ if variant == "gemma_300m":
70
+ # 311M params
71
+ return Config(
72
+ width=1024,
73
+ depth=18,
74
+ mlp_dim=4096,
75
+ num_heads=8,
76
+ num_kv_heads=1,
77
+ head_dim=256,
78
+ )
79
+ if variant == "gemma_2b":
80
+ return Config(
81
+ width=2048,
82
+ depth=18,
83
+ mlp_dim=16_384,
84
+ num_heads=8,
85
+ num_kv_heads=1,
86
+ head_dim=256,
87
+ )
88
+ if variant == "gemma_2b_lora":
89
+ return Config(
90
+ width=2048,
91
+ depth=18,
92
+ mlp_dim=16_384,
93
+ num_heads=8,
94
+ num_kv_heads=1,
95
+ head_dim=256,
96
+ lora_configs={"attn": lora.LoRAConfig(rank=16, alpha=16.0), "ffn": lora.LoRAConfig(rank=16, alpha=16.0)},
97
+ )
98
+ if variant == "gemma_300m_lora":
99
+ # 311M params
100
+ return Config(
101
+ width=1024,
102
+ depth=18,
103
+ mlp_dim=4096,
104
+ num_heads=8,
105
+ num_kv_heads=1,
106
+ head_dim=256,
107
+ lora_configs={"attn": lora.LoRAConfig(rank=32, alpha=32.0), "ffn": lora.LoRAConfig(rank=32, alpha=32.0)},
108
+ )
109
+ raise ValueError(f"Unknown variant: {variant}")
110
+
111
+
112
+ @at.typecheck
113
+ class RMSNorm(nn.Module):
114
+ @nn.compact
115
+ def __call__(self, x, cond):
116
+ dtype = x.dtype # original dtype, could be half-precision
117
+ var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True) # compute variance in float32
118
+ normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) # compute normalization in float32
119
+ if cond is None:
120
+ # regular RMSNorm
121
+ scale = self.param("scale", nn.initializers.zeros_init(), (x.shape[-1]))
122
+ normed_inputs = normed_inputs * (
123
+ 1 + scale
124
+ ) # scale by learned parameter in float32 (matches Flax implementation)
125
+ return normed_inputs.astype(dtype), None # return in original dtype
126
+
127
+ # adaptive RMSNorm
128
+ modulation = nn.Dense(x.shape[-1] * 3, kernel_init=nn.initializers.zeros, dtype=dtype)(cond)
129
+ scale, shift, gate = jnp.split(modulation[:, None, :], 3, axis=-1)
130
+ normed_inputs = normed_inputs * (1 + scale) + shift # scale and shift in float32
131
+ return normed_inputs.astype(dtype), gate
132
+
133
+
134
+ @at.typecheck
135
+ class Embedder(nn.Module):
136
+ """Embedder module."""
137
+
138
+ vocab_size: int
139
+ embed_dim: int
140
+
141
+ def setup(self):
142
+ self.input_embedding_table = self.param(
143
+ "input_embedding",
144
+ nn.initializers.normal(),
145
+ (self.vocab_size, self.embed_dim),
146
+ )
147
+
148
+ def encode(self, x):
149
+ x = self.input_embedding_table[(x,)]
150
+ x *= jnp.sqrt(self.embed_dim).astype(x.dtype)
151
+ return x
152
+
153
+ def decode(self, x):
154
+ return jnp.dot(x, self.input_embedding_table.T)
155
+
156
+
157
+ @at.typecheck
158
+ class Attention(nn.Module):
159
+ """Attention module."""
160
+
161
+ configs: Sequence[Config]
162
+
163
+ @nn.compact
164
+ def __call__(self, xs, positions, attn_mask, kv_cache):
165
+ # all experts must share the same head dim, num heads, and num kv heads for self-attention to work
166
+ assert all(config.head_dim == self.configs[0].head_dim for config in self.configs)
167
+ assert all(config.num_heads == self.configs[0].num_heads for config in self.configs)
168
+ assert all(config.num_kv_heads == self.configs[0].num_kv_heads for config in self.configs)
169
+
170
+ dtype = next(x.dtype for x in xs if x is not None) # original dtype, could be half-precision
171
+
172
+ qkvs = []
173
+ for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
174
+ if x is None:
175
+ continue
176
+ if config.num_kv_heads == config.num_heads:
177
+ qkv_einsum = lora.Einsum(
178
+ shape=(3, config.num_heads, config.width, config.head_dim),
179
+ name=_name("qkv_einsum", i),
180
+ init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
181
+ lora_config=config.lora_configs.get("attn"),
182
+ )
183
+ qkvs.append(qkv_einsum("BSD,3KDH->3BSKH", x))
184
+ else:
185
+ q_einsum = lora.Einsum(
186
+ shape=(config.num_heads, config.width, config.head_dim),
187
+ name=_name("q_einsum", i),
188
+ init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
189
+ lora_config=config.lora_configs.get("attn"),
190
+ )
191
+ q = q_einsum("BTD,NDH->BTNH", x)
192
+ kv_einsum = lora.Einsum(
193
+ shape=(2, config.num_kv_heads, config.width, config.head_dim),
194
+ name=_name("kv_einsum", i),
195
+ init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
196
+ lora_config=config.lora_configs.get("attn"),
197
+ )
198
+ k, v = kv_einsum("BSD,2KDH->2BSKH", x)
199
+ qkvs.append((q, k, v))
200
+
201
+ q, k, v = (jnp.concatenate(y, axis=1) for y in zip(*qkvs, strict=True))
202
+
203
+ q = _apply_rope(q, positions=positions)
204
+ q *= self.configs[0].head_dim ** -0.5
205
+
206
+ k = _apply_rope(k, positions=positions)
207
+
208
+ # should still be half-precision here (if input was half-precision)
209
+ assert q.dtype == k.dtype == v.dtype == dtype
210
+
211
+ if kv_cache is not None:
212
+ cache_k, cache_v = kv_cache
213
+ k = jnp.concatenate([cache_k, k], axis=1)
214
+ v = jnp.concatenate([cache_v, v], axis=1)
215
+
216
+ q = einops.rearrange(q, "B T (K G) H -> B T K G H", K=self.configs[0].num_kv_heads)
217
+ logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k, preferred_element_type=jnp.float32)
218
+
219
+ if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]):
220
+ raise ValueError(
221
+ f"Attention mask with shape {attn_mask.shape} but shapes for q and k are: {q.shape} and {k.shape}"
222
+ )
223
+
224
+ # big_neg = jnp.finfo(logits.dtype).min
225
+ big_neg = -2.3819763e38 # See gemma/modules.py
226
+ masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg)
227
+
228
+ probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype)
229
+
230
+ encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v)
231
+ encoded = einops.rearrange(encoded, "B T K G H -> B T (K G) H")
232
+
233
+ out = []
234
+ start = 0
235
+ for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
236
+ if x is not None:
237
+ end = start + x.shape[1]
238
+ out_einsum = lora.Einsum(
239
+ shape=(config.num_heads, config.head_dim, config.width),
240
+ name=_name("attn_vec_einsum", i),
241
+ init_fn=nn.initializers.lecun_normal(in_axis=(-3, -2), out_axis=-1),
242
+ lora_config=config.lora_configs.get("attn"),
243
+ )
244
+ out.append(out_einsum("BTNH,NHD->BTD", encoded[:, start:end]))
245
+ start = end
246
+ else:
247
+ out.append(None)
248
+
249
+ return out, (k, v)
250
+
251
+
252
+ @at.typecheck
253
+ class FeedForward(nn.Module):
254
+ """Feed forward module."""
255
+
256
+ features: int
257
+ hidden_dim: int
258
+
259
+ @nn.compact
260
+ def __call__(self, x):
261
+ dtype = x.dtype # original dtype, could be half-precision
262
+ w_gating = self.param(
263
+ "gating_einsum",
264
+ nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
265
+ (2, self.features, self.hidden_dim),
266
+ ).astype(dtype)
267
+ ff_gate = jnp.dot(x, w_gating[0])
268
+ gate_value = nn.gelu(ff_gate)
269
+
270
+ ff1 = jnp.dot(x, w_gating[1])
271
+ activations = gate_value * ff1
272
+
273
+ w_linear = self.param(
274
+ "linear",
275
+ nn.initializers.lecun_normal(in_axis=-2, out_axis=-1),
276
+ (self.hidden_dim, self.features),
277
+ ).astype(dtype)
278
+ outputs = jnp.dot(activations, w_linear)
279
+ assert outputs.dtype == dtype
280
+ return outputs
281
+
282
+
283
+ @at.typecheck
284
+ class Block(nn.Module):
285
+ """Transformer block."""
286
+
287
+ configs: tuple[Config, ...]
288
+
289
+ dropout: float = 0.0
290
+ dropout_bdims: tuple[int, ...] = ()
291
+
292
+ @nn.compact
293
+ def __call__(self, xs, kv_cache, positions, attn_mask, adarms_cond, deterministic=True): # noqa: FBT002
294
+ xs = sharding.activation_sharding_constraint(xs)
295
+ drop = nn.Dropout(self.dropout, self.dropout_bdims) if self.dropout else lambda x, _: x
296
+
297
+ attn = Attention(configs=self.configs, name="attn")
298
+
299
+ pre_attn = []
300
+ gates = []
301
+ for i, x in enumerate(xs):
302
+ if x is not None:
303
+ x, gate = RMSNorm(name=_name("pre_attention_norm", i))(x, adarms_cond[i]) # noqa: PLW2901
304
+ pre_attn.append(x)
305
+ gates.append(gate if x is not None else None)
306
+
307
+ pre_attn = sharding.activation_sharding_constraint(pre_attn)
308
+ post_attn, kv_cache = attn(pre_attn, positions, attn_mask, kv_cache)
309
+ post_attn = jax.tree.map(lambda x: drop(x, deterministic), post_attn)
310
+ post_attn = sharding.activation_sharding_constraint(post_attn)
311
+ xs = [_gated_residual(x, y, gate) for x, y, gate in zip(xs, post_attn, gates, strict=True)]
312
+ xs = sharding.activation_sharding_constraint(xs)
313
+
314
+ out = []
315
+ gates = []
316
+ for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
317
+ if x is not None:
318
+ x, gate = RMSNorm(name=_name("pre_ffw_norm", i))(x, adarms_cond[i]) # noqa: PLW2901
319
+ x = lora.FeedForward( # noqa: PLW2901
320
+ features=config.width,
321
+ hidden_dim=config.mlp_dim,
322
+ name=_name("mlp", i),
323
+ lora_config=config.lora_configs.get("ffn"),
324
+ )(x)
325
+ out.append(x)
326
+ gates.append(gate if x is not None else None)
327
+
328
+ out = sharding.activation_sharding_constraint(out)
329
+ out = jax.tree.map(lambda x: drop(x, deterministic), out)
330
+ xs = [_gated_residual(x, y, gate) for x, y, gate in zip(xs, out, gates, strict=True)]
331
+ xs = sharding.activation_sharding_constraint(xs)
332
+
333
+ return xs, kv_cache
334
+
335
+
336
+ KVCache: TypeAlias = tuple[at.Float[at.Array, "l b _t _k _h"], at.Float[at.Array, "l b _t _v _h"]]
337
+
338
+
339
+ @at.typecheck
340
+ class Module(nn.Module):
341
+ """Transformer model, supporting a mixture of different weights for different tokens."""
342
+
343
+ configs: Sequence[Config] # list of configs, one for each expert
344
+ embed_dtype: str
345
+
346
+ dropout: float = 0.0
347
+ dropout_bdims: tuple[int, ...] = () # Every float is dropped independently.
348
+ adarms: bool = False
349
+
350
+ def setup(self):
351
+ # all experts must have the same depth
352
+ assert all(config.depth == self.configs[0].depth for config in self.configs)
353
+
354
+ self.embedder = Embedder(
355
+ vocab_size=PALIGEMMA_VOCAB_SIZE,
356
+ embed_dim=self.configs[0].width, # embedder for first expert only
357
+ name="embedder",
358
+ )
359
+ block_cls = nn.remat(
360
+ Block,
361
+ prevent_cse=False,
362
+ static_argnums=(5,), # 0=self, 6=deterministic
363
+ policy=jax.checkpoint_policies.nothing_saveable,
364
+ )
365
+ self.layers = nn.scan(
366
+ block_cls,
367
+ variable_axes={"params": 0},
368
+ split_rngs={"params": True, "dropout": True},
369
+ in_axes=(
370
+ 0,
371
+ nn.broadcast,
372
+ nn.broadcast,
373
+ nn.broadcast,
374
+ nn.broadcast,
375
+ ), # 0=kv_cache, 1=positions, 2=mask, 3=adarms_cond, 4=deterministic
376
+ length=self.configs[0].depth,
377
+ )(
378
+ configs=self.configs,
379
+ dropout=self.dropout,
380
+ dropout_bdims=self.dropout_bdims,
381
+ )
382
+ self.final_norms = [RMSNorm(name=_name("final_norm", i)) for i in range(len(self.configs))]
383
+
384
+ @at.typecheck
385
+ def embed(self, tokens: at.Int[at.Array, "b t"]) -> at.Float[at.Array, "b t d"]:
386
+ return self.embedder.encode(tokens).astype(self.embed_dtype)
387
+
388
+ @at.typecheck
389
+ def __call__(
390
+ self,
391
+ # list of token arrays, one for each expert, or None if that expert should not be run
392
+ embedded: Sequence[at.Float[at.Array, "b _t _d"] | None],
393
+ positions: at.Int[at.Array, "b t"],
394
+ mask: at.Bool[at.Array, "b t s"],
395
+ adarms_cond: Sequence[at.Float[at.Array, "b _d"] | None] | None = None,
396
+ *,
397
+ kv_cache: KVCache | None = None,
398
+ deterministic: bool = True,
399
+ ) -> tuple[Sequence[at.Float[at.Array, "b _t _d"] | None], KVCache]:
400
+ embedded = jax.tree.map(lambda e: e.astype(self.embed_dtype), embedded)
401
+ mask = jnp.asarray(mask)[:, None, :, :]
402
+ if adarms_cond is None:
403
+ adarms_cond = [None] * len(self.configs)
404
+
405
+ embedded, kv_cache = self.layers(embedded, kv_cache, positions, mask, adarms_cond, deterministic)
406
+
407
+ assert all(e.dtype == jnp.dtype(self.embed_dtype) for e in embedded if e is not None)
408
+
409
+ return [
410
+ f(e, a)[0] if e is not None else e for f, e, a in zip(self.final_norms, embedded, adarms_cond, strict=True)
411
+ ], kv_cache
412
+
413
+ def init(self, use_adarms: Sequence[bool]):
414
+ """Convenience method for initializing all parameters, necessary due to the quirks of linen."""
415
+ self.embed(jnp.zeros((1, 1), dtype=jnp.int32))
416
+ self(
417
+ [jnp.zeros((1, 1, c.width)) for c in self.configs],
418
+ jnp.zeros((1, len(self.configs)), dtype=jnp.int32),
419
+ jnp.zeros((1, len(self.configs), len(self.configs)), dtype=bool),
420
+ adarms_cond=[jnp.zeros((1, c.width)) if u else None for u, c in zip(use_adarms, self.configs, strict=True)],
421
+ )
422
+
423
+
424
+ def _apply_rope(x, *, positions, max_wavelength=10_000):
425
+ """Applies RoPE positions [B, L] to x [B, L, H, D]."""
426
+ freq_exponents = (2.0 / x.shape[-1]) * jnp.arange(x.shape[-1] // 2, dtype=jnp.float32)
427
+ timescale = max_wavelength**freq_exponents
428
+ radians = positions[..., None] / timescale[None, None, :]
429
+ radians = radians[..., None, :]
430
+ assert radians.dtype == jnp.float32
431
+ # radians.shape = [...,L,1,d=D/2]
432
+ sin, cos = jnp.sin(radians), jnp.cos(radians)
433
+ x1, x2 = jnp.split(x, 2, axis=-1)
434
+ res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)
435
+ assert res.dtype == jnp.float32
436
+ # The original bigvision impl allows RoPE to upcast to float32. It is then immediately downcast again to the cache
437
+ # dtype when in inference mode (but not in training mode). I don't think any of this was intentional. Based on the
438
+ # original DeepMind impl, as well as the widely-used transformers impl, it is ok to always downcast back to bfloat16
439
+ # here.
440
+ return res.astype(x.dtype)
441
+
442
+
443
+ def _name(name, i):
444
+ # we name layers like this because we want the first expert's weights to have no suffix (e.g., "attn"), so that they
445
+ # can be loaded seamlessly from the existing PaliGemma checkpoint. subsequent experts will have a suffix (e.g.,
446
+ # "attn_1") and their weights will be initialized from scratch. in practice, we only use two experts -- PaliGemma,
447
+ # and the action expert.
448
+ if i == 0:
449
+ return name
450
+ return f"{name}_{i}"
451
+
452
+
453
+ def _gated_residual(x, y, gate):
454
+ assert (x is None) == (y is None)
455
+ if x is None:
456
+ return None
457
+ if gate is None:
458
+ return x + y
459
+ return x + y * gate
capvector-pi05/src/openpi/models/gemma_fast.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Gemma model implementation from big_vision/models/ppp/gemma.py (with small modifications for NNX compatibility)
17
+ Used for FAST autoregressive policies.
18
+ """
19
+
20
+ import dataclasses
21
+ from typing import Literal, TypeAlias
22
+
23
+ import einops
24
+ import flax.linen as nn
25
+ import jax
26
+ import jax.numpy as jnp
27
+ import ml_collections
28
+
29
+ import openpi.models.lora as lora
30
+ import openpi.shared.array_typing as at
31
+
32
+ Variant = Literal["gemma_2b", "gemma_2b_lora"]
33
+
34
+
35
+ def get_config(variant):
36
+ """Returns config for specified gemma variant."""
37
+ if variant == "gemma_2b":
38
+ return ml_collections.ConfigDict(
39
+ {
40
+ "variant": variant,
41
+ "width": 2048,
42
+ "depth": 18,
43
+ "mlp_dim": 16_384,
44
+ "num_heads": 8,
45
+ "num_kv_heads": 1,
46
+ "head_dim": 256,
47
+ "norm_eps": 1e-6,
48
+ "vocab_size": 257_152,
49
+ "scan": True,
50
+ "remat_policy": "nothing_saveable",
51
+ }
52
+ )
53
+ if variant == "gemma_2b_lora":
54
+ return ml_collections.ConfigDict(
55
+ {
56
+ "variant": variant,
57
+ "width": 2048,
58
+ "depth": 18,
59
+ "mlp_dim": 16_384,
60
+ "num_heads": 8,
61
+ "num_kv_heads": 1,
62
+ "head_dim": 256,
63
+ "norm_eps": 1e-6,
64
+ "vocab_size": 257_152,
65
+ "scan": True,
66
+ "remat_policy": "nothing_saveable",
67
+ "lora_configs": {
68
+ "attn": lora.LoRAConfig(rank=16, alpha=16.0),
69
+ "ffn": lora.LoRAConfig(rank=16, alpha=16.0),
70
+ },
71
+ }
72
+ )
73
+ raise ValueError(f"Unknown variant: {variant}")
74
+
75
+
76
+ @at.typecheck
77
+ class Einsum(nn.Module):
78
+ shape: tuple[int, ...]
79
+
80
+ @nn.compact
81
+ def __call__(self, eqn, x):
82
+ dtype = x.dtype # original dtype, could be half-precision
83
+ w = self.param("w", nn.initializers.zeros_init(), self.shape).astype(dtype)
84
+ return jnp.einsum(eqn, x, w)
85
+
86
+
87
+ @at.typecheck
88
+ class RMSNorm(nn.Module):
89
+ @nn.compact
90
+ def __call__(self, x):
91
+ dtype = x.dtype # original dtype, could be half-precision
92
+ scale = self.param("scale", nn.initializers.zeros_init(), (x.shape[-1]))
93
+ var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True) # compute variance in float32
94
+ normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) # compute normalization in float32
95
+ normed_inputs = normed_inputs * (
96
+ 1 + scale
97
+ ) # scale by learned parameter in float32 (matches Flax implementation)
98
+ return normed_inputs.astype(dtype) # return in original dtype
99
+
100
+
101
+ @at.typecheck
102
+ class Embedder(nn.Module):
103
+ """Embedder module."""
104
+
105
+ vocab_size: int
106
+ embed_dim: int
107
+
108
+ def setup(self):
109
+ self.input_embedding_table = self.param(
110
+ "input_embedding",
111
+ nn.initializers.zeros_init(),
112
+ (self.vocab_size, self.embed_dim),
113
+ )
114
+
115
+ def encode(self, x):
116
+ x = self.input_embedding_table[(x,)]
117
+ x *= jnp.sqrt(self.embed_dim).astype(x.dtype)
118
+ return x
119
+
120
+ def decode(self, x):
121
+ return jnp.dot(x, self.input_embedding_table.T)
122
+
123
+
124
+ @at.typecheck
125
+ class Attention(nn.Module):
126
+ """Attention module."""
127
+
128
+ num_heads: int
129
+ num_kv_heads: int
130
+ features: int
131
+ head_dim: int
132
+
133
+ cache_dtype: str | None = None
134
+
135
+ lora_config: lora.LoRAConfig | None = None
136
+
137
+ def setup(self):
138
+ if self.num_kv_heads == self.num_heads:
139
+ self.qkv_einsum = lora.Einsum(
140
+ shape=(3, self.num_heads, self.features, self.head_dim),
141
+ name="qkv_einsum",
142
+ init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
143
+ lora_config=self.lora_config,
144
+ )
145
+ else:
146
+ self.q_einsum = lora.Einsum(
147
+ shape=(self.num_heads, self.features, self.head_dim),
148
+ name="q_einsum",
149
+ init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
150
+ lora_config=self.lora_config,
151
+ )
152
+ self.kv_einsum = lora.Einsum(
153
+ shape=(2, self.num_kv_heads, self.features, self.head_dim),
154
+ name="kv_einsum",
155
+ init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
156
+ lora_config=self.lora_config,
157
+ )
158
+ self.attn_vec_einsum = lora.Einsum(
159
+ shape=(self.num_heads, self.head_dim, self.features),
160
+ name="attn_vec_einsum",
161
+ init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
162
+ lora_config=self.lora_config,
163
+ )
164
+
165
+ def _init_cache(self, k, v, cache_size):
166
+ """Initialize KV cache"""
167
+ prefill_len = k.shape[1]
168
+ pad_width = ((0, 0), (0, cache_size - prefill_len), (0, 0), (0, 0))
169
+ cache_dtype = self.cache_dtype or k.dtype
170
+ k_cache = jnp.pad(k.astype(cache_dtype), pad_width)
171
+ v_cache = jnp.pad(v.astype(cache_dtype), pad_width)
172
+ idx = jnp.zeros((k.shape[0],), dtype=jnp.int32) + prefill_len
173
+ return idx, k_cache, v_cache
174
+
175
+ def _update_cache(self, k, v, idx, k_cache, v_cache):
176
+ """Update KV cache with new values"""
177
+ assert k.shape[1] == 1, "Only support kv-cache updates of length 1"
178
+ indices = (0, idx[0], 0, 0)
179
+ cache_dtype = self.cache_dtype or k.dtype
180
+ k_new = jax.lax.dynamic_update_slice(k_cache, k.astype(cache_dtype), indices)
181
+ v_new = jax.lax.dynamic_update_slice(v_cache, v.astype(cache_dtype), indices)
182
+ idx_new = idx + 1
183
+ return idx_new, k_new, v_new
184
+
185
+ @nn.compact
186
+ def __call__(self, x, positions, attn_mask, kv_cache, decode, deterministic=True): # noqa: FBT002
187
+ dtype = x.dtype # original dtype, could be half-precision
188
+ if self.num_kv_heads == self.num_heads:
189
+ q, k, v = self.qkv_einsum("BSD,3KDH->3BSKH", x)
190
+ else:
191
+ q = self.q_einsum("BTD,NDH->BTNH", x)
192
+ k, v = self.kv_einsum("BSD,2KDH->2BSKH", x)
193
+
194
+ q = _apply_rope(q, positions=positions) # promotes to float32
195
+ q *= self.head_dim**-0.5
196
+
197
+ k = _apply_rope(k, positions=positions) # promotes to float32
198
+
199
+ if kv_cache is None:
200
+ idx, k_cache, v_cache = self._init_cache(k, v, attn_mask.shape[-1])
201
+ else:
202
+ idx, k_cache, v_cache = kv_cache
203
+ idx, k_cache, v_cache = self._update_cache(k, v, idx, k_cache, v_cache)
204
+
205
+ k, v = k_cache, v_cache
206
+ kv_cache = (idx, k_cache, v_cache)
207
+
208
+ q = einops.rearrange(q, "B T (K G) H -> B T K G H", K=self.num_kv_heads)
209
+ logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k, preferred_element_type=jnp.float32)
210
+
211
+ if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]):
212
+ raise ValueError(
213
+ f"Attention mask with shape {attn_mask.shape} but shapes for q and k are: {q.shape} and {k.shape}"
214
+ )
215
+
216
+ # big_neg = jnp.finfo(logits.dtype).min
217
+ big_neg = -2.3819763e38 # See gemma/modules.py
218
+ masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg)
219
+
220
+ probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype)
221
+
222
+ encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v)
223
+ encoded = einops.rearrange(encoded, "B T K G H -> B T (K G) H")
224
+ return self.attn_vec_einsum("BTNH,NHD->BTD", encoded), kv_cache
225
+
226
+
227
+ @at.typecheck
228
+ class Block(nn.Module):
229
+ """Transformer block."""
230
+
231
+ num_heads: int
232
+ num_kv_heads: int
233
+ embed_dim: int
234
+ head_dim: int
235
+ hidden_dim: int
236
+
237
+ dropout: float = 0.0
238
+ dropout_bdims: tuple[int, ...] = ()
239
+ cache_dtype: str | None = None
240
+ lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict)
241
+
242
+ def setup(self):
243
+ self.pre_attention_norm = RMSNorm()
244
+ self.attn = Attention(
245
+ num_heads=self.num_heads,
246
+ num_kv_heads=self.num_kv_heads,
247
+ features=self.embed_dim,
248
+ head_dim=self.head_dim,
249
+ cache_dtype=self.cache_dtype,
250
+ lora_config=self.lora_configs.get("attn"),
251
+ )
252
+ self.pre_ffw_norm = RMSNorm()
253
+ self.mlp = lora.FeedForward(
254
+ features=self.embed_dim, hidden_dim=self.hidden_dim, name="mlp", lora_config=self.lora_configs.get("ffn")
255
+ )
256
+ if self.dropout:
257
+ self.drop = nn.Dropout(self.dropout, self.dropout_bdims)
258
+ else:
259
+ self.drop = lambda x, _: x
260
+
261
+ def __call__(self, x, kv_cache, positions, attn_mask, decode, deterministic=True): # noqa: FBT002
262
+ x = nn.with_logical_constraint(x, ("act_batch", "act_len", "act_emb"))
263
+ inputs_normalized = self.pre_attention_norm(x)
264
+ attn_output, kv_cache = self.attn(inputs_normalized, positions, attn_mask, kv_cache, decode, deterministic)
265
+ attn_output = self.drop(attn_output, deterministic)
266
+ attn_output += x
267
+ residual = attn_output
268
+ attn_output = self.pre_ffw_norm(attn_output)
269
+ outputs = self.mlp(attn_output)
270
+ outputs = self.drop(outputs, deterministic)
271
+ outputs = residual + outputs
272
+ return outputs, kv_cache
273
+
274
+
275
+ KVCache: TypeAlias = tuple[at.Int[at.Array, " b"], at.Float[at.Array, "b _t _k _h"], at.Float[at.Array, "b _t _v _h"]]
276
+
277
+
278
+ @at.typecheck
279
+ class Module(nn.Module):
280
+ """gemma model."""
281
+
282
+ variant: str
283
+
284
+ width: int
285
+ depth: int
286
+ mlp_dim: int
287
+ num_heads: int
288
+ num_kv_heads: int
289
+ head_dim: int
290
+ norm_eps: float
291
+ vocab_size: int
292
+ embed_dtype: str
293
+
294
+ dropout: float = 0.0
295
+ dropout_bdims: tuple[int, ...] = () # Every float is dropped independently.
296
+ cache_dtype: str | None = None
297
+
298
+ scan: bool = False
299
+ remat_policy: str = "none"
300
+ lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict)
301
+
302
+ @nn.compact
303
+ def __call__(
304
+ self,
305
+ tokens=None,
306
+ embedded_prefix=None,
307
+ embed_only=False, # noqa: FBT002
308
+ pre_logits=None,
309
+ positions=None,
310
+ mask=None,
311
+ decode=False, # noqa: FBT002
312
+ kv_cache=None,
313
+ deterministic=True, # noqa: FBT002
314
+ return_prelogits=False, # noqa: FBT002
315
+ ):
316
+ """Embed only, or complete forward pass.
317
+
318
+ Args:
319
+ tokens: Embedded, then and appended to `embedded_prefix`. Can be None.
320
+ embedded_prefix: Optional prefix that is already embedded.
321
+ embed_only: Whether to compute embeddings only.
322
+ pre_logits: If present computes logits from pre_logits and returns.
323
+ positions: Optional `[B, T]` allows to specify the absolute position of
324
+ the tokens.
325
+ mask: Optional attention mask `[B, T, S]`.
326
+ decode: Whether to use kv-cache. Caller must pass masks and positions.
327
+ deterministic: Forwarded to all dropout layers.
328
+ return_prelogits: Whether to return the pre-logits.
329
+
330
+ Returns:
331
+ If `embed_only=False`, then `(logits, out)` will be returned.
332
+ If `embed_only=True`, then the embeddings will be returned.
333
+ If `return_prelogits=True`, then the pre-logits will be returned.
334
+ """
335
+ out = {}
336
+
337
+ embedder = Embedder(vocab_size=self.vocab_size, embed_dim=self.width, name="embedder")
338
+
339
+ if pre_logits is not None:
340
+ x = out["pre_logits"] = pre_logits
341
+ logits = out["logits"] = embedder.decode(x)
342
+ return logits, out
343
+
344
+ x = []
345
+ if embedded_prefix is not None:
346
+ x.append(embedded_prefix)
347
+ if tokens is not None:
348
+ x.append(embedder.encode(tokens))
349
+
350
+ x = jnp.concatenate(x, axis=-2)
351
+ x = x.astype(self.embed_dtype)
352
+ batch_size, seq_len, width = x.shape
353
+
354
+ if embed_only:
355
+ return x
356
+
357
+ if decode:
358
+ assert positions is not None and mask is not None, ( # noqa: PT018
359
+ "Must explicitly pass positions and mask for decoding."
360
+ )
361
+
362
+ if positions is None:
363
+ positions = jnp.arange(seq_len).astype(jnp.int32)[None, :]
364
+ assert positions.shape[1] == x.shape[1], (positions.shape, x.shape)
365
+
366
+ if mask is None:
367
+ mask = nn.attention.make_causal_mask(jnp.ones([batch_size, seq_len]))
368
+ if mask.ndim == 3:
369
+ mask = mask[:, None, :, :]
370
+ cache_size = max(seq_len, mask.shape[-1])
371
+ assert mask.shape == (batch_size, 1, seq_len, cache_size), mask.shape
372
+
373
+ if self.remat_policy == "none":
374
+ block_cls = Block
375
+ else:
376
+ block_cls = nn.remat(
377
+ Block,
378
+ prevent_cse=not self.scan,
379
+ static_argnums=(5, 6), # 0=self, 5=decode, 6=deterministic
380
+ policy=getattr(jax.checkpoint_policies, self.remat_policy),
381
+ )
382
+
383
+ block_kw = {
384
+ "num_heads": self.num_heads,
385
+ "head_dim": self.head_dim,
386
+ "num_kv_heads": self.num_kv_heads,
387
+ "embed_dim": width,
388
+ "hidden_dim": self.mlp_dim,
389
+ "dropout": self.dropout,
390
+ "dropout_bdims": self.dropout_bdims,
391
+ "cache_dtype": self.cache_dtype,
392
+ "lora_configs": self.lora_configs,
393
+ }
394
+ layers = self.scope.push("layers")
395
+ blocks = [
396
+ nn.scan(
397
+ block_cls,
398
+ variable_axes={"params": 0},
399
+ split_rngs={"params": True, "dropout": True},
400
+ in_axes=(0, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast), # 0=kv_cache, 1=positions, 2=mask
401
+ length=self.depth,
402
+ )(parent=layers, **block_kw)
403
+ ]
404
+ for block in blocks:
405
+ x, kv_cache = block(x, kv_cache, positions, mask, decode, deterministic)
406
+
407
+ assert x.dtype == jnp.dtype(self.embed_dtype) # Sanity check.
408
+ out["encoded"] = x
409
+
410
+ x = RMSNorm(name="final_norm")(x)
411
+ out["pre_logits"] = x
412
+ if return_prelogits:
413
+ return x, kv_cache, out
414
+
415
+ x = embedder.decode(x)
416
+ out["logits"] = x
417
+
418
+ return x, kv_cache, out
419
+
420
+ def init(self):
421
+ """Convenience method for initializing all parameters, necessary due to the quirks of linen."""
422
+ self(jnp.zeros((1, 1), dtype=jnp.int32))
423
+
424
+
425
+ def _apply_rope(x, *, positions, max_wavelength=10_000):
426
+ """Applies RoPE positions [B, L] to x [B, L, H, D]."""
427
+ freq_exponents = (2.0 / x.shape[-1]) * jnp.arange(x.shape[-1] // 2, dtype=jnp.float32)
428
+ timescale = max_wavelength**freq_exponents
429
+ radians = positions[..., None] / timescale[None, None, :]
430
+ radians = radians[..., None, :]
431
+ assert radians.dtype == jnp.float32
432
+ # radians.shape = [...,L,1,d=D/2]
433
+ sin, cos = jnp.sin(radians), jnp.cos(radians)
434
+ x1, x2 = jnp.split(x, 2, axis=-1)
435
+ res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)
436
+ assert res.dtype == jnp.float32
437
+ return res
capvector-pi05/src/openpi/models/lora.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import re
3
+
4
+ import flax.linen as nn
5
+ import flax.struct as struct
6
+ import jax.numpy as jnp
7
+
8
+ import openpi.shared.array_typing as at
9
+
10
+
11
+ @struct.dataclass
12
+ class LoRAConfig:
13
+ """Configuration for LoRA."""
14
+
15
+ # LoRA rank.
16
+ rank: int
17
+ # LoRA scaling factor.
18
+ alpha: float = 1.0
19
+ # Initialization function for LoRA parameters.
20
+ init_fn: nn.initializers.Initializer = nn.initializers.normal(stddev=0.01)
21
+ # Enable rank-stabilized LoRA: https://arxiv.org/pdf/2312.03732
22
+ rslora: bool = False
23
+ # Axes in the weight to apply LoRA to. Should typically be the last two axes.
24
+ axes: tuple[int, int] = (-2, -1)
25
+ # Axis label which is used by LoRA in einsum equations. Must not be present in the original equation.
26
+ label: str = "L"
27
+
28
+ @property
29
+ def scaling_value(self) -> float:
30
+ return self.alpha / math.sqrt(self.rank) if self.rslora else self.alpha / self.rank
31
+
32
+
33
+ class Einsum(nn.Module):
34
+ """Einsum with LoRA support. Can be used as a drop-in replacement for the Gemma Einsum."""
35
+
36
+ # Shape of the weight.
37
+ shape: tuple[int, ...]
38
+ # Initialization function for the weight.
39
+ init_fn: nn.initializers.Initializer = nn.initializers.zeros
40
+ # If not None, apply LoRA to the weight.
41
+ lora_config: LoRAConfig | None = None
42
+
43
+ def setup(self):
44
+ self.w = self.param("w", self.init_fn, self.shape)
45
+
46
+ if config := self.lora_config:
47
+ # Setup LoRA parameters.
48
+ shape_a, shape_b = list(self.shape), list(self.shape)
49
+ shape_a[config.axes[1]] = config.rank
50
+ shape_b[config.axes[0]] = config.rank
51
+ self.w_a = self.param("lora_a", config.init_fn, shape_a)
52
+ self.w_b = self.param("lora_b", config.init_fn, shape_b)
53
+
54
+ @nn.compact
55
+ def __call__(self, eqn: str, x):
56
+ dtype = x.dtype # original dtype, could be half-precision
57
+ result = jnp.einsum(eqn, x, self.w.astype(dtype))
58
+
59
+ if config := self.lora_config:
60
+ eqn_a, eqn_b = self._make_lora_eqns(eqn)
61
+ lora = jnp.einsum(eqn_a, x, self.w_a.astype(dtype))
62
+ lora = jnp.einsum(eqn_b, lora, self.w_b.astype(dtype))
63
+ result = result + lora * config.scaling_value
64
+
65
+ return result
66
+
67
+ def _make_lora_eqns(self, eqn: str) -> tuple[str, str]:
68
+ if "L" in eqn:
69
+ raise ValueError(f"L already in eqn: {eqn}")
70
+ if not (m := re.match("(.*),(.*)->(.*)", eqn)):
71
+ raise ValueError(f"Unsupported einsum eqn: {eqn}")
72
+ lhs, rhs, out = m.groups()
73
+
74
+ assert self.lora_config is not None
75
+ a_label, b_label = (rhs[x] for x in self.lora_config.axes)
76
+ label = self.lora_config.label
77
+
78
+ a_rhs = rhs.replace(b_label, label)
79
+ a_out = out.replace(b_label, label)
80
+ eqn_a = f"{lhs},{a_rhs}->{a_out}"
81
+
82
+ b_rhs = rhs.replace(a_label, label)
83
+ eqn_b = f"{a_out},{b_rhs}->{out}"
84
+
85
+ return eqn_a, eqn_b
86
+
87
+
88
+ class FeedForward(nn.Module):
89
+ """Feed forward module."""
90
+
91
+ features: int
92
+ hidden_dim: int
93
+ # If not None, apply LoRA to the weight.
94
+ lora_config: LoRAConfig | None = None
95
+
96
+ def setup(self):
97
+ self.w_gating = self.param(
98
+ "gating_einsum",
99
+ nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
100
+ (2, self.features, self.hidden_dim),
101
+ )
102
+ self.w_linear = self.param(
103
+ "linear",
104
+ nn.initializers.lecun_normal(in_axis=-2, out_axis=-1),
105
+ (self.hidden_dim, self.features),
106
+ )
107
+ self.w_gating_lora = None
108
+ self.w_linear_lora = None
109
+ if self.lora_config:
110
+ # Setup LoRA parameters.
111
+ # TODO: follow up with a simplified init_fn api.
112
+ self.w_gating_lora = (
113
+ self.param("gating_einsum_lora_a", self.lora_config.init_fn, (2, self.features, self.lora_config.rank)),
114
+ self.param(
115
+ "gating_einsum_lora_b", self.lora_config.init_fn, (2, self.lora_config.rank, self.hidden_dim)
116
+ ),
117
+ )
118
+ self.w_linear_lora = (
119
+ self.param("linear_lora_a", self.lora_config.init_fn, (self.hidden_dim, self.lora_config.rank)),
120
+ self.param("linear_lora_b", self.lora_config.init_fn, (self.lora_config.rank, self.features)),
121
+ )
122
+
123
+ @nn.compact
124
+ def __call__(self, x):
125
+ dtype = x.dtype # original dtype, could be half-precision
126
+ ff_gate = self._dot(
127
+ x,
128
+ self.w_gating[0],
129
+ None if self.w_gating_lora is None else (self.w_gating_lora[0][0], self.w_gating_lora[1][0]),
130
+ )
131
+ gate_value = nn.gelu(ff_gate)
132
+
133
+ ff1 = self._dot(
134
+ x,
135
+ self.w_gating[1],
136
+ None if self.w_gating_lora is None else (self.w_gating_lora[0][1], self.w_gating_lora[1][1]),
137
+ )
138
+ activations = gate_value * ff1
139
+
140
+ outputs = self._dot(activations, self.w_linear, self.w_linear_lora)
141
+ assert outputs.dtype == dtype
142
+ return outputs
143
+
144
+ def _dot(self, x: at.Array, w: at.Array, lora_weights: tuple[at.Array, at.Array] | None) -> at.Array:
145
+ base = jnp.dot(x, w.astype(x.dtype))
146
+ if lora_weights is None:
147
+ return base
148
+ return base + jnp.dot(jnp.dot(x, lora_weights[0].astype(x.dtype)), lora_weights[1].astype(x.dtype))
capvector-pi05/src/openpi/models/lora_test.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import flax.linen as nn
2
+ import jax
3
+ import jax.numpy as jnp
4
+
5
+ import openpi.models.lora as lora
6
+
7
+
8
+ def test_lora_einsum_params_shape():
9
+ shape = (3, 8, 32, 4) # (3KDH)
10
+ einsum = lora.Einsum(shape)
11
+ lora0 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2))
12
+ lora1 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2, axes=(1, 2)))
13
+
14
+ key = jax.random.key(0)
15
+ x = jax.random.normal(key, (8, 64, 32)) # (BSD)
16
+ eqn = "BSD,3KDH->3BSKH"
17
+
18
+ # Ensure that lora parameters are not initialized when LoRA is not used.
19
+ params = einsum.init(key, eqn, x)
20
+ assert "lora_a" not in params["params"]
21
+ assert "lora_b" not in params["params"]
22
+
23
+ # Check that default axes work.
24
+ params_lora0 = lora0.init(key, eqn, x)
25
+ assert params_lora0["params"]["lora_a"].shape == (3, 8, 32, 2)
26
+ assert params_lora0["params"]["lora_b"].shape == (3, 8, 2, 4)
27
+
28
+ # Check that user provided axes work.
29
+ params_lora1 = lora1.init(key, eqn, x)
30
+ assert params_lora1["params"]["lora_a"].shape == (3, 8, 2, 4)
31
+ assert params_lora1["params"]["lora_b"].shape == (3, 2, 32, 4)
32
+
33
+
34
+ def test_lora_einsum_same_output():
35
+ shape = (3, 8, 32, 4) # (3KDH)
36
+ einsum = lora.Einsum(shape)
37
+ einsum_lora = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros))
38
+
39
+ key = jax.random.key(0)
40
+ x = jax.random.normal(key, (8, 64, 32)) # (BSD)
41
+ eqn = "BSD,3KDH->3BSKH"
42
+
43
+ params = einsum.init(key, eqn, x)
44
+ output = einsum.apply(params, eqn, x)
45
+
46
+ params_lora = einsum_lora.init(key, eqn, x)
47
+ output_lora = einsum_lora.apply(params_lora, eqn, x)
48
+
49
+ # Results are the same since the LoRA parameters are initialized to zeros.
50
+ assert jnp.allclose(output, output_lora)
51
+
52
+
53
+ def test_lora_ffn_params_shape():
54
+ ffn = lora.FeedForward(features=8, hidden_dim=32)
55
+ ffn_lora = lora.FeedForward(
56
+ features=8,
57
+ hidden_dim=32,
58
+ lora_config=lora.LoRAConfig(rank=2),
59
+ )
60
+
61
+ key = jax.random.key(0)
62
+ x = jax.random.normal(key, (2, 8))
63
+
64
+ params = ffn.init(key, x)
65
+ assert params["params"]["gating_einsum"].shape == (2, 8, 32)
66
+ assert params["params"]["linear"].shape == (32, 8)
67
+
68
+ params_lora = ffn_lora.init(key, x)
69
+ assert params_lora["params"]["gating_einsum"].shape == (2, 8, 32)
70
+ assert params_lora["params"]["linear"].shape == (32, 8)
71
+ assert params_lora["params"]["gating_einsum_lora_a"].shape == (2, 8, 2)
72
+ assert params_lora["params"]["gating_einsum_lora_b"].shape == (2, 2, 32)
73
+ assert params_lora["params"]["linear_lora_a"].shape == (32, 2)
74
+ assert params_lora["params"]["linear_lora_b"].shape == (2, 8)
75
+
76
+
77
+ def test_lora_ffn_same_output():
78
+ ffn = lora.FeedForward(features=8, hidden_dim=32)
79
+ ffn_lora = lora.FeedForward(
80
+ features=8,
81
+ hidden_dim=32,
82
+ lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros),
83
+ )
84
+
85
+ key = jax.random.key(0)
86
+ x = jax.random.normal(key, (2, 8))
87
+
88
+ params = ffn.init(key, x)
89
+ output = ffn.apply(params, x)
90
+
91
+ params_lora = ffn_lora.init(key, x)
92
+ output_lora = ffn_lora.apply(params_lora, x)
93
+
94
+ assert jnp.allclose(output, output_lora)
capvector-pi05/src/openpi/models/model.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from collections.abc import Sequence
3
+ import dataclasses
4
+ import enum
5
+ import logging
6
+ import pathlib
7
+ from typing import Generic, TypeVar
8
+
9
+ import augmax
10
+ from flax import nnx
11
+ from flax import struct
12
+ from flax import traverse_util
13
+ import jax
14
+ import jax.numpy as jnp
15
+ import numpy as np
16
+ import orbax.checkpoint as ocp
17
+ import safetensors
18
+ import torch
19
+
20
+ from openpi.models_pytorch import pi0_pytorch
21
+ from openpi.shared import image_tools
22
+ import openpi.shared.array_typing as at
23
+
24
+ logger = logging.getLogger("openpi")
25
+
26
+ # Type variable for array types (JAX arrays, PyTorch tensors, or numpy arrays)
27
+ ArrayT = TypeVar("ArrayT", bound=jax.Array | torch.Tensor | np.ndarray)
28
+
29
+
30
+ class ModelType(enum.Enum):
31
+ """Supported model types."""
32
+
33
+ PI0 = "pi0"
34
+ PI0_FAST = "pi0_fast"
35
+ PI05 = "pi05"
36
+
37
+
38
+ # The model always expects these images
39
+ IMAGE_KEYS = (
40
+ "base_0_rgb",
41
+ "left_wrist_0_rgb",
42
+ "right_wrist_0_rgb",
43
+ )
44
+
45
+
46
+ # This may need change if we release a small model.
47
+ IMAGE_RESOLUTION = (224, 224)
48
+
49
+
50
+ # Data format
51
+ #
52
+ # Data transforms produce the model input as a nested dictionary which is later converted
53
+ # into `Obesrvation` and `Actions` objects. See below.
54
+ #
55
+ # In the dictory form, this data should look like:
56
+ # {
57
+ # # Observation data.
58
+ # "image": {
59
+ # "base_0_rgb": (float32|uint8)[*b, h, w, 3], # RGB image in [-1, 1] or [0, 255]
60
+ # ... # Additional camera views
61
+ # },
62
+ # "image_mask": {
63
+ # "base_0_rgb": bool[*b], # True if image is valid
64
+ # ... # Masks for additional views
65
+ # },
66
+ # "state": float32[*b, s], # Low-dimensional robot state
67
+ # "tokenized_prompt": int32[*b, l], # Optional, tokenized language prompt
68
+ # "tokenized_prompt_mask": bool[*b, l], # Optional, mask for tokenized prompt
69
+ # "token_ar_mask": int32[*b, l], # Optional, autoregressive mask for FAST model
70
+ # "token_loss_mask": bool[*b, l], # Optional, loss mask for FAST model
71
+ #
72
+ # # Actions data.
73
+ # "actions": float32[*b ah ad]
74
+ # }
75
+ # where:
76
+ # *b = batch dimensions
77
+ # h,w = image height/width
78
+ # s = state dimension
79
+ # l = sequence length
80
+ #
81
+ @at.typecheck
82
+ @struct.dataclass
83
+ class Observation(Generic[ArrayT]):
84
+ """Holds observations, i.e., inputs to the model.
85
+
86
+ See `Observation.from_dict` to see the expected dictionary form. This is the format
87
+ that should be produced by the data transforms.
88
+ """
89
+
90
+ # Images, in [-1, 1] float32.
91
+ images: dict[str, at.Float[ArrayT, "*b h w c"]]
92
+ # the padding area for non-rectangular input images is False
93
+ image_padding_mask: dict[str, at.Bool[ArrayT, "*b w c"]]
94
+ # Image masks, with same keys as images.
95
+ image_masks: dict[str, at.Bool[ArrayT, "*b"]]
96
+ # Low-dimensional robot state.
97
+ state: at.Float[ArrayT, "*b s"]
98
+
99
+ # Tokenized prompt.
100
+ tokenized_prompt: at.Int[ArrayT, "*b l"] | None = None
101
+ # Tokenized prompt mask.
102
+ tokenized_prompt_mask: at.Bool[ArrayT, "*b l"] | None = None
103
+
104
+ # pi0-fast model specific fields.
105
+
106
+ # Token auto-regressive mask (for FAST autoregressive model).
107
+ token_ar_mask: at.Int[ArrayT, "*b l"] | None = None
108
+ # Token loss mask (for FAST autoregressive model).
109
+ token_loss_mask: at.Bool[ArrayT, "*b l"] | None = None
110
+
111
+ @classmethod
112
+ def from_dict(cls, data: at.PyTree[ArrayT]) -> "Observation[ArrayT]":
113
+ """This method defines the mapping between unstructured data (i.e., nested dict) to the structured Observation format."""
114
+ # Ensure that tokenized_prompt and tokenized_prompt_mask are provided together.
115
+ if ("tokenized_prompt" in data) != ("tokenized_prompt_mask" in data):
116
+ raise ValueError("tokenized_prompt and tokenized_prompt_mask must be provided together.")
117
+ # If images are uint8, convert them to [-1, 1] float32.
118
+ for key in data["image"]:
119
+ if data["image"][key].dtype == np.uint8:
120
+ data["image"][key] = data["image"][key].astype(np.float32) / 255.0 * 2.0 - 1.0
121
+ elif hasattr(data["image"][key], "dtype") and data["image"][key].dtype == torch.uint8:
122
+ data["image"][key] = data["image"][key].to(torch.float32).permute(0, 3, 1, 2) / 255.0 * 2.0 - 1.0
123
+ return cls(
124
+ images=data["image"],
125
+ image_padding_mask=data.get("image_padding_mask", {}),
126
+ image_masks=data["image_mask"],
127
+ state=data["state"],
128
+ tokenized_prompt=data.get("tokenized_prompt"),
129
+ tokenized_prompt_mask=data.get("tokenized_prompt_mask"),
130
+ token_ar_mask=data.get("token_ar_mask"),
131
+ token_loss_mask=data.get("token_loss_mask"),
132
+ )
133
+
134
+ def to_dict(self) -> at.PyTree[ArrayT]:
135
+ """Convert the Observation to a nested dict."""
136
+ result = dataclasses.asdict(self)
137
+ result["image"] = result.pop("images")
138
+ result["image_mask"] = result.pop("image_masks")
139
+ return result
140
+
141
+
142
+ # Defines the format of the actions. This field is included as "actions" inside the dictionary
143
+ # produced by the data transforms.
144
+ Actions = at.Float[ArrayT, "*b ah ad"]
145
+
146
+
147
+ def preprocess_observation(
148
+ rng: at.KeyArrayLike | None,
149
+ observation: Observation,
150
+ *,
151
+ train: bool = False,
152
+ image_keys: Sequence[str] = IMAGE_KEYS,
153
+ image_resolution: tuple[int, int] = IMAGE_RESOLUTION,
154
+ ) -> Observation:
155
+ """Preprocess the observations by performing image augmentations (if train=True), resizing (if necessary), and
156
+ filling in a default image mask (if necessary).
157
+ """
158
+
159
+ if not set(image_keys).issubset(observation.images):
160
+ raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}")
161
+
162
+ batch_shape = observation.state.shape[:-1]
163
+
164
+ out_images = {}
165
+ for key in image_keys:
166
+ image = observation.images[key]
167
+ if image.shape[1:3] != image_resolution:
168
+ logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}")
169
+ image = image_tools.resize_with_pad(image, *image_resolution)
170
+
171
+ if train:
172
+ # Convert from [-1, 1] to [0, 1] for augmax.
173
+ image = image / 2.0 + 0.5
174
+
175
+ transforms = []
176
+ if "wrist" not in key:
177
+ height, width = image.shape[1:3]
178
+ transforms += [
179
+ augmax.RandomCrop(int(width * 0.95), int(height * 0.95)),
180
+ augmax.Resize(width, height),
181
+ augmax.Rotate((-5, 5)),
182
+ ]
183
+ transforms += [
184
+ augmax.ColorJitter(brightness=0.3, contrast=0.4, saturation=0.5),
185
+ ]
186
+ sub_rngs = jax.random.split(rng, image.shape[0])
187
+ image = jax.vmap(augmax.Chain(*transforms))(sub_rngs, image)
188
+
189
+ # Back to [-1, 1].
190
+ image = image * 2.0 - 1.0
191
+
192
+ out_images[key] = image
193
+
194
+ # obtain mask
195
+ out_masks = {}
196
+ for key in out_images:
197
+ if key not in observation.image_masks:
198
+ # do not mask by default
199
+ out_masks[key] = jnp.ones(batch_shape, dtype=jnp.bool)
200
+ else:
201
+ out_masks[key] = jnp.asarray(observation.image_masks[key])
202
+
203
+ return Observation(
204
+ images=out_images,
205
+ image_masks=out_masks,
206
+ state=observation.state,
207
+ tokenized_prompt=observation.tokenized_prompt,
208
+ tokenized_prompt_mask=observation.tokenized_prompt_mask,
209
+ token_ar_mask=observation.token_ar_mask,
210
+ token_loss_mask=observation.token_loss_mask,
211
+ )
212
+
213
+
214
+ @dataclasses.dataclass(frozen=True)
215
+ class BaseModelConfig(abc.ABC):
216
+ """Configuration shared by all models. Specific models should inherit from this class, and implement the `create`
217
+ method to create the corresponding model.
218
+ """
219
+
220
+ # Action space dimension.
221
+ action_dim: int
222
+ # Action sequence length.
223
+ action_horizon: int
224
+ # Tokenized prompt maximum length.
225
+ max_token_len: int
226
+
227
+ @property
228
+ @abc.abstractmethod
229
+ def model_type(self) -> ModelType:
230
+ """The model type."""
231
+
232
+ @abc.abstractmethod
233
+ def create(self, rng: at.KeyArrayLike) -> "BaseModel":
234
+ """Create a new model, initializing parameters."""
235
+
236
+ def load(self, params: at.Params, *, remove_extra_params: bool = True) -> "BaseModel":
237
+ """Create a model with the given parameters."""
238
+ model = nnx.eval_shape(self.create, jax.random.key(0))
239
+ graphdef, state = nnx.split(model)
240
+ if remove_extra_params:
241
+ params = ocp.transform_utils.intersect_trees(state.to_pure_dict(), params)
242
+ at.check_pytree_equality(expected=state.to_pure_dict(), got=params, check_shapes=True, check_dtypes=False)
243
+ state.replace_by_pure_dict(params)
244
+ return nnx.merge(graphdef, state)
245
+
246
+ def load_pytorch(self, train_config, weight_path: str):
247
+ logger.info(f"train_config: {train_config}")
248
+ model = pi0_pytorch.PI0Pytorch(config=train_config.model)
249
+ safetensors.torch.load_model(model, weight_path)
250
+ return model
251
+
252
+ @abc.abstractmethod
253
+ def inputs_spec(self, *, batch_size: int = 1) -> tuple[Observation, Actions]:
254
+ """Returns the input specification for the model. Values are jax.ShapeDtypeStruct."""
255
+
256
+ def fake_obs(self, batch_size: int = 1) -> Observation:
257
+ observation_spec, _ = self.inputs_spec(batch_size=batch_size)
258
+ return jax.tree.map(lambda x: jnp.ones(x.shape, x.dtype), observation_spec)
259
+
260
+ def fake_act(self, batch_size: int = 1) -> Actions:
261
+ _, action_spec = self.inputs_spec(batch_size=batch_size)
262
+ return jax.tree.map(lambda x: jnp.ones(x.shape, x.dtype), action_spec)
263
+
264
+
265
+ @dataclasses.dataclass
266
+ class BaseModel(nnx.Module, abc.ABC):
267
+ """Base class for all model implementations. Specific models should inherit from this class. They should call
268
+ super().__init__() to initialize the shared attributes (action_dim, action_horizon, and max_token_len).
269
+ """
270
+
271
+ action_dim: int
272
+ action_horizon: int
273
+ max_token_len: int
274
+
275
+ @abc.abstractmethod
276
+ def compute_loss(
277
+ self,
278
+ rng: at.KeyArrayLike,
279
+ observation: Observation,
280
+ actions: Actions,
281
+ *,
282
+ train: bool = False,
283
+ ) -> at.Float[at.Array, "*b ah"]: ...
284
+
285
+ @abc.abstractmethod
286
+ def sample_actions(self, rng: at.KeyArrayLike, observation: Observation, **kwargs) -> Actions: ...
287
+
288
+
289
+ def restore_params(
290
+ params_path: pathlib.Path | str,
291
+ *,
292
+ restore_type: type[np.ndarray] | type[jax.Array] = jax.Array,
293
+ dtype: jnp.dtype | None = None,
294
+ sharding: jax.sharding.Sharding | None = None,
295
+ ) -> at.Params:
296
+ """Restores unstructured params PyTree from a checkpoint.
297
+
298
+ This works with checkpoints saved with `save_state` during openpi training (see `training/checkpoints.py`) as
299
+ well as pre-trained checkpoints released for openpi.
300
+
301
+ Args:
302
+ params_path: The local path to the checkpoint directory.
303
+ restore_type: The type to restore the params as. Can be set to `np.ndarray` to load the params as a numpy array.
304
+ dtype: The dtype to restore all params as. If not provided, will use the original dtype from the checkpoint.
305
+ sharding: The sharding to use for the params. If not provided, the params will be replicated across all devices.
306
+
307
+ Returns:
308
+ The restored params.
309
+ """
310
+ params_path = pathlib.Path(params_path).resolve() if not str(params_path).startswith("gs://") else params_path
311
+
312
+ if restore_type is jax.Array and sharding is None:
313
+ mesh = jax.sharding.Mesh(jax.devices(), ("x",))
314
+ sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
315
+
316
+ with ocp.PyTreeCheckpointer() as ckptr:
317
+ metadata = ckptr.metadata(params_path)
318
+ item = {"params": metadata["params"]}
319
+
320
+ params = ckptr.restore(
321
+ params_path,
322
+ ocp.args.PyTreeRestore(
323
+ item=item,
324
+ restore_args=jax.tree.map(
325
+ lambda _: ocp.ArrayRestoreArgs(sharding=sharding, restore_type=restore_type, dtype=dtype), item
326
+ ),
327
+ ),
328
+ )["params"]
329
+
330
+ # If the params were saved with `save_state` during openpi training, every key path will end with "value", which is
331
+ # added by `nnx.State`. We remove the "value" suffix here and always return what NNX calls a "pure dict".
332
+ flat_params = traverse_util.flatten_dict(params)
333
+ if all(kp[-1] == "value" for kp in flat_params):
334
+ flat_params = {kp[:-1]: v for kp, v in flat_params.items()}
335
+ return traverse_util.unflatten_dict(flat_params)
capvector-pi05/src/openpi/models/model_test.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flax import nnx
2
+ import jax
3
+ import pytest
4
+
5
+ from openpi.models import model as _model
6
+ from openpi.models import pi0_config
7
+ from openpi.models import pi0_fast
8
+ from openpi.shared import download
9
+ from openpi.shared import nnx_utils
10
+
11
+
12
+ def test_pi0_model():
13
+ key = jax.random.key(0)
14
+ config = pi0_config.Pi0Config()
15
+ model = config.create(key)
16
+
17
+ batch_size = 2
18
+ obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
19
+
20
+ loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
21
+ assert loss.shape == (batch_size, config.action_horizon)
22
+
23
+ actions = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10)
24
+ assert actions.shape == (batch_size, model.action_horizon, model.action_dim)
25
+
26
+
27
+ def test_pi0_lora_model():
28
+ key = jax.random.key(0)
29
+ config = pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora")
30
+ model = config.create(key)
31
+
32
+ batch_size = 2
33
+ obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
34
+
35
+ loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
36
+ assert loss.shape == (batch_size, config.action_horizon)
37
+
38
+ actions = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10)
39
+ assert actions.shape == (batch_size, model.action_horizon, model.action_dim)
40
+
41
+
42
+ def test_pi0_fast_model():
43
+ key = jax.random.key(0)
44
+ config = pi0_fast.Pi0FASTConfig()
45
+ model = config.create(key)
46
+
47
+ batch_size = 2
48
+ obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
49
+
50
+ loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
51
+ assert loss.shape == (batch_size,)
52
+
53
+ actions = nnx_utils.module_jit(model.sample_actions)(key, obs)
54
+ assert actions.shape == (batch_size, 256)
55
+
56
+
57
+ def test_pi0_fast_lora_model():
58
+ key = jax.random.key(0)
59
+ config = pi0_fast.Pi0FASTConfig(paligemma_variant="gemma_2b_lora")
60
+ model = config.create(key)
61
+
62
+ batch_size = 2
63
+ obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
64
+
65
+ loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
66
+ assert loss.shape == (batch_size,)
67
+
68
+ actions = nnx_utils.module_jit(model.sample_actions)(key, obs)
69
+ assert actions.shape == (batch_size, 256)
70
+
71
+ lora_filter = nnx_utils.PathRegex(".*lora.*")
72
+ model_state = nnx.state(model)
73
+
74
+ lora_state_elems = list(model_state.filter(lora_filter))
75
+ assert len(lora_state_elems) > 0
76
+
77
+
78
+ @pytest.mark.manual
79
+ def test_model_restore():
80
+ key = jax.random.key(0)
81
+ config = pi0_config.Pi0Config()
82
+
83
+ batch_size = 2
84
+ obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
85
+
86
+ model = config.load(
87
+ _model.restore_params(download.maybe_download("gs://openpi-assets/checkpoints/pi0_base/params"))
88
+ )
89
+
90
+ loss = model.compute_loss(key, obs, act)
91
+ assert loss.shape == (batch_size, config.action_horizon)
92
+
93
+ actions = model.sample_actions(key, obs, num_steps=10)
94
+ assert actions.shape == (batch_size, model.action_horizon, model.action_dim)
capvector-pi05/src/openpi/models/pi0.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import einops
4
+ import flax.nnx as nnx
5
+ import flax.nnx.bridge as nnx_bridge
6
+ import jax
7
+ import jax.numpy as jnp
8
+ from typing_extensions import override
9
+
10
+ from openpi.models import model as _model
11
+ from openpi.models import pi0_config
12
+ import openpi.models.gemma as _gemma
13
+ import openpi.models.siglip as _siglip
14
+ from openpi.shared import array_typing as at
15
+
16
+ logger = logging.getLogger("openpi")
17
+
18
+
19
+ def make_attn_mask(input_mask, mask_ar):
20
+ """Adapted from big_vision.
21
+
22
+ Tokens can attend to valid inputs tokens which have a cumulative mask_ar
23
+ smaller or equal to theirs. This way `mask_ar` bool[?B, N] can be used to
24
+ setup several types of attention, for example:
25
+
26
+ [[1 1 1 1 1 1]]: pure causal attention.
27
+
28
+ [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
29
+ themselves and the last 3 tokens have a causal attention. The first
30
+ entry could also be a 1 without changing behaviour.
31
+
32
+ [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
33
+ block can attend all previous blocks and all tokens on the same block.
34
+
35
+ Args:
36
+ input_mask: bool[B, N] true if its part of the input, false if padding.
37
+ mask_ar: bool[?B, N] mask that's true where previous tokens cannot depend on
38
+ it and false where it shares the same attention mask as the previous token.
39
+ """
40
+ mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape)
41
+ cumsum = jnp.cumsum(mask_ar, axis=1)
42
+ attn_mask = cumsum[:, None, :] <= cumsum[:, :, None]
43
+ valid_mask = input_mask[:, None, :] * input_mask[:, :, None]
44
+ return jnp.logical_and(attn_mask, valid_mask)
45
+
46
+
47
+ @at.typecheck
48
+ def posemb_sincos(
49
+ pos: at.Real[at.Array, " b"], embedding_dim: int, min_period: float, max_period: float
50
+ ) -> at.Float[at.Array, "b {embedding_dim}"]:
51
+ """Computes sine-cosine positional embedding vectors for scalar positions."""
52
+ if embedding_dim % 2 != 0:
53
+ raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by 2")
54
+
55
+ fraction = jnp.linspace(0.0, 1.0, embedding_dim // 2)
56
+ period = min_period * (max_period / min_period) ** fraction
57
+ sinusoid_input = jnp.einsum(
58
+ "i,j->ij",
59
+ pos,
60
+ 1.0 / period * 2 * jnp.pi,
61
+ precision=jax.lax.Precision.HIGHEST,
62
+ )
63
+ return jnp.concatenate([jnp.sin(sinusoid_input), jnp.cos(sinusoid_input)], axis=-1)
64
+
65
+
66
+ class Pi0(_model.BaseModel):
67
+ def __init__(self, config: pi0_config.Pi0Config, rngs: nnx.Rngs):
68
+ super().__init__(config.action_dim, config.action_horizon, config.max_token_len)
69
+ self.pi05 = config.pi05
70
+ paligemma_config = _gemma.get_config(config.paligemma_variant)
71
+ action_expert_config = _gemma.get_config(config.action_expert_variant)
72
+ # TODO: rewrite gemma in NNX. For now, use bridge.
73
+ llm = nnx_bridge.ToNNX(
74
+ _gemma.Module(
75
+ configs=[paligemma_config, action_expert_config],
76
+ embed_dtype=config.dtype,
77
+ adarms=config.pi05,
78
+ )
79
+ )
80
+ llm.lazy_init(rngs=rngs, method="init", use_adarms=[False, True] if config.pi05 else [False, False])
81
+ img = nnx_bridge.ToNNX(
82
+ _siglip.Module(
83
+ num_classes=paligemma_config.width,
84
+ variant="So400m/14",
85
+ pool_type="none",
86
+ scan=True,
87
+ dtype_mm=config.dtype,
88
+ )
89
+ )
90
+ img.lazy_init(next(iter(config.fake_obs().images.values())), train=False, rngs=rngs)
91
+ self.PaliGemma = nnx.Dict(llm=llm, img=img)
92
+ self.action_in_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)
93
+ if config.pi05:
94
+ self.time_mlp_in = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)
95
+ self.time_mlp_out = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)
96
+ else:
97
+ self.state_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)
98
+ self.action_time_mlp_in = nnx.Linear(2 * action_expert_config.width, action_expert_config.width, rngs=rngs)
99
+ self.action_time_mlp_out = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)
100
+ self.action_out_proj = nnx.Linear(action_expert_config.width, config.action_dim, rngs=rngs)
101
+
102
+ # This attribute gets automatically set by model.train() and model.eval().
103
+ self.deterministic = True
104
+
105
+ @at.typecheck
106
+ def embed_prefix(
107
+ self, obs: _model.Observation
108
+ ) -> tuple[at.Float[at.Array, "b s emb"], at.Bool[at.Array, "b s"], at.Bool[at.Array, " s"]]:
109
+ input_mask = []
110
+ ar_mask = []
111
+ tokens = []
112
+ # embed images
113
+ for name in obs.images:
114
+ image_tokens, _ = self.PaliGemma.img(obs.images[name], train=False)
115
+
116
+ tokens.append(image_tokens)
117
+ input_mask.append(
118
+ einops.repeat(
119
+ obs.image_masks[name],
120
+ "b -> b s",
121
+ s=image_tokens.shape[1],
122
+ )
123
+ )
124
+ # image tokens attend to each other
125
+ ar_mask += [False] * image_tokens.shape[1]
126
+
127
+ # add language (aka tokenized inputs)
128
+ if obs.tokenized_prompt is not None:
129
+ tokenized_inputs = self.PaliGemma.llm(obs.tokenized_prompt, method="embed")
130
+ tokens.append(tokenized_inputs)
131
+ input_mask.append(obs.tokenized_prompt_mask)
132
+ # full attention between image and language inputs
133
+ ar_mask += [False] * tokenized_inputs.shape[1]
134
+ tokens = jnp.concatenate(tokens, axis=1)
135
+ input_mask = jnp.concatenate(input_mask, axis=1)
136
+ ar_mask = jnp.array(ar_mask)
137
+ return tokens, input_mask, ar_mask
138
+
139
+ @at.typecheck
140
+ def embed_suffix(
141
+ self, obs: _model.Observation, noisy_actions: _model.Actions, timestep: at.Float[at.Array, " b"]
142
+ ) -> tuple[
143
+ at.Float[at.Array, "b s emb"],
144
+ at.Bool[at.Array, "b s"],
145
+ at.Bool[at.Array, " s"],
146
+ at.Float[at.Array, "b emb"] | None,
147
+ ]:
148
+ input_mask = []
149
+ ar_mask = []
150
+ tokens = []
151
+ if not self.pi05:
152
+ # add a single state token
153
+ state_token = self.state_proj(obs.state)[:, None, :]
154
+ tokens.append(state_token)
155
+ input_mask.append(jnp.ones((obs.state.shape[0], 1), dtype=jnp.bool_))
156
+ # image/language inputs do not attend to state or actions
157
+ ar_mask += [True]
158
+
159
+ action_tokens = self.action_in_proj(noisy_actions)
160
+ # embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
161
+ time_emb = posemb_sincos(timestep, self.action_in_proj.out_features, min_period=4e-3, max_period=4.0)
162
+ if self.pi05:
163
+ # time MLP (for adaRMS)
164
+ time_emb = self.time_mlp_in(time_emb)
165
+ time_emb = nnx.swish(time_emb)
166
+ time_emb = self.time_mlp_out(time_emb)
167
+ time_emb = nnx.swish(time_emb)
168
+ action_expert_tokens = action_tokens
169
+ adarms_cond = time_emb
170
+ else:
171
+ # mix timestep + action information using an MLP (no adaRMS)
172
+ time_tokens = einops.repeat(time_emb, "b emb -> b s emb", s=self.action_horizon)
173
+ action_time_tokens = jnp.concatenate([action_tokens, time_tokens], axis=-1)
174
+ action_time_tokens = self.action_time_mlp_in(action_time_tokens)
175
+ action_time_tokens = nnx.swish(action_time_tokens)
176
+ action_time_tokens = self.action_time_mlp_out(action_time_tokens)
177
+ action_expert_tokens = action_time_tokens
178
+ adarms_cond = None
179
+ tokens.append(action_expert_tokens)
180
+ input_mask.append(jnp.ones(action_expert_tokens.shape[:2], dtype=jnp.bool_))
181
+ # image/language/state inputs do not attend to action tokens
182
+ ar_mask += [True] + ([False] * (self.action_horizon - 1))
183
+ tokens = jnp.concatenate(tokens, axis=1)
184
+ input_mask = jnp.concatenate(input_mask, axis=1)
185
+ ar_mask = jnp.array(ar_mask)
186
+ return tokens, input_mask, ar_mask, adarms_cond
187
+
188
+ @override
189
+ def compute_loss(
190
+ self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool = False
191
+ ) -> at.Float[at.Array, "*b ah"]:
192
+ preprocess_rng, noise_rng, time_rng = jax.random.split(rng, 3)
193
+ observation = _model.preprocess_observation(preprocess_rng, observation, train=train)
194
+
195
+ batch_shape = actions.shape[:-2]
196
+ noise = jax.random.normal(noise_rng, actions.shape)
197
+ time = jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 + 0.001
198
+ time_expanded = time[..., None, None]
199
+ x_t = time_expanded * noise + (1 - time_expanded) * actions
200
+ u_t = noise - actions
201
+
202
+ # one big forward pass of prefix + suffix at once
203
+ prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
204
+ suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix(observation, x_t, time)
205
+ input_mask = jnp.concatenate([prefix_mask, suffix_mask], axis=1)
206
+ ar_mask = jnp.concatenate([prefix_ar_mask, suffix_ar_mask], axis=0)
207
+ attn_mask = make_attn_mask(input_mask, ar_mask)
208
+ positions = jnp.cumsum(input_mask, axis=1) - 1
209
+ (prefix_out, suffix_out), _ = self.PaliGemma.llm(
210
+ [prefix_tokens, suffix_tokens], mask=attn_mask, positions=positions, adarms_cond=[None, adarms_cond]
211
+ )
212
+ v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])
213
+
214
+ return jnp.mean(jnp.square(v_t - u_t), axis=-1)
215
+
216
+ @override
217
+ def sample_actions(
218
+ self,
219
+ rng: at.KeyArrayLike,
220
+ observation: _model.Observation,
221
+ *,
222
+ num_steps: int | at.Int[at.Array, ""] = 10,
223
+ noise: at.Float[at.Array, "b ah ad"] | None = None,
224
+ ) -> _model.Actions:
225
+ observation = _model.preprocess_observation(None, observation, train=False)
226
+ # note that we use the convention more common in diffusion literature, where t=1 is noise and t=0 is the target
227
+ # distribution. yes, this is the opposite of the pi0 paper, and I'm sorry.
228
+ dt = -1.0 / num_steps
229
+ batch_size = observation.state.shape[0]
230
+ if noise is None:
231
+ noise = jax.random.normal(rng, (batch_size, self.action_horizon, self.action_dim))
232
+
233
+ # first fill KV cache with a forward pass of the prefix
234
+ prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
235
+ prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
236
+ positions = jnp.cumsum(prefix_mask, axis=1) - 1
237
+ _, kv_cache = self.PaliGemma.llm([prefix_tokens, None], mask=prefix_attn_mask, positions=positions)
238
+
239
+ def step(carry):
240
+ x_t, time = carry
241
+ suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix(
242
+ observation, x_t, jnp.broadcast_to(time, batch_size)
243
+ )
244
+ # `suffix_attn_mask` is shape (b, suffix_len, suffix_len) indicating how the suffix tokens can attend to each
245
+ # other
246
+ suffix_attn_mask = make_attn_mask(suffix_mask, suffix_ar_mask)
247
+ # `prefix_attn_mask` is shape (b, suffix_len, prefix_len) indicating how the suffix tokens can attend to the
248
+ # prefix tokens
249
+ prefix_attn_mask = einops.repeat(prefix_mask, "b p -> b s p", s=suffix_tokens.shape[1])
250
+ # `combined_mask` is shape (b, suffix_len, prefix_len + suffix_len) indicating how the suffix tokens (which
251
+ # generate the queries) can attend to the full prefix + suffix sequence (which generates the keys and values)
252
+ full_attn_mask = jnp.concatenate([prefix_attn_mask, suffix_attn_mask], axis=-1)
253
+ assert full_attn_mask.shape == (
254
+ batch_size,
255
+ suffix_tokens.shape[1],
256
+ prefix_tokens.shape[1] + suffix_tokens.shape[1],
257
+ )
258
+ # `positions` is shape (b, suffix_len) indicating the positions of the suffix tokens
259
+ positions = jnp.sum(prefix_mask, axis=-1)[:, None] + jnp.cumsum(suffix_mask, axis=-1) - 1
260
+
261
+ (prefix_out, suffix_out), _ = self.PaliGemma.llm(
262
+ [None, suffix_tokens],
263
+ mask=full_attn_mask,
264
+ positions=positions,
265
+ kv_cache=kv_cache,
266
+ adarms_cond=[None, adarms_cond],
267
+ )
268
+ assert prefix_out is None
269
+ v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])
270
+
271
+ return x_t + dt * v_t, time + dt
272
+
273
+ def cond(carry):
274
+ x_t, time = carry
275
+ # robust to floating-point error
276
+ return time >= -dt / 2
277
+
278
+ x_0, _ = jax.lax.while_loop(cond, step, (noise, 1.0))
279
+ return x_0
capvector-pi05/src/openpi/models/pi0_config.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from typing import TYPE_CHECKING
3
+
4
+ import flax.nnx as nnx
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from typing_extensions import override
8
+
9
+ from openpi.models import model as _model
10
+ import openpi.models.gemma as _gemma
11
+ from openpi.shared import array_typing as at
12
+ import openpi.shared.nnx_utils as nnx_utils
13
+
14
+ if TYPE_CHECKING:
15
+ from openpi.models.pi0 import Pi0
16
+
17
+
18
+ @dataclasses.dataclass(frozen=True)
19
+ class Pi0Config(_model.BaseModelConfig):
20
+ dtype: str = "bfloat16"
21
+ paligemma_variant: _gemma.Variant = "gemma_2b"
22
+ action_expert_variant: _gemma.Variant = "gemma_300m"
23
+
24
+ # Set the model specific defaults.
25
+ action_dim: int = 32
26
+ action_horizon: int = 50
27
+ max_token_len: int = None # type: ignore
28
+ # Pi05 has two differences from Pi0:
29
+ # - the state input is part of the discrete language tokens rather than a continuous input that is part of the suffix
30
+ # - the action expert uses adaRMSNorm to inject the flow matching timestep
31
+ pi05: bool = False
32
+ # This config option is not used directly by the model, but it is read by the ModelTransformFactory.
33
+ discrete_state_input: bool = None # type: ignore
34
+
35
+ def __post_init__(self):
36
+ if self.max_token_len is None:
37
+ object.__setattr__(self, "max_token_len", 200 if self.pi05 else 48)
38
+ if self.discrete_state_input is None:
39
+ object.__setattr__(self, "discrete_state_input", self.pi05)
40
+
41
+ @property
42
+ @override
43
+ def model_type(self) -> _model.ModelType:
44
+ if self.pi05:
45
+ return _model.ModelType.PI05
46
+ return _model.ModelType.PI0
47
+
48
+ @override
49
+ def create(self, rng: at.KeyArrayLike) -> "Pi0":
50
+ from openpi.models.pi0 import Pi0
51
+
52
+ return Pi0(self, rngs=nnx.Rngs(rng))
53
+
54
+ @override
55
+ def inputs_spec(self, *, batch_size: int = 1) -> tuple[_model.Observation, _model.Actions]:
56
+ image_spec = jax.ShapeDtypeStruct([batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32)
57
+ image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_)
58
+
59
+ with at.disable_typechecking():
60
+ observation_spec = _model.Observation(
61
+ images={
62
+ "base_0_rgb": image_spec,
63
+ "left_wrist_0_rgb": image_spec,
64
+ "right_wrist_0_rgb": image_spec,
65
+ },
66
+ image_masks={
67
+ "base_0_rgb": image_mask_spec,
68
+ "left_wrist_0_rgb": image_mask_spec,
69
+ "right_wrist_0_rgb": image_mask_spec,
70
+ },
71
+ state=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32),
72
+ tokenized_prompt=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32),
73
+ tokenized_prompt_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], bool),
74
+ )
75
+ action_spec = jax.ShapeDtypeStruct([batch_size, self.action_horizon, self.action_dim], jnp.float32)
76
+
77
+ return observation_spec, action_spec
78
+
79
+ def get_freeze_filter(self) -> nnx.filterlib.Filter:
80
+ """Returns the freeze filter based on the model config."""
81
+ filters = []
82
+ has_lora = False
83
+ gemma_params_filter = nnx_utils.PathRegex(".*llm.*")
84
+ action_expert_params_filter = nnx_utils.PathRegex(".*llm.*_1.*")
85
+ if "lora" in self.paligemma_variant:
86
+ filters.append(
87
+ gemma_params_filter,
88
+ )
89
+ if "lora" not in self.action_expert_variant:
90
+ # If only freeze gemma params, exclude action expert params.
91
+ filters.append(
92
+ nnx.Not(action_expert_params_filter),
93
+ )
94
+ has_lora = True
95
+ elif "lora" in self.action_expert_variant:
96
+ filters.append(
97
+ action_expert_params_filter,
98
+ )
99
+ has_lora = True
100
+
101
+ if has_lora:
102
+ # If any lora is used, exclude all lora params.
103
+ filters.append(
104
+ nnx.Not(nnx_utils.PathRegex(".*lora.*")),
105
+ )
106
+ if not filters:
107
+ return nnx.Nothing
108
+ return nnx.All(*filters)
capvector-pi05/src/openpi/models/pi0_fast.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import logging
3
+ from typing import Any
4
+
5
+ import einops
6
+ import flax.nnx as nnx
7
+ import flax.nnx.bridge as nnx_bridge
8
+ import jax
9
+ import jax.numpy as jnp
10
+ from typing_extensions import override
11
+
12
+ from openpi.models import model as _model
13
+ import openpi.models.gemma_fast as _gemma
14
+ import openpi.models.siglip as _siglip
15
+ from openpi.shared import array_typing as at
16
+ import openpi.shared.nnx_utils as nnx_utils
17
+
18
+ logger = logging.getLogger("openpi")
19
+
20
+ PALIGEMMA_EOS_TOKEN = 1
21
+
22
+
23
+ def make_attn_mask(input_mask, mask_ar):
24
+ """Adapted from big_vision.
25
+
26
+ Tokens can attend to valid inputs tokens which have a cumulative mask_ar
27
+ smaller or equal to theirs. This way `mask_ar` bool[?B, N] can be used to
28
+ setup several types of attention, for example:
29
+
30
+ [[1 1 1 1 1 1]]: pure causal attention.
31
+
32
+ [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
33
+ themselves and the last 3 tokens have a causal attention. The first
34
+ entry could also be a 1 without changing behaviour.
35
+
36
+ [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
37
+ block can attend all previous blocks and all tokens on the same block.
38
+
39
+ Args:
40
+ input_mask: bool[B, N] true if its part of the input, false if padding.
41
+ mask_ar: bool[?B, N] mask that's true where previous tokens cannot depend on
42
+ it and false where it shares the same attention mask as the previous token.
43
+ """
44
+ mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape)
45
+ cumsum = jnp.cumsum(mask_ar, axis=1)
46
+ attn_mask = cumsum[:, None, :] <= cumsum[:, :, None]
47
+ valid_mask = input_mask[:, None, :] * input_mask[:, :, None]
48
+ return jnp.logical_and(attn_mask, valid_mask)
49
+
50
+
51
+ @jax.vmap
52
+ def left_to_right_align(x, input_mask, attn_mask):
53
+ """Converts input from left-align to right-aligned."""
54
+ # Due to vmap, this is operating in a single example (not batch level).
55
+ assert x.ndim == 2
56
+ assert input_mask.ndim == 1
57
+ assert attn_mask.ndim == 2
58
+ assert x.shape[0] == input_mask.shape[0]
59
+ assert attn_mask.shape[0] == attn_mask.shape[1], attn_mask.shape
60
+ seqlen = jnp.max(input_mask * jnp.arange(input_mask.shape[0])) + 1
61
+ x = jnp.roll(x, -seqlen, axis=0)
62
+ input_mask = jnp.roll(input_mask, -seqlen, axis=0)
63
+ attn_mask = jnp.roll(attn_mask, -seqlen, axis=(0, 1))
64
+ return x, input_mask, attn_mask
65
+
66
+
67
+ def put_along_last_axis(arr, indices, values):
68
+ """Like np.put_along_axis(..., axis=-1), since jax is missing it."""
69
+ assert arr.ndim == indices.ndim == values.ndim, (arr.ndim, indices.ndim, values.ndim)
70
+ onehot = jax.nn.one_hot(indices, arr.shape[-1], dtype=values.dtype)
71
+ put_mask = jnp.einsum("...i,...in->...n", jnp.ones(values.shape, jnp.int32), onehot)
72
+ put_values = jnp.einsum("...i,...in->...n", values, onehot)
73
+ return jnp.where(put_mask, put_values, arr)
74
+
75
+
76
+ @dataclasses.dataclass(frozen=True)
77
+ class Pi0FASTConfig(_model.BaseModelConfig):
78
+ dtype: str = "bfloat16"
79
+ paligemma_variant: _gemma.Variant = "gemma_2b"
80
+
81
+ # Set the model specific defaults.
82
+ action_dim: int = 32
83
+ action_horizon: int = 32
84
+ max_token_len: int = 250
85
+
86
+ # Tokenizer for the fast model.
87
+ fast_model_tokenizer: Any | None = None
88
+ # Keyword arguments for the fast model tokenizer.
89
+ fast_model_tokenizer_kwargs: dict[str, Any] | None = None
90
+
91
+ @property
92
+ @override
93
+ def model_type(self) -> _model.ModelType:
94
+ return _model.ModelType.PI0_FAST
95
+
96
+ @override
97
+ def create(self, rng: at.KeyArrayLike) -> "Pi0FAST":
98
+ return Pi0FAST(self, rngs=nnx.Rngs(rng))
99
+
100
+ @override
101
+ def inputs_spec(self, *, batch_size: int = 1) -> tuple[_model.Observation, _model.Actions]:
102
+ image_spec = jax.ShapeDtypeStruct([batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32)
103
+ image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_)
104
+
105
+ with at.disable_typechecking():
106
+ observation_spec = _model.Observation(
107
+ images={
108
+ "base_0_rgb": image_spec,
109
+ "base_1_rgb": image_spec,
110
+ "wrist_0_rgb": image_spec,
111
+ },
112
+ image_masks={
113
+ "base_0_rgb": image_mask_spec,
114
+ "base_1_rgb": image_mask_spec,
115
+ "wrist_0_rgb": image_mask_spec,
116
+ },
117
+ state=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32),
118
+ tokenized_prompt=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32),
119
+ tokenized_prompt_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], bool),
120
+ token_ar_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32),
121
+ token_loss_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.bool_),
122
+ )
123
+ action_spec = jax.ShapeDtypeStruct([batch_size, self.action_horizon, self.action_dim], jnp.float32)
124
+
125
+ return observation_spec, action_spec
126
+
127
+ def get_freeze_filter(self) -> nnx.filterlib.Filter:
128
+ """Returns the freeze filter based on the model config."""
129
+ if "lora" in self.paligemma_variant:
130
+ return nnx.All(nnx_utils.PathRegex(".*llm.*"), nnx.Not(nnx_utils.PathRegex(".*lora.*")))
131
+ return nnx.Nothing
132
+
133
+
134
+ class Pi0FAST(_model.BaseModel):
135
+ def __init__(self, config: Pi0FASTConfig, rngs: nnx.Rngs):
136
+ super().__init__(config.action_dim, config.action_horizon, config.max_token_len)
137
+ paligemma_config = _gemma.get_config(config.paligemma_variant)
138
+ # TODO: rewrite gemma in NNX. For now, use bridge.
139
+ llm = nnx_bridge.ToNNX(
140
+ _gemma.Module(
141
+ **paligemma_config,
142
+ embed_dtype=config.dtype,
143
+ cache_dtype=config.dtype,
144
+ )
145
+ )
146
+ llm.lazy_init(rngs=rngs, method="init")
147
+ img = nnx_bridge.ToNNX(
148
+ _siglip.Module(
149
+ num_classes=paligemma_config.width,
150
+ variant="So400m/14",
151
+ pool_type="none",
152
+ scan=True,
153
+ dtype_mm=config.dtype,
154
+ )
155
+ )
156
+ img.lazy_init(next(iter(config.fake_obs().images.values())), train=False, rngs=rngs)
157
+ self.PaliGemma = nnx.Dict(llm=llm, img=img)
158
+
159
+ @at.typecheck
160
+ def embed_inputs(
161
+ self, obs: _model.Observation
162
+ ) -> tuple[at.Float[at.Array, "b s emb"], at.Bool[at.Array, "b s"], at.Int[at.Array, "b s"]]:
163
+ input_mask = []
164
+ ar_mask = []
165
+ token_embeddings = []
166
+ # embed images
167
+ for name in obs.images:
168
+ image_token_embeddings, _ = self.PaliGemma.img(obs.images[name], train=False)
169
+
170
+ token_embeddings.append(image_token_embeddings)
171
+ input_mask.append(
172
+ einops.repeat(
173
+ obs.image_masks[name],
174
+ "b -> b s",
175
+ s=image_token_embeddings.shape[1],
176
+ )
177
+ )
178
+ # image tokens attend to each other --> AR mask = 0
179
+ ar_mask.append(0 * input_mask[-1])
180
+
181
+ # add tokenized inputs
182
+ assert obs.tokenized_prompt is not None, "Tokenized prompt is required"
183
+ assert obs.tokenized_prompt_mask is not None, "Tokenized prompt mask is required"
184
+ assert obs.token_ar_mask is not None, "Token auto-regressive mask is required"
185
+ tokenized_inputs_embeddings = self.PaliGemma.llm(obs.tokenized_prompt, embed_only=True)
186
+ token_embeddings.append(tokenized_inputs_embeddings)
187
+ input_mask.append(obs.tokenized_prompt_mask)
188
+ ar_mask.append(obs.token_ar_mask)
189
+
190
+ # return embeddings, input mask, and ar mask
191
+ return (
192
+ jnp.concatenate(token_embeddings, axis=1),
193
+ jnp.concatenate(input_mask, axis=1),
194
+ jnp.concatenate(ar_mask, axis=1),
195
+ )
196
+
197
+ @override
198
+ def compute_loss(
199
+ self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool = False
200
+ ) -> at.Float[at.Array, "*b ah"]:
201
+ observation = _model.preprocess_observation(
202
+ rng, observation, train=train, image_keys=list(observation.images.keys())
203
+ )
204
+
205
+ # Compute inputs: one big forward pass of prefix + suffix at once
206
+ input_token_embeddings, input_mask, ar_mask = self.embed_inputs(observation)
207
+ attn_mask = make_attn_mask(input_mask, ar_mask)
208
+
209
+ # Compute one-hot targets: we predict *next* token, so shift the input tokens by one.
210
+ targets = jax.nn.one_hot(
211
+ observation.tokenized_prompt[:, 1:],
212
+ self.PaliGemma.llm.module.vocab_size,
213
+ )
214
+
215
+ # Each input predicts *next* token, so we don't input the last token.
216
+ pre_logits, _, _ = self.PaliGemma.llm(
217
+ embedded_prefix=input_token_embeddings[:, :-1],
218
+ mask=attn_mask[:, :-1, :-1],
219
+ return_prelogits=True,
220
+ )
221
+
222
+ # Only decode logits for the target tokens to save memory
223
+ # (decoding matmul is large because it is a seq_len x vocab_size dense layer).
224
+ logits, _ = self.PaliGemma.llm(
225
+ pre_logits=pre_logits[:, -targets.shape[1] :],
226
+ )
227
+ logp = jax.nn.log_softmax(logits, axis=-1)
228
+
229
+ # Compute CE loss on token targets
230
+ assert observation.token_loss_mask is not None, "Token loss mask is required"
231
+ loss_mask = observation.token_loss_mask[:, 1:]
232
+ token_pplx = jnp.sum(targets * logp, axis=-1)
233
+ return -jnp.sum(token_pplx * loss_mask, axis=-1) / jnp.clip(jnp.sum(loss_mask, -1), 1)
234
+
235
+ @override
236
+ def sample_actions(
237
+ self,
238
+ rng: at.KeyArrayLike,
239
+ observation: _model.Observation,
240
+ *,
241
+ max_decoding_steps: int | at.Int[at.Array, ""] = 256,
242
+ temperature: float = 0.0,
243
+ ) -> _model.Actions:
244
+ # TODO: this is a hack to get the image keys.
245
+ observation = _model.preprocess_observation(
246
+ None, observation, train=False, image_keys=list(observation.images.keys())
247
+ )
248
+
249
+ # embed inputs
250
+ prefix_token_embeddings, prefix_mask, prefix_ar_mask = self.embed_inputs(observation)
251
+ prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
252
+
253
+ # left to right align all input token sequences
254
+ prefix_token_embeddings, prefix_mask, prefix_attn_mask = left_to_right_align(
255
+ prefix_token_embeddings, prefix_mask, prefix_attn_mask
256
+ )
257
+ prefill_size = prefix_token_embeddings.shape[1]
258
+ prefill_len = jnp.sum(prefix_mask, axis=-1)
259
+ prefix_start = prefill_size - prefill_len
260
+
261
+ # first fill KV cache with a forward pass of the prefix
262
+ # pad attention mask to set the size of the KV cache (prefill_size + max_decoding_steps)
263
+ prefix_attn_mask = jnp.pad(prefix_attn_mask, ((0, 0), (0, 0), (0, max_decoding_steps)))
264
+ prefix_positions = jnp.cumsum(prefix_mask, axis=-1) - 1
265
+ prefix_logits, kv_cache, _ = self.PaliGemma.llm(
266
+ embedded_prefix=prefix_token_embeddings, mask=prefix_attn_mask, positions=prefix_positions, decode=True
267
+ )
268
+
269
+ # prepare decoding -- final logit decodes the first token
270
+ last_logit = prefix_logits[:, -1:]
271
+ output_tokens = jnp.zeros((last_logit.shape[0], max_decoding_steps))
272
+
273
+ def step(carry):
274
+ rng, last_logit, output_tokens, cache, _, step = carry
275
+
276
+ # Sample token from last logit
277
+ # Split RNG for this step
278
+ rng, rng_step = jax.random.split(rng)
279
+ token = jax.lax.cond(
280
+ temperature > 0.0,
281
+ lambda _: jax.random.categorical(rng_step, last_logit / temperature, axis=-1),
282
+ lambda _: jnp.argmax(last_logit, axis=-1),
283
+ operand=None,
284
+ )
285
+ output_tokens = put_along_last_axis(output_tokens, jnp.broadcast_to(step, (token.shape[0], 1)), token)
286
+
287
+ # Check for early stopping --> stop if all batch elements have EOS token
288
+ has_eos = jnp.any(token == PALIGEMMA_EOS_TOKEN, axis=-1)
289
+ all_eos = jnp.all(has_eos)
290
+
291
+ # Decode one step
292
+ token_embedding = self.PaliGemma.llm(token, embed_only=True)
293
+ positions = prefill_len[:, None] + step + 1
294
+ mask = jnp.logical_and(
295
+ jnp.arange(prefill_size + max_decoding_steps)[None, None, :] >= prefix_start[:, None, None],
296
+ jnp.arange(prefill_size + max_decoding_steps)[None, None, :]
297
+ < (jnp.broadcast_to(prefill_size + step + 1, (prefix_start.shape[0], 1, 1))),
298
+ )
299
+ last_logit, kv_cache, _ = self.PaliGemma.llm(
300
+ embedded_prefix=token_embedding, mask=mask, positions=positions, decode=True, kv_cache=cache
301
+ )
302
+
303
+ return rng, last_logit, output_tokens, kv_cache, all_eos, step + 1
304
+
305
+ def cond(carry):
306
+ _, _, _, _, all_eos, step = carry
307
+ return (~all_eos) & (step < max_decoding_steps)
308
+
309
+ # Use lax.while_loop so we can jit the full decoding loop.
310
+ _, _, output_tokens, _, _, _ = jax.lax.while_loop(
311
+ cond, step, (rng, last_logit, output_tokens, kv_cache, False, 0)
312
+ )
313
+ return output_tokens