diff --git a/capvector-pi05/examples/libero/Dockerfile b/capvector-pi05/examples/libero/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..3e1ed413f0441f3ab556974910fc6a7e138dd1e6 --- /dev/null +++ b/capvector-pi05/examples/libero/Dockerfile @@ -0,0 +1,59 @@ +# Dockerfile for the LIBERO benchmark. + +# Build the container: +# docker build . -t libero -f examples/libero/Dockerfile + +# Run the container: +# docker run --rm -it --network=host -v .:/app -v /tmp/.X11-unix:/tmp/.X11-unix:ro -e DISPLAY=$DISPLAY --gpus all libero /bin/bash + +FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0 +COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/ + +RUN apt-get update && \ + apt-get install -y \ + make \ + g++ \ + clang \ + libosmesa6-dev \ + libgl1-mesa-glx \ + libglew-dev \ + libglfw3-dev \ + libgles2-mesa-dev \ + libglib2.0-0 \ + libsm6 \ + libxrender1 \ + libxext6 + +WORKDIR /app + +# Copy from the cache instead of linking since it's a mounted volume +ENV UV_LINK_MODE=copy + +# Write the virtual environment outside of the project directory so it doesn't +# leak out of the container when we mount the application code. +ENV UV_PROJECT_ENVIRONMENT=/.venv + +# Copy the requirements files so we can install dependencies. +# The rest of the project is mounted as a volume, so we don't need to rebuild on changes. +# This strategy is best for development-style usage. +COPY ./examples/libero/requirements.txt /tmp/requirements.txt +COPY ./third_party/libero/requirements.txt /tmp/requirements-libero.txt +COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml + +# Install python dependencies. +RUN uv venv --python 3.8 $UV_PROJECT_ENVIRONMENT +RUN uv pip sync /tmp/requirements.txt /tmp/requirements-libero.txt /tmp/openpi-client/pyproject.toml --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match +ENV PYTHONPATH=/app:/app/packages/openpi-client/src:/app/third_party/libero + +# Create a default config file to avoid an input prompt from LIBERO's init script. +# https://github.com/Lifelong-Robot-Learning/LIBERO/blob/master/libero/libero/__init__.py +ENV LIBERO_CONFIG_PATH=/tmp/libero +RUN mkdir -p /tmp/libero && cat <<'EOF' > /tmp/libero/config.yaml +benchmark_root: /app/third_party/libero/libero/libero +bddl_files: /app/third_party/libero/libero/libero/bddl_files +init_states: /app/third_party/libero/libero/libero/init_files +datasets: /app/third_party/libero/libero/datasets +assets: /app/third_party/libero/libero/libero/assets +EOF + +CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/libero/main.py $CLIENT_ARGS"] diff --git a/capvector-pi05/examples/libero/README.md b/capvector-pi05/examples/libero/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2e16f67368d09ae80e16ab67515907fcb034c1cc --- /dev/null +++ b/capvector-pi05/examples/libero/README.md @@ -0,0 +1,71 @@ +# LIBERO Benchmark + +This example runs the LIBERO benchmark: https://github.com/Lifelong-Robot-Learning/LIBERO + +Note: When updating requirements.txt in this directory, there is an additional flag `--extra-index-url https://download.pytorch.org/whl/cu113` that must be added to the `uv pip compile` command. + +This example requires git submodules to be initialized. Don't forget to run: + +```bash +git submodule update --init --recursive +``` + +## With Docker (recommended) + +```bash +# Grant access to the X11 server: +sudo xhost +local:docker + +# To run with the default checkpoint and task suite: +SERVER_ARGS="--env LIBERO" docker compose -f examples/libero/compose.yml up --build + +# To run with glx for Mujoco instead (use this if you have egl errors): +MUJOCO_GL=glx SERVER_ARGS="--env LIBERO" docker compose -f examples/libero/compose.yml up --build +``` + +You can customize the loaded checkpoint by providing additional `SERVER_ARGS` (see `scripts/serve_policy.py`), and the LIBERO task suite by providing additional `CLIENT_ARGS` (see `examples/libero/main.py`). +For example: + +```bash +# To load a custom checkpoint (located in the top-level openpi/ directory): +export SERVER_ARGS="--env LIBERO policy:checkpoint --policy.config pi05_libero --policy.dir ./my_custom_checkpoint" + +# To run the libero_10 task suite: +export CLIENT_ARGS="--args.task-suite-name libero_10" +``` + +## Without Docker (not recommended) + +Terminal window 1: + +```bash +# Create virtual environment +uv venv --python 3.8 examples/libero/.venv +source examples/libero/.venv/bin/activate +uv pip sync examples/libero/requirements.txt third_party/libero/requirements.txt --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match +uv pip install -e packages/openpi-client +uv pip install -e third_party/libero +export PYTHONPATH=$PYTHONPATH:$PWD/third_party/libero + +# Run the simulation +python examples/libero/main.py + +# To run with glx for Mujoco instead (use this if you have egl errors): +MUJOCO_GL=glx python examples/libero/main.py +``` + +Terminal window 2: + +```bash +# Run the server +uv run scripts/serve_policy.py --env LIBERO +``` + +## Results + +If you want to reproduce the following numbers, you can evaluate the checkpoint at `gs://openpi-assets/checkpoints/pi05_libero/`. This +checkpoint was trained in openpi with the `pi05_libero` config. + +| Model | Libero Spatial | Libero Object | Libero Goal | Libero 10 | Average | +|-------|---------------|---------------|-------------|-----------|---------| +| π0.5 @ 30k (finetuned) | 98.8 | 98.2 | 98.0 | 92.4 | 96.85 diff --git a/capvector-pi05/examples/libero/main.py b/capvector-pi05/examples/libero/main.py new file mode 100644 index 0000000000000000000000000000000000000000..2a1ab94db9aac8288708455d3404e03f08e5383f --- /dev/null +++ b/capvector-pi05/examples/libero/main.py @@ -0,0 +1,219 @@ +import collections +import dataclasses +import logging +import math +import pathlib + +import imageio +from libero.libero import benchmark +from libero.libero import get_libero_path +from libero.libero.envs import OffScreenRenderEnv +import numpy as np +from openpi_client import image_tools +from openpi_client import websocket_client_policy as _websocket_client_policy +import tqdm +import tyro + +LIBERO_DUMMY_ACTION = [0.0] * 6 + [-1.0] +LIBERO_ENV_RESOLUTION = 256 # resolution used to render training data + + +@dataclasses.dataclass +class Args: + ################################################################################################################# + # Model server parameters + ################################################################################################################# + host: str = "0.0.0.0" + port: int = 8000 + resize_size: int = 224 + replan_steps: int = 5 + + ################################################################################################################# + # LIBERO environment-specific parameters + ################################################################################################################# + task_suite_name: str = ( + "libero_spatial" # Task suite. Options: libero_spatial, libero_object, libero_goal, libero_10, libero_90 + ) + num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize i n sim + num_trials_per_task: int = 50 # Number of rollouts per task + + ################################################################################################################# + # Utils + ################################################################################################################# + video_out_path: str = "data/libero/videos" # Path to save videos + + seed: int = 7 # Random Seed (for reproducibility) + + +def eval_libero(args: Args) -> None: + # Set random seed + np.random.seed(args.seed) + + # Initialize LIBERO task suite + benchmark_dict = benchmark.get_benchmark_dict() + task_suite = benchmark_dict[args.task_suite_name]() + num_tasks_in_suite = task_suite.n_tasks + logging.info(f"Task suite: {args.task_suite_name}") + + pathlib.Path(args.video_out_path).mkdir(parents=True, exist_ok=True) + + if args.task_suite_name == "libero_spatial": + max_steps = 220 # longest training demo has 193 steps + elif args.task_suite_name == "libero_object": + max_steps = 280 # longest training demo has 254 steps + elif args.task_suite_name == "libero_goal": + max_steps = 300 # longest training demo has 270 steps + elif args.task_suite_name == "libero_10": + max_steps = 520 # longest training demo has 505 steps + elif args.task_suite_name == "libero_90": + max_steps = 400 # longest training demo has 373 steps + else: + raise ValueError(f"Unknown task suite: {args.task_suite_name}") + + client = _websocket_client_policy.WebsocketClientPolicy(args.host, args.port) + + # Start evaluation + total_episodes, total_successes = 0, 0 + for task_id in tqdm.tqdm(range(num_tasks_in_suite)): + # Get task + task = task_suite.get_task(task_id) + + # Get default LIBERO initial states + initial_states = task_suite.get_task_init_states(task_id) + + # Initialize LIBERO environment and task description + env, task_description = _get_libero_env(task, LIBERO_ENV_RESOLUTION, args.seed) + + # Start episodes + task_episodes, task_successes = 0, 0 + for episode_idx in tqdm.tqdm(range(args.num_trials_per_task)): + logging.info(f"\nTask: {task_description}") + + # Reset environment + env.reset() + action_plan = collections.deque() + + # Set initial states + obs = env.set_init_state(initial_states[episode_idx]) + + # Setup + t = 0 + replay_images = [] + + logging.info(f"Starting episode {task_episodes+1}...") + while t < max_steps + args.num_steps_wait: + try: + # IMPORTANT: Do nothing for the first few timesteps because the simulator drops objects + # and we need to wait for them to fall + if t < args.num_steps_wait: + obs, reward, done, info = env.step(LIBERO_DUMMY_ACTION) + t += 1 + continue + + # Get preprocessed image + # IMPORTANT: rotate 180 degrees to match train preprocessing + img = np.ascontiguousarray(obs["agentview_image"][::-1, ::-1]) + wrist_img = np.ascontiguousarray(obs["robot0_eye_in_hand_image"][::-1, ::-1]) + img = image_tools.convert_to_uint8( + image_tools.resize_with_pad(img, args.resize_size, args.resize_size) + ) + wrist_img = image_tools.convert_to_uint8( + image_tools.resize_with_pad(wrist_img, args.resize_size, args.resize_size) + ) + + # Save preprocessed image for replay video + replay_images.append(img) + + if not action_plan: + # Finished executing previous action chunk -- compute new chunk + # Prepare observations dict + element = { + "observation/image": img, + "observation/wrist_image": wrist_img, + "observation/state": np.concatenate( + ( + obs["robot0_eef_pos"], + _quat2axisangle(obs["robot0_eef_quat"]), + obs["robot0_gripper_qpos"], + ) + ), + "prompt": str(task_description), + } + + # Query model to get action + action_chunk = client.infer(element)["actions"] + assert ( + len(action_chunk) >= args.replan_steps + ), f"We want to replan every {args.replan_steps} steps, but policy only predicts {len(action_chunk)} steps." + action_plan.extend(action_chunk[: args.replan_steps]) + + action = action_plan.popleft() + + # Execute action in environment + obs, reward, done, info = env.step(action.tolist()) + if done: + task_successes += 1 + total_successes += 1 + break + t += 1 + + except Exception as e: + logging.error(f"Caught exception: {e}") + break + + task_episodes += 1 + total_episodes += 1 + + # Save a replay video of the episode + suffix = "success" if done else "failure" + task_segment = task_description.replace(" ", "_") + imageio.mimwrite( + pathlib.Path(args.video_out_path) / f"rollout_{task_segment}_{suffix}.mp4", + [np.asarray(x) for x in replay_images], + fps=10, + ) + + # Log current results + logging.info(f"Success: {done}") + logging.info(f"# episodes completed so far: {total_episodes}") + logging.info(f"# successes: {total_successes} ({total_successes / total_episodes * 100:.1f}%)") + + # Log final results + logging.info(f"Current task success rate: {float(task_successes) / float(task_episodes)}") + logging.info(f"Current total success rate: {float(total_successes) / float(total_episodes)}") + + logging.info(f"Total success rate: {float(total_successes) / float(total_episodes)}") + logging.info(f"Total episodes: {total_episodes}") + + +def _get_libero_env(task, resolution, seed): + """Initializes and returns the LIBERO environment, along with the task description.""" + task_description = task.language + task_bddl_file = pathlib.Path(get_libero_path("bddl_files")) / task.problem_folder / task.bddl_file + env_args = {"bddl_file_name": task_bddl_file, "camera_heights": resolution, "camera_widths": resolution} + env = OffScreenRenderEnv(**env_args) + env.seed(seed) # IMPORTANT: seed seems to affect object positions even when using fixed initial state + return env, task_description + + +def _quat2axisangle(quat): + """ + Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55 + """ + # clip quaternion + if quat[3] > 1.0: + quat[3] = 1.0 + elif quat[3] < -1.0: + quat[3] = -1.0 + + den = np.sqrt(1.0 - quat[3] * quat[3]) + if math.isclose(den, 0.0): + # This is (close to) a zero degree rotation, immediately return + return np.zeros(3) + + return (quat[:3] * 2.0 * math.acos(quat[3])) / den + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + tyro.cli(eval_libero) diff --git a/capvector-pi05/examples/libero/requirements.in b/capvector-pi05/examples/libero/requirements.in new file mode 100644 index 0000000000000000000000000000000000000000..d9fd2275d739216c453e67fbe3c060ccec56cca4 --- /dev/null +++ b/capvector-pi05/examples/libero/requirements.in @@ -0,0 +1,11 @@ +imageio[ffmpeg] +numpy==1.22.4 +tqdm +tyro +PyYaml +opencv-python==4.6.0.66 +torch==1.11.0+cu113 +torchvision==0.12.0+cu113 +torchaudio==0.11.0+cu113 +robosuite==1.4.1 +matplotlib==3.5.3 diff --git a/capvector-pi05/examples/libero/requirements.txt b/capvector-pi05/examples/libero/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..9123401789ca3740077de1a56842db145a697781 --- /dev/null +++ b/capvector-pi05/examples/libero/requirements.txt @@ -0,0 +1,136 @@ +# This file was autogenerated by uv via the following command: +# uv pip compile examples/libero/requirements.in -o examples/libero/requirements.txt --python-version 3.8 --index-strategy=unsafe-best-match +absl-py==2.1.0 + # via mujoco +certifi==2024.12.14 + # via requests +charset-normalizer==3.4.0 + # via requests +cycler==0.12.1 + # via matplotlib +docstring-parser==0.16 + # via tyro +etils==1.3.0 + # via mujoco +eval-type-backport==0.2.0 + # via tyro +evdev==1.7.1 + # via pynput +fonttools==4.55.3 + # via matplotlib +glfw==1.12.0 + # via mujoco +idna==3.10 + # via requests +imageio==2.35.1 + # via -r examples/libero/requirements.in +imageio-ffmpeg==0.5.1 + # via imageio +importlib-metadata==8.5.0 + # via typeguard +importlib-resources==6.4.5 + # via etils +kiwisolver==1.4.7 + # via matplotlib +llvmlite==0.36.0 + # via numba +markdown-it-py==3.0.0 + # via rich +matplotlib==3.5.3 + # via -r examples/libero/requirements.in +mdurl==0.1.2 + # via markdown-it-py +mujoco==3.2.3 + # via robosuite +numba==0.53.1 + # via robosuite +numpy==1.22.4 + # via + # -r examples/libero/requirements.in + # imageio + # matplotlib + # mujoco + # numba + # opencv-python + # robosuite + # scipy + # torchvision +opencv-python==4.6.0.66 + # via + # -r examples/libero/requirements.in + # robosuite +packaging==24.2 + # via matplotlib +pillow==10.4.0 + # via + # imageio + # matplotlib + # robosuite + # torchvision +psutil==6.1.0 + # via imageio +pygments==2.18.0 + # via rich +pynput==1.7.7 + # via robosuite +pyopengl==3.1.7 + # via mujoco +pyparsing==3.1.4 + # via matplotlib +python-dateutil==2.9.0.post0 + # via matplotlib +python-xlib==0.33 + # via pynput +pyyaml==6.0.2 + # via -r examples/libero/requirements.in +requests==2.32.3 + # via torchvision +rich==13.9.4 + # via tyro +robosuite==1.4.1 + # via -r examples/libero/requirements.in +scipy==1.10.1 + # via robosuite +setuptools==75.3.0 + # via + # imageio-ffmpeg + # numba +shtab==1.7.1 + # via tyro +six==1.17.0 + # via + # pynput + # python-dateutil + # python-xlib +termcolor==2.4.0 + # via robosuite +torch==1.11.0+cu113 + # via + # -r examples/libero/requirements.in + # torchaudio + # torchvision +torchaudio==0.11.0+cu113 + # via -r examples/libero/requirements.in +torchvision==0.12.0+cu113 + # via -r examples/libero/requirements.in +tqdm==4.67.1 + # via -r examples/libero/requirements.in +typeguard==4.4.0 + # via tyro +typing-extensions==4.12.2 + # via + # etils + # rich + # torch + # torchvision + # typeguard + # tyro +tyro==0.9.2 + # via -r examples/libero/requirements.in +urllib3==2.2.3 + # via requests +zipp==3.20.2 + # via + # etils + # importlib-metadata + # importlib-resources diff --git a/capvector-pi05/examples/simple_client/Dockerfile b/capvector-pi05/examples/simple_client/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..095712073f98d5dd9b639d6d365698b506d15827 --- /dev/null +++ b/capvector-pi05/examples/simple_client/Dockerfile @@ -0,0 +1,32 @@ +# Dockerfile for the simple client. + +# Build the container: +# docker build . -t simple_client -f examples/simple_client/Dockerfile + +# Run the container: +# docker run --rm -it --network=host -v .:/app simple_client /bin/bash + +FROM python:3.7-slim +COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/ + +WORKDIR /app + +# Copy from the cache instead of linking since it's a mounted volume +ENV UV_LINK_MODE=copy + +# Write the virtual environment outside of the project directory so it doesn't +# leak out of the container when we mount the application code. +ENV UV_PROJECT_ENVIRONMENT=/.venv + +# Copy the requirements files so we can install dependencies. +# The rest of the project is mounted as a volume, so we don't need to rebuild on changes. +# This strategy is best for development-style usage. +COPY ./examples/simple_client/requirements.txt /tmp/requirements.txt +COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml + +# Install python dependencies. +RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT +RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml +ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src + +CMD /bin/bash -c "source /.venv/bin/activate && python examples/simple_client/main.py $SERVER_ARGS" diff --git a/capvector-pi05/examples/simple_client/README.md b/capvector-pi05/examples/simple_client/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ea2fe5050c25cbfccea3f1c5c99c8b6ef0961caa --- /dev/null +++ b/capvector-pi05/examples/simple_client/README.md @@ -0,0 +1,30 @@ +# Simple Client + +A minimal client that sends observations to the server and prints the inference rate. + +You can specify which runtime environment to use using the `--env` flag. You can see the available options by running: + +```bash +uv run examples/simple_client/main.py --help +``` + +## With Docker + +```bash +export SERVER_ARGS="--env ALOHA_SIM" +docker compose -f examples/simple_client/compose.yml up --build +``` + +## Without Docker + +Terminal window 1: + +```bash +uv run examples/simple_client/main.py --env DROID +``` + +Terminal window 2: + +```bash +uv run scripts/serve_policy.py --env DROID +``` diff --git a/capvector-pi05/examples/simple_client/compose.yml b/capvector-pi05/examples/simple_client/compose.yml new file mode 100644 index 0000000000000000000000000000000000000000..109821bb445ff61c0ccb695fb5c43a1eb4220005 --- /dev/null +++ b/capvector-pi05/examples/simple_client/compose.yml @@ -0,0 +1,42 @@ +# Run with: +# docker compose -f examples/simple_client/compose.yml up --build +services: + runtime: + image: simple_client + depends_on: + - openpi_server + build: + context: ../.. + dockerfile: examples/simple_client/Dockerfile + init: true + tty: true + network_mode: host + volumes: + - $PWD:/app + environment: + - SERVER_ARGS + + openpi_server: + image: openpi_server + build: + context: ../.. + dockerfile: scripts/docker/serve_policy.Dockerfile + init: true + tty: true + network_mode: host + volumes: + - $PWD:/app + - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets + environment: + - SERVER_ARGS + - OPENPI_DATA_HOME=/openpi_assets + - IS_DOCKER=true + + # Comment out this block if not running on a machine with GPUs. + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] diff --git a/capvector-pi05/examples/simple_client/main.py b/capvector-pi05/examples/simple_client/main.py new file mode 100644 index 0000000000000000000000000000000000000000..3907706b164e55d04e948cae6620d940c69ca7fe --- /dev/null +++ b/capvector-pi05/examples/simple_client/main.py @@ -0,0 +1,187 @@ +import dataclasses +import enum +import logging +import pathlib +import time + +import numpy as np +from openpi_client import websocket_client_policy as _websocket_client_policy +import polars as pl +import rich +import tqdm +import tyro + +logger = logging.getLogger(__name__) + + +class EnvMode(enum.Enum): + """Supported environments.""" + + ALOHA = "aloha" + ALOHA_SIM = "aloha_sim" + DROID = "droid" + LIBERO = "libero" + + +@dataclasses.dataclass +class Args: + """Command line arguments.""" + + # Host and port to connect to the server. + host: str = "0.0.0.0" + # Port to connect to the server. If None, the server will use the default port. + port: int | None = 8000 + # API key to use for the server. + api_key: str | None = None + # Number of steps to run the policy for. + num_steps: int = 20 + # Path to save the timings to a parquet file. (e.g., timing.parquet) + timing_file: pathlib.Path | None = None + # Environment to run the policy in. + env: EnvMode = EnvMode.ALOHA_SIM + + +class TimingRecorder: + """Records timing measurements for different keys.""" + + def __init__(self) -> None: + self._timings: dict[str, list[float]] = {} + + def record(self, key: str, time_ms: float) -> None: + """Record a timing measurement for the given key.""" + if key not in self._timings: + self._timings[key] = [] + self._timings[key].append(time_ms) + + def get_stats(self, key: str) -> dict[str, float]: + """Get statistics for the given key.""" + times = self._timings[key] + return { + "mean": float(np.mean(times)), + "std": float(np.std(times)), + "p25": float(np.quantile(times, 0.25)), + "p50": float(np.quantile(times, 0.50)), + "p75": float(np.quantile(times, 0.75)), + "p90": float(np.quantile(times, 0.90)), + "p95": float(np.quantile(times, 0.95)), + "p99": float(np.quantile(times, 0.99)), + } + + def print_all_stats(self) -> None: + """Print statistics for all keys in a concise format.""" + + table = rich.table.Table( + title="[bold blue]Timing Statistics[/bold blue]", + show_header=True, + header_style="bold white", + border_style="blue", + title_justify="center", + ) + + # Add metric column with custom styling + table.add_column("Metric", style="cyan", justify="left", no_wrap=True) + + # Add statistical columns with consistent styling + stat_columns = [ + ("Mean", "yellow", "mean"), + ("Std", "yellow", "std"), + ("P25", "magenta", "p25"), + ("P50", "magenta", "p50"), + ("P75", "magenta", "p75"), + ("P90", "magenta", "p90"), + ("P95", "magenta", "p95"), + ("P99", "magenta", "p99"), + ] + + for name, style, _ in stat_columns: + table.add_column(name, justify="right", style=style, no_wrap=True) + + # Add rows for each metric with formatted values + for key in sorted(self._timings.keys()): + stats = self.get_stats(key) + values = [f"{stats[key]:.1f}" for _, _, key in stat_columns] + table.add_row(key, *values) + + # Print with custom console settings + console = rich.console.Console(width=None, highlight=True) + console.print(table) + + def write_parquet(self, path: pathlib.Path) -> None: + """Save the timings to a parquet file.""" + logger.info(f"Writing timings to {path}") + frame = pl.DataFrame(self._timings) + path.parent.mkdir(parents=True, exist_ok=True) + frame.write_parquet(path) + + +def main(args: Args) -> None: + obs_fn = { + EnvMode.ALOHA: _random_observation_aloha, + EnvMode.ALOHA_SIM: _random_observation_aloha, + EnvMode.DROID: _random_observation_droid, + EnvMode.LIBERO: _random_observation_libero, + }[args.env] + + policy = _websocket_client_policy.WebsocketClientPolicy( + host=args.host, + port=args.port, + api_key=args.api_key, + ) + logger.info(f"Server metadata: {policy.get_server_metadata()}") + + # Send a few observations to make sure the model is loaded. + for _ in range(2): + policy.infer(obs_fn()) + + timing_recorder = TimingRecorder() + + for _ in tqdm.trange(args.num_steps, desc="Running policy"): + inference_start = time.time() + action = policy.infer(obs_fn()) + timing_recorder.record("client_infer_ms", 1000 * (time.time() - inference_start)) + for key, value in action.get("server_timing", {}).items(): + timing_recorder.record(f"server_{key}", value) + for key, value in action.get("policy_timing", {}).items(): + timing_recorder.record(f"policy_{key}", value) + + timing_recorder.print_all_stats() + + if args.timing_file is not None: + timing_recorder.write_parquet(args.timing_file) + + +def _random_observation_aloha() -> dict: + return { + "state": np.ones((14,)), + "images": { + "cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), + "cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), + "cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), + "cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), + }, + "prompt": "do something", + } + + +def _random_observation_droid() -> dict: + return { + "observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), + "observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), + "observation/joint_position": np.random.rand(7), + "observation/gripper_position": np.random.rand(1), + "prompt": "do something", + } + + +def _random_observation_libero() -> dict: + return { + "observation/state": np.random.rand(8), + "observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), + "observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), + "prompt": "do something", + } + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + main(tyro.cli(Args)) diff --git a/capvector-pi05/examples/simple_client/requirements.in b/capvector-pi05/examples/simple_client/requirements.in new file mode 100644 index 0000000000000000000000000000000000000000..17ef4aef112d274624eba0503e00cc4aec44f7a6 --- /dev/null +++ b/capvector-pi05/examples/simple_client/requirements.in @@ -0,0 +1,5 @@ +numpy>=1.22.4,<2.0.0 +rich +tqdm +tyro +polars \ No newline at end of file diff --git a/capvector-pi05/examples/simple_client/requirements.txt b/capvector-pi05/examples/simple_client/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..416d9cd72e32f4c35c93dd0ac8e5d2144dc52513 --- /dev/null +++ b/capvector-pi05/examples/simple_client/requirements.txt @@ -0,0 +1,30 @@ +# This file was autogenerated by uv via the following command: +# uv pip compile examples/simple_client/requirements.in -o examples/simple_client/requirements.txt --python-version 3.11.9 +docstring-parser==0.16 + # via tyro +markdown-it-py==3.0.0 + # via rich +mdurl==0.1.2 + # via markdown-it-py +numpy==1.26.4 + # via -r examples/simple_client/requirements.in +polars==1.30.0 + # via -r examples/simple_client/requirements.in +pygments==2.19.1 + # via rich +rich==14.0.0 + # via + # -r examples/simple_client/requirements.in + # tyro +shtab==1.7.2 + # via tyro +tqdm==4.67.1 + # via -r examples/simple_client/requirements.in +typeguard==4.4.2 + # via tyro +typing-extensions==4.13.2 + # via + # typeguard + # tyro +tyro==0.9.22 + # via -r examples/simple_client/requirements.in diff --git a/capvector-pi05/examples/ur5/README.md b/capvector-pi05/examples/ur5/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e90ca6c3bc8d1135e92ae55651639305bc0e209f --- /dev/null +++ b/capvector-pi05/examples/ur5/README.md @@ -0,0 +1,142 @@ +# UR5 Example + +Below we provide an outline of how to implement the key components mentioned in the "Finetune on your data" section of the [README](../README.md) for finetuning on UR5 datasets. + +First, we will define the `UR5Inputs` and `UR5Outputs` classes, which map the UR5 environment to the model and vice versa. Check the corresponding files in `src/openpi/policies/libero_policy.py` for comments explaining each line. + +```python + +@dataclasses.dataclass(frozen=True) +class UR5Inputs(transforms.DataTransformFn): + + model_type: _model.ModelType = _model.ModelType.PI0 + + def __call__(self, data: dict) -> dict: + # First, concatenate the joints and gripper into the state vector. + state = np.concatenate([data["joints"], data["gripper"]]) + + # Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically + # stores as float32 (C,H,W), gets skipped for policy inference. + base_image = _parse_image(data["base_rgb"]) + wrist_image = _parse_image(data["wrist_rgb"]) + + # Create inputs dict. + inputs = { + "state": state, + "image": { + "base_0_rgb": base_image, + "left_wrist_0_rgb": wrist_image, + # Since there is no right wrist, replace with zeros + "right_wrist_0_rgb": np.zeros_like(base_image), + }, + "image_mask": { + "base_0_rgb": np.True_, + "left_wrist_0_rgb": np.True_, + # Since the "slot" for the right wrist is not used, this mask is set + # to False + "right_wrist_0_rgb": np.True_ if self.model_type == _model.ModelType.PI0_FAST else np.False_, + }, + } + + if "actions" in data: + inputs["actions"] = data["actions"] + + # Pass the prompt (aka language instruction) to the model. + if "prompt" in data: + inputs["prompt"] = data["prompt"] + + return inputs + + +@dataclasses.dataclass(frozen=True) +class UR5Outputs(transforms.DataTransformFn): + + def __call__(self, data: dict) -> dict: + # Since the robot has 7 action dimensions (6 DoF + gripper), return the first 7 dims + return {"actions": np.asarray(data["actions"][:, :7])} + +``` + +Next, we will define the `UR5DataConfig` class, which defines how to process raw UR5 data from LeRobot dataset for training. For a full example, see the `LeRobotLiberoDataConfig` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py). + +```python + +@dataclasses.dataclass(frozen=True) +class LeRobotUR5DataConfig(DataConfigFactory): + + @override + def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig: + # Boilerplate for remapping keys from the LeRobot dataset. We assume no renaming needed here. + repack_transform = _transforms.Group( + inputs=[ + _transforms.RepackTransform( + { + "base_rgb": "image", + "wrist_rgb": "wrist_image", + "joints": "joints", + "gripper": "gripper", + "prompt": "prompt", + } + ) + ] + ) + + # These transforms are the ones we wrote earlier. + data_transforms = _transforms.Group( + inputs=[UR5Inputs(action_dim=model_config.action_dim, model_type=model_config.model_type)], + outputs=[UR5Outputs()], + ) + + # Convert absolute actions to delta actions. + # By convention, we do not convert the gripper action (7th dimension). + delta_action_mask = _transforms.make_bool_mask(6, -1) + data_transforms = data_transforms.push( + inputs=[_transforms.DeltaActions(delta_action_mask)], + outputs=[_transforms.AbsoluteActions(delta_action_mask)], + ) + + # Model transforms include things like tokenizing the prompt and action targets + # You do not need to change anything here for your own dataset. + model_transforms = ModelTransformFactory()(model_config) + + # We return all data transforms for training and inference. No need to change anything here. + return dataclasses.replace( + self.create_base_config(assets_dirs), + repack_transforms=repack_transform, + data_transforms=data_transforms, + model_transforms=model_transforms, + ) + +``` + +Finally, we define the TrainConfig for our UR5 dataset. Here, we define a config for fine-tuning pi0 on our UR5 dataset. See the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py) for more examples, e.g. for pi0-FAST or for LoRA fine-tuning. + +```python +TrainConfig( + name="pi0_ur5", + model=pi0.Pi0Config(), + data=LeRobotUR5DataConfig( + repo_id="your_username/ur5_dataset", + # This config lets us reload the UR5 normalization stats from the base model checkpoint. + # Reloading normalization stats can help transfer pre-trained models to new environments. + # See the [norm_stats.md](../docs/norm_stats.md) file for more details. + assets=AssetsConfig( + assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets", + asset_id="ur5e", + ), + base_config=DataConfig( + # This flag determines whether we load the prompt (i.e. the task instruction) from the + # ``task`` field in the LeRobot dataset. The recommended setting is True. + prompt_from_task=True, + ), + ), + # Load the pi0 base model checkpoint. + weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"), + num_train_steps=30_000, +) +``` + + + + + diff --git a/capvector-pi05/packages/openpi-client/pyproject.toml b/capvector-pi05/packages/openpi-client/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..123c066e6e79a6d1bbc8e385ab99422d58e16acc --- /dev/null +++ b/capvector-pi05/packages/openpi-client/pyproject.toml @@ -0,0 +1,23 @@ +[project] +name = "openpi-client" +version = "0.1.0" +requires-python = ">=3.7" +dependencies = [ + "dm-tree>=0.1.8", + "msgpack>=1.0.5", + "numpy>=1.22.4,<2.0.0", + "pillow>=9.0.0", + "tree>=0.2.4", + "websockets>=11.0", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.uv] +dev-dependencies = ["pytest>=8.3.4"] + +[tool.ruff] +line-length = 120 +target-version = "py37" diff --git a/capvector-pi05/packages/openpi-client/src/openpi_client/__init__.py b/capvector-pi05/packages/openpi-client/src/openpi_client/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3f5c4a7d6e309ba9807642ee936d82cbc458017e --- /dev/null +++ b/capvector-pi05/packages/openpi-client/src/openpi_client/__init__.py @@ -0,0 +1 @@ +__version__ = "0.1.0" diff --git a/capvector-pi05/packages/openpi-client/src/openpi_client/action_chunk_broker.py b/capvector-pi05/packages/openpi-client/src/openpi_client/action_chunk_broker.py new file mode 100644 index 0000000000000000000000000000000000000000..9445a66815e15ee32ceb033d5a481b58053783fb --- /dev/null +++ b/capvector-pi05/packages/openpi-client/src/openpi_client/action_chunk_broker.py @@ -0,0 +1,50 @@ +from typing import Dict + +import numpy as np +import tree +from typing_extensions import override + +from openpi_client import base_policy as _base_policy + + +class ActionChunkBroker(_base_policy.BasePolicy): + """Wraps a policy to return action chunks one-at-a-time. + + Assumes that the first dimension of all action fields is the chunk size. + + A new inference call to the inner policy is only made when the current + list of chunks is exhausted. + """ + + def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int): + self._policy = policy + self._action_horizon = action_horizon + self._cur_step: int = 0 + + self._last_results: Dict[str, np.ndarray] | None = None + + @override + def infer(self, obs: Dict) -> Dict: # noqa: UP006 + if self._last_results is None: + self._last_results = self._policy.infer(obs) + self._cur_step = 0 + + def slicer(x): + if isinstance(x, np.ndarray): + return x[self._cur_step, ...] + else: + return x + + results = tree.map_structure(slicer, self._last_results) + self._cur_step += 1 + + if self._cur_step >= self._action_horizon: + self._last_results = None + + return results + + @override + def reset(self) -> None: + self._policy.reset() + self._last_results = None + self._cur_step = 0 diff --git a/capvector-pi05/packages/openpi-client/src/openpi_client/base_policy.py b/capvector-pi05/packages/openpi-client/src/openpi_client/base_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..1b14963fe90508b804b04c6480d82d5b3e2b5ca3 --- /dev/null +++ b/capvector-pi05/packages/openpi-client/src/openpi_client/base_policy.py @@ -0,0 +1,12 @@ +import abc +from typing import Dict + + +class BasePolicy(abc.ABC): + @abc.abstractmethod + def infer(self, obs: Dict) -> Dict: + """Infer actions from observations.""" + + def reset(self) -> None: + """Reset the policy to its initial state.""" + pass diff --git a/capvector-pi05/packages/openpi-client/src/openpi_client/image_tools.py b/capvector-pi05/packages/openpi-client/src/openpi_client/image_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..421532e216df6345902760be6cf1e22fc3c167fa --- /dev/null +++ b/capvector-pi05/packages/openpi-client/src/openpi_client/image_tools.py @@ -0,0 +1,78 @@ +import numpy as np +from PIL import Image + + +def convert_to_uint8(img: np.ndarray) -> np.ndarray: + """Converts an image to uint8 if it is a float image. + + This is important for reducing the size of the image when sending it over the network. + """ + if np.issubdtype(img.dtype, np.floating): + img = (255 * img).astype(np.uint8) + return img + + +def resize_with_pad(images: np.ndarray, height: int, width: int, method=Image.BILINEAR, return_mask=False) -> np.ndarray: + """Replicates tf.image.resize_with_pad for multiple images using PIL. Resizes a batch of images to a target height. + + Args: + images: A batch of images in [..., height, width, channel] format. + height: The target height of the image. + width: The target width of the image. + method: The interpolation method to use. Default is bilinear. + + Returns: + The resized images in [..., height, width, channel]. + """ + # If the images are already the correct size, return them as is. + if images.shape[-3:-1] == (height, width): + if return_mask: + img_padding_mask = np.ones((*images.shape[:-3], height, width), dtype=bool) + return images, img_padding_mask + return images + + original_shape = images.shape + + images = images.reshape(-1, *original_shape[-3:]) + + resized_results = [ + _resize_with_pad_pil(Image.fromarray(im), height, width, method=method) for im in images + ] + resized_images, img_padding_mask = zip(*resized_results) + resized_images = np.stack(resized_images) + img_padding_mask = np.stack(img_padding_mask) + + if return_mask: + return ( + resized_images.reshape(*original_shape[:-3], *resized_images.shape[-3:]), + img_padding_mask.reshape(*original_shape[:-3], *img_padding_mask.shape[-2:]), + ) + else: + return resized_images.reshape(*original_shape[:-3], *resized_images.shape[-3:]) + + +def _resize_with_pad_pil(image: Image.Image, height: int, width: int, method: int) -> Image.Image: + """Replicates tf.image.resize_with_pad for one image using PIL. Resizes an image to a target height and + width without distortion by padding with zeros. + + Unlike the jax version, note that PIL uses [width, height, channel] ordering instead of [batch, h, w, c]. + """ + cur_width, cur_height = image.size + if cur_width == width and cur_height == height: + return image # No need to resize if the image is already the correct size. + + ratio = max(cur_width / width, cur_height / height) + resized_height = int(cur_height / ratio) + resized_width = int(cur_width / ratio) + resized_image = image.resize((resized_width, resized_height), resample=method) + + zero_image = Image.new(resized_image.mode, (width, height), 0) + pad_height = max(0, int((height - resized_height) / 2)) + pad_width = max(0, int((width - resized_width) / 2)) + zero_image.paste(resized_image, (pad_width, pad_height)) + assert zero_image.size == (width, height) + + img_padding_mask = np.zeros((height, width), dtype=bool) + img_padding_mask[pad_height:pad_height+resized_height, pad_width:pad_width+resized_width] = True + + return zero_image, img_padding_mask diff --git a/capvector-pi05/packages/openpi-client/src/openpi_client/image_tools_test.py b/capvector-pi05/packages/openpi-client/src/openpi_client/image_tools_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1c8a2a26c04254f6246da50a2254f3c3c0c03c96 --- /dev/null +++ b/capvector-pi05/packages/openpi-client/src/openpi_client/image_tools_test.py @@ -0,0 +1,37 @@ +import numpy as np + +import openpi_client.image_tools as image_tools + + +def test_resize_with_pad_shapes(): + # Test case 1: Resize image with larger dimensions + images = np.zeros((2, 10, 10, 3), dtype=np.uint8) # Input images of shape (batch_size, height, width, channels) + height = 20 + width = 20 + resized_images = image_tools.resize_with_pad(images, height, width) + assert resized_images.shape == (2, height, width, 3) + assert np.all(resized_images == 0) + + # Test case 2: Resize image with smaller dimensions + images = np.zeros((3, 30, 30, 3), dtype=np.uint8) + height = 15 + width = 15 + resized_images = image_tools.resize_with_pad(images, height, width) + assert resized_images.shape == (3, height, width, 3) + assert np.all(resized_images == 0) + + # Test case 3: Resize image with the same dimensions + images = np.zeros((1, 50, 50, 3), dtype=np.uint8) + height = 50 + width = 50 + resized_images = image_tools.resize_with_pad(images, height, width) + assert resized_images.shape == (1, height, width, 3) + assert np.all(resized_images == 0) + + # Test case 3: Resize image with odd-numbered padding + images = np.zeros((1, 256, 320, 3), dtype=np.uint8) + height = 60 + width = 80 + resized_images = image_tools.resize_with_pad(images, height, width) + assert resized_images.shape == (1, height, width, 3) + assert np.all(resized_images == 0) diff --git a/capvector-pi05/packages/openpi-client/src/openpi_client/msgpack_numpy.py b/capvector-pi05/packages/openpi-client/src/openpi_client/msgpack_numpy.py new file mode 100644 index 0000000000000000000000000000000000000000..70e353a9762de8ea45988354ea5d044fc03a52b4 --- /dev/null +++ b/capvector-pi05/packages/openpi-client/src/openpi_client/msgpack_numpy.py @@ -0,0 +1,57 @@ +"""Adds NumPy array support to msgpack. + +msgpack is good for (de)serializing data over a network for multiple reasons: +- msgpack is secure (as opposed to pickle/dill/etc which allow for arbitrary code execution) +- msgpack is widely used and has good cross-language support +- msgpack does not require a schema (as opposed to protobuf/flatbuffers/etc) which is convenient in dynamically typed + languages like Python and JavaScript +- msgpack is fast and efficient (as opposed to readable formats like JSON/YAML/etc); I found that msgpack was ~4x faster + than pickle for serializing large arrays using the below strategy + +The code below is adapted from https://github.com/lebedov/msgpack-numpy. The reason not to use that library directly is +that it falls back to pickle for object arrays. +""" + +import functools + +import msgpack +import numpy as np + + +def pack_array(obj): + if (isinstance(obj, (np.ndarray, np.generic))) and obj.dtype.kind in ("V", "O", "c"): + raise ValueError(f"Unsupported dtype: {obj.dtype}") + + if isinstance(obj, np.ndarray): + return { + b"__ndarray__": True, + b"data": obj.tobytes(), + b"dtype": obj.dtype.str, + b"shape": obj.shape, + } + + if isinstance(obj, np.generic): + return { + b"__npgeneric__": True, + b"data": obj.item(), + b"dtype": obj.dtype.str, + } + + return obj + + +def unpack_array(obj): + if b"__ndarray__" in obj: + return np.ndarray(buffer=obj[b"data"], dtype=np.dtype(obj[b"dtype"]), shape=obj[b"shape"]) + + if b"__npgeneric__" in obj: + return np.dtype(obj[b"dtype"]).type(obj[b"data"]) + + return obj + + +Packer = functools.partial(msgpack.Packer, default=pack_array) +packb = functools.partial(msgpack.packb, default=pack_array) + +Unpacker = functools.partial(msgpack.Unpacker, object_hook=unpack_array) +unpackb = functools.partial(msgpack.unpackb, object_hook=unpack_array) diff --git a/capvector-pi05/packages/openpi-client/src/openpi_client/msgpack_numpy_test.py b/capvector-pi05/packages/openpi-client/src/openpi_client/msgpack_numpy_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d0d0b027c3ba77269151c2274226a6baadb410a4 --- /dev/null +++ b/capvector-pi05/packages/openpi-client/src/openpi_client/msgpack_numpy_test.py @@ -0,0 +1,45 @@ +import numpy as np +import pytest +import tree + +from openpi_client import msgpack_numpy + + +def _check(expected, actual): + if isinstance(expected, np.ndarray): + assert expected.shape == actual.shape + assert expected.dtype == actual.dtype + assert np.array_equal(expected, actual, equal_nan=expected.dtype.kind == "f") + else: + assert expected == actual + + +@pytest.mark.parametrize( + "data", + [ + 1, # int + 1.0, # float + "hello", # string + np.bool_(True), # boolean scalar + np.array([1, 2, 3])[0], # int scalar + np.str_("asdf"), # string scalar + [1, 2, 3], # list + {"key": "value"}, # dict + {"key": [1, 2, 3]}, # nested dict + np.array(1.0), # 0D array + np.array([1, 2, 3], dtype=np.int32), # 1D integer array + np.array(["asdf", "qwer"]), # string array + np.array([True, False]), # boolean array + np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), # 2D float array + np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.int16), # 3D integer array + np.array([np.nan, np.inf, -np.inf]), # special float values + {"arr": np.array([1, 2, 3]), "nested": {"arr": np.array([4, 5, 6])}}, # nested dict with arrays + [np.array([1, 2]), np.array([3, 4])], # list of arrays + np.zeros((3, 4, 5), dtype=np.float32), # 3D zeros + np.ones((2, 3), dtype=np.float64), # 2D ones with double precision + ], +) +def test_pack_unpack(data): + packed = msgpack_numpy.packb(data) + unpacked = msgpack_numpy.unpackb(packed) + tree.map_structure(_check, data, unpacked) diff --git a/capvector-pi05/packages/openpi-client/src/openpi_client/runtime/agent.py b/capvector-pi05/packages/openpi-client/src/openpi_client/runtime/agent.py new file mode 100644 index 0000000000000000000000000000000000000000..d09d57ddf0e670a7630b7bff95175984c3f9212e --- /dev/null +++ b/capvector-pi05/packages/openpi-client/src/openpi_client/runtime/agent.py @@ -0,0 +1,17 @@ +import abc + + +class Agent(abc.ABC): + """An Agent is the thing with agency, i.e. the entity that makes decisions. + + Agents receive observations about the state of the world, and return actions + to take in response. + """ + + @abc.abstractmethod + def get_action(self, observation: dict) -> dict: + """Query the agent for the next action.""" + + @abc.abstractmethod + def reset(self) -> None: + """Reset the agent to its initial state.""" diff --git a/capvector-pi05/packages/openpi-client/src/openpi_client/runtime/agents/policy_agent.py b/capvector-pi05/packages/openpi-client/src/openpi_client/runtime/agents/policy_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..2fff4f87f7072aa055dad04bb11c0524e385eaf4 --- /dev/null +++ b/capvector-pi05/packages/openpi-client/src/openpi_client/runtime/agents/policy_agent.py @@ -0,0 +1,18 @@ +from typing_extensions import override + +from openpi_client import base_policy as _base_policy +from openpi_client.runtime import agent as _agent + + +class PolicyAgent(_agent.Agent): + """An agent that uses a policy to determine actions.""" + + def __init__(self, policy: _base_policy.BasePolicy) -> None: + self._policy = policy + + @override + def get_action(self, observation: dict) -> dict: + return self._policy.infer(observation) + + def reset(self) -> None: + self._policy.reset() diff --git a/capvector-pi05/packages/openpi-client/src/openpi_client/runtime/environment.py b/capvector-pi05/packages/openpi-client/src/openpi_client/runtime/environment.py new file mode 100644 index 0000000000000000000000000000000000000000..4b29f594f247700981fa87ff46a4500f060be052 --- /dev/null +++ b/capvector-pi05/packages/openpi-client/src/openpi_client/runtime/environment.py @@ -0,0 +1,32 @@ +import abc + + +class Environment(abc.ABC): + """An Environment represents the robot and the environment it inhabits. + + The primary contract of environments is that they can be queried for observations + about their state, and have actions applied to them to change that state. + """ + + @abc.abstractmethod + def reset(self) -> None: + """Reset the environment to its initial state. + + This will be called once before starting each episode. + """ + + @abc.abstractmethod + def is_episode_complete(self) -> bool: + """Allow the environment to signal that the episode is complete. + + This will be called after each step. It should return `True` if the episode is + complete (either successfully or unsuccessfully), and `False` otherwise. + """ + + @abc.abstractmethod + def get_observation(self) -> dict: + """Query the environment for the current state.""" + + @abc.abstractmethod + def apply_action(self, action: dict) -> None: + """Take an action in the environment.""" diff --git a/capvector-pi05/packages/openpi-client/src/openpi_client/runtime/runtime.py b/capvector-pi05/packages/openpi-client/src/openpi_client/runtime/runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..d480c2ebb01fc559a128c5e338a8be74ff8e55d3 --- /dev/null +++ b/capvector-pi05/packages/openpi-client/src/openpi_client/runtime/runtime.py @@ -0,0 +1,92 @@ +import logging +import threading +import time + +from openpi_client.runtime import agent as _agent +from openpi_client.runtime import environment as _environment +from openpi_client.runtime import subscriber as _subscriber + + +class Runtime: + """The core module orchestrating interactions between key components of the system.""" + + def __init__( + self, + environment: _environment.Environment, + agent: _agent.Agent, + subscribers: list[_subscriber.Subscriber], + max_hz: float = 0, + num_episodes: int = 1, + max_episode_steps: int = 0, + ) -> None: + self._environment = environment + self._agent = agent + self._subscribers = subscribers + self._max_hz = max_hz + self._num_episodes = num_episodes + self._max_episode_steps = max_episode_steps + + self._in_episode = False + self._episode_steps = 0 + + def run(self) -> None: + """Runs the runtime loop continuously until stop() is called or the environment is done.""" + for _ in range(self._num_episodes): + self._run_episode() + + # Final reset, this is important for real environments to move the robot to its home position. + self._environment.reset() + + def run_in_new_thread(self) -> threading.Thread: + """Runs the runtime loop in a new thread.""" + thread = threading.Thread(target=self.run) + thread.start() + return thread + + def mark_episode_complete(self) -> None: + """Marks the end of an episode.""" + self._in_episode = False + + def _run_episode(self) -> None: + """Runs a single episode.""" + logging.info("Starting episode...") + self._environment.reset() + self._agent.reset() + for subscriber in self._subscribers: + subscriber.on_episode_start() + + self._in_episode = True + self._episode_steps = 0 + step_time = 1 / self._max_hz if self._max_hz > 0 else 0 + last_step_time = time.time() + + while self._in_episode: + self._step() + self._episode_steps += 1 + + # Sleep to maintain the desired frame rate + now = time.time() + dt = now - last_step_time + if dt < step_time: + time.sleep(step_time - dt) + last_step_time = time.time() + else: + last_step_time = now + + logging.info("Episode completed.") + for subscriber in self._subscribers: + subscriber.on_episode_end() + + def _step(self) -> None: + """A single step of the runtime loop.""" + observation = self._environment.get_observation() + action = self._agent.get_action(observation) + self._environment.apply_action(action) + + for subscriber in self._subscribers: + subscriber.on_step(observation, action) + + if self._environment.is_episode_complete() or ( + self._max_episode_steps > 0 and self._episode_steps >= self._max_episode_steps + ): + self.mark_episode_complete() diff --git a/capvector-pi05/packages/openpi-client/src/openpi_client/runtime/subscriber.py b/capvector-pi05/packages/openpi-client/src/openpi_client/runtime/subscriber.py new file mode 100644 index 0000000000000000000000000000000000000000..e11b583aa2c4c962df7ed7907f5070ef30b97ef5 --- /dev/null +++ b/capvector-pi05/packages/openpi-client/src/openpi_client/runtime/subscriber.py @@ -0,0 +1,20 @@ +import abc + + +class Subscriber(abc.ABC): + """Subscribes to events in the runtime. + + Subscribers can be used to save data, visualize, etc. + """ + + @abc.abstractmethod + def on_episode_start(self) -> None: + """Called when an episode starts.""" + + @abc.abstractmethod + def on_step(self, observation: dict, action: dict) -> None: + """Append a step to the episode.""" + + @abc.abstractmethod + def on_episode_end(self) -> None: + """Called when an episode ends.""" diff --git a/capvector-pi05/packages/openpi-client/src/openpi_client/websocket_client_policy.py b/capvector-pi05/packages/openpi-client/src/openpi_client/websocket_client_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..6cd20760b0fb1ef93622626e17fad2f63146dbe7 --- /dev/null +++ b/capvector-pi05/packages/openpi-client/src/openpi_client/websocket_client_policy.py @@ -0,0 +1,55 @@ +import logging +import time +from typing import Dict, Optional, Tuple + +from typing_extensions import override +import websockets.sync.client + +from openpi_client import base_policy as _base_policy +from openpi_client import msgpack_numpy + + +class WebsocketClientPolicy(_base_policy.BasePolicy): + """Implements the Policy interface by communicating with a server over websocket. + + See WebsocketPolicyServer for a corresponding server implementation. + """ + + def __init__(self, host: str = "0.0.0.0", port: Optional[int] = None, api_key: Optional[str] = None) -> None: + self._uri = f"ws://{host}" + if port is not None: + self._uri += f":{port}" + self._packer = msgpack_numpy.Packer() + self._api_key = api_key + self._ws, self._server_metadata = self._wait_for_server() + + def get_server_metadata(self) -> Dict: + return self._server_metadata + + def _wait_for_server(self) -> Tuple[websockets.sync.client.ClientConnection, Dict]: + logging.info(f"Waiting for server at {self._uri}...") + while True: + try: + headers = {"Authorization": f"Api-Key {self._api_key}"} if self._api_key else None + conn = websockets.sync.client.connect( + self._uri, compression=None, max_size=None, additional_headers=headers + ) + metadata = msgpack_numpy.unpackb(conn.recv()) + return conn, metadata + except ConnectionRefusedError: + logging.info("Still waiting for server...") + time.sleep(5) + + @override + def infer(self, obs: Dict) -> Dict: # noqa: UP006 + data = self._packer.pack(obs) + self._ws.send(data) + response = self._ws.recv() + if isinstance(response, str): + # we're expecting bytes; if the server sends a string, it's an error. + raise RuntimeError(f"Error in inference server:\n{response}") + return msgpack_numpy.unpackb(response) + + @override + def reset(self) -> None: + pass diff --git a/capvector-pi05/scripts/__init__.py b/capvector-pi05/scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/capvector-pi05/scripts/compute_norm_stats.py b/capvector-pi05/scripts/compute_norm_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..07ccb5e9ac4acb5c8c43ac64fbc2684edf8d7495 --- /dev/null +++ b/capvector-pi05/scripts/compute_norm_stats.py @@ -0,0 +1,117 @@ +"""Compute normalization statistics for a config. + +This script is used to compute the normalization statistics for a given config. It +will compute the mean and standard deviation of the data in the dataset and save it +to the config assets directory. +""" + +import numpy as np +import tqdm +import tyro + +import openpi.models.model as _model +import openpi.shared.normalize as normalize +import openpi.training.config as _config +import openpi.training.data_loader as _data_loader +import openpi.transforms as transforms + + +class RemoveStrings(transforms.DataTransformFn): + def __call__(self, x: dict) -> dict: + return {k: v for k, v in x.items() if not np.issubdtype(np.asarray(v).dtype, np.str_)} + + +def create_torch_dataloader( + data_config: _config.DataConfig, + action_horizon: int, + batch_size: int, + model_config: _model.BaseModelConfig, + num_workers: int, + max_frames: int | None = None, +) -> tuple[_data_loader.Dataset, int]: + if data_config.repo_id is None: + raise ValueError("Data config must have a repo_id") + dataset = _data_loader.create_torch_dataset(data_config, action_horizon, model_config) + dataset = _data_loader.TransformedDataset( + dataset, + [ + *data_config.repack_transforms.inputs, + *data_config.data_transforms.inputs, + # Remove strings since they are not supported by JAX and are not needed to compute norm stats. + RemoveStrings(), + ], + ) + if max_frames is not None and max_frames < len(dataset): + num_batches = max_frames // batch_size + shuffle = True + else: + num_batches = len(dataset) // batch_size + shuffle = False + data_loader = _data_loader.TorchDataLoader( + dataset, + local_batch_size=batch_size, + num_workers=num_workers, + shuffle=shuffle, + num_batches=num_batches, + ) + return data_loader, num_batches + + +def create_rlds_dataloader( + data_config: _config.DataConfig, + action_horizon: int, + batch_size: int, + max_frames: int | None = None, +) -> tuple[_data_loader.Dataset, int]: + dataset = _data_loader.create_rlds_dataset(data_config, action_horizon, batch_size, shuffle=False) + dataset = _data_loader.IterableTransformedDataset( + dataset, + [ + *data_config.repack_transforms.inputs, + *data_config.data_transforms.inputs, + # Remove strings since they are not supported by JAX and are not needed to compute norm stats. + RemoveStrings(), + ], + is_batched=True, + ) + if max_frames is not None and max_frames < len(dataset): + num_batches = max_frames // batch_size + else: + # NOTE: this length is currently hard-coded for DROID. + num_batches = len(dataset) // batch_size + data_loader = _data_loader.RLDSDataLoader( + dataset, + num_batches=num_batches, + ) + return data_loader, num_batches + + +def main(config_name: str, max_frames: int | None = None): + config = _config.get_config(config_name) + data_config = config.data.create(config.assets_dirs, config.model) + + if data_config.rlds_data_dir is not None: + data_loader, num_batches = create_rlds_dataloader( + data_config, config.model.action_horizon, config.batch_size, max_frames + ) + else: + data_loader, num_batches = create_torch_dataloader( + data_config, config.model.action_horizon, config.batch_size, config.model, config.num_workers, max_frames + ) + + keys = ["state", "actions"] + stats = {key: normalize.RunningStats() for key in keys} + + for batch in tqdm.tqdm(data_loader, total=num_batches, desc="Computing stats"): + for key in keys: + stats[key].update(np.asarray(batch[key])) + + norm_stats = {key: stats.get_statistics() for key, stats in stats.items()} + + output_path = config.assets_dirs / data_config.repo_id + print(f"Writing stats to: {output_path}") + normalize.save(output_path, norm_stats) + + +if __name__ == "__main__": + tyro.cli(main) diff --git a/capvector-pi05/scripts/docker/compose.yml b/capvector-pi05/scripts/docker/compose.yml new file mode 100644 index 0000000000000000000000000000000000000000..3655b85cf287df0f4e4e586bae97a6c607841516 --- /dev/null +++ b/capvector-pi05/scripts/docker/compose.yml @@ -0,0 +1,29 @@ +# Run with: +# docker compose -f scripts/docker/compose.yml up --build +services: + openpi_server: + image: openpi_server + build: + context: ../.. + dockerfile: scripts/docker/serve_policy.Dockerfile + init: true + tty: true + network_mode: host + # Populate configured openpi data home to /openpi_assets inside the container. + # Populate aws credential inside the container. + volumes: + - $PWD:/app + - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets + environment: + - SERVER_ARGS + - OPENPI_DATA_HOME=/openpi_assets + - IS_DOCKER=true + + # Comment out this block if not running on a machine with GPUs. + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] diff --git a/capvector-pi05/scripts/docker/install_docker_ubuntu22.sh b/capvector-pi05/scripts/docker/install_docker_ubuntu22.sh new file mode 100644 index 0000000000000000000000000000000000000000..cdda7fd608abde9aa99ab9c47049db6ae59a90db --- /dev/null +++ b/capvector-pi05/scripts/docker/install_docker_ubuntu22.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +# Add Docker's official GPG key: +sudo apt-get update +sudo apt-get install -y ca-certificates curl +sudo install -m 0755 -d /etc/apt/keyrings +sudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc +sudo chmod a+r /etc/apt/keyrings/docker.asc + +# Add the repository to Apt sources: +echo \ + "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \ + $(. /etc/os-release && echo "$VERSION_CODENAME") stable" | + sudo tee /etc/apt/sources.list.d/docker.list >/dev/null +sudo apt-get update + +sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin + +# Add current user to the 'docker' group, which allows them to use docker commands (docker build, docker run, etc). +# See https://docs.docker.com/engine/install/linux-postinstall/ +username=$(whoami) +sudo usermod -aG docker $username + +# Configure docker to start automatically on system boot. +sudo systemctl enable docker.service +sudo systemctl enable containerd.service + +# https://forums.docker.com/t/docker-credential-desktop-exe-executable-file-not-found-in-path-using-wsl2/100225/5 +if [ ~/.docker/config.json ]; then + sed -i 's/credsStore/credStore/g' ~/.docker/config.json +fi + +echo "" +echo "********************************************************************" +echo "**** Restart to allow Docker permission changes to take effect. ****" +echo "********************************************************************" +echo "" diff --git a/capvector-pi05/scripts/docker/install_nvidia_container_toolkit.sh b/capvector-pi05/scripts/docker/install_nvidia_container_toolkit.sh new file mode 100644 index 0000000000000000000000000000000000000000..1a1583309d936ad358551f7224bbce0d3bf5c9d1 --- /dev/null +++ b/capvector-pi05/scripts/docker/install_nvidia_container_toolkit.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +# Installs the NVIDIA Container Toolkit, which allows Docker containers to access NVIDIA GPUs. +# NVIDIA's official documentation: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html + +curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg && + curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | + sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | + sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list + +# NVIDIA's documenation omits 'sudo' in the following command, but it is required. +sudo sed -i -e '/experimental/ s/^#//g' /etc/apt/sources.list.d/nvidia-container-toolkit.list +sudo apt-get update +sudo apt-get install -y nvidia-container-toolkit + +sudo nvidia-ctk runtime configure --runtime=docker +sudo systemctl restart docker diff --git a/capvector-pi05/scripts/docker/serve_policy.Dockerfile b/capvector-pi05/scripts/docker/serve_policy.Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..4060254f052a48a7346b700394537a422dfb88ee --- /dev/null +++ b/capvector-pi05/scripts/docker/serve_policy.Dockerfile @@ -0,0 +1,38 @@ +# Dockerfile for serving a PI policy. +# Based on UV's instructions: https://docs.astral.sh/uv/guides/integration/docker/#developing-in-a-container + +# Build the container: +# docker build . -t openpi_server -f scripts/docker/serve_policy.Dockerfile + +# Run the container: +# docker run --rm -it --network=host -v .:/app --gpus=all openpi_server /bin/bash + +FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0 +COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/ + +WORKDIR /app + +# Needed because LeRobot uses git-lfs. +RUN apt-get update && apt-get install -y git git-lfs linux-headers-generic build-essential clang + +# Copy from the cache instead of linking since it's a mounted volume +ENV UV_LINK_MODE=copy + +# Write the virtual environment outside of the project directory so it doesn't +# leak out of the container when we mount the application code. +ENV UV_PROJECT_ENVIRONMENT=/.venv + +# Install the project's dependencies using the lockfile and settings +RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT +RUN --mount=type=cache,target=/root/.cache/uv \ + --mount=type=bind,source=uv.lock,target=uv.lock \ + --mount=type=bind,source=pyproject.toml,target=pyproject.toml \ + --mount=type=bind,source=packages/openpi-client/pyproject.toml,target=packages/openpi-client/pyproject.toml \ + --mount=type=bind,source=packages/openpi-client/src,target=packages/openpi-client/src \ + GIT_LFS_SKIP_SMUDGE=1 uv sync --frozen --no-install-project --no-dev + +# Copy transformers_replace files while preserving directory structure +COPY src/openpi/models_pytorch/transformers_replace/ /tmp/transformers_replace/ +RUN /.venv/bin/python -c "import transformers; print(transformers.__file__)" | xargs dirname | xargs -I{} cp -r /tmp/transformers_replace/* {} && rm -rf /tmp/transformers_replace + +CMD /bin/bash -c "uv run scripts/serve_policy.py $SERVER_ARGS" diff --git a/capvector-pi05/scripts/serve_policy.py b/capvector-pi05/scripts/serve_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..edabae3deb3d7cc1c79a7fbb9e5a6059a8d82c01 --- /dev/null +++ b/capvector-pi05/scripts/serve_policy.py @@ -0,0 +1,122 @@ +import dataclasses +import enum +import logging +import socket + +import tyro + +from openpi.policies import policy as _policy +from openpi.policies import policy_config as _policy_config +from openpi.serving import websocket_policy_server +from openpi.training import config as _config + + +class EnvMode(enum.Enum): + """Supported environments.""" + + ALOHA = "aloha" + ALOHA_SIM = "aloha_sim" + DROID = "droid" + LIBERO = "libero" + + +@dataclasses.dataclass +class Checkpoint: + """Load a policy from a trained checkpoint.""" + + # Training config name (e.g., "pi0_aloha_sim"). + config: str + # Checkpoint directory (e.g., "checkpoints/pi0_aloha_sim/exp/10000"). + dir: str + + +@dataclasses.dataclass +class Default: + """Use the default policy for the given environment.""" + + +@dataclasses.dataclass +class Args: + """Arguments for the serve_policy script.""" + + # Environment to serve the policy for. This is only used when serving default policies. + env: EnvMode = EnvMode.ALOHA_SIM + + # If provided, will be used in case the "prompt" key is not present in the data, or if the model doesn't have a default + # prompt. + default_prompt: str | None = None + + # Port to serve the policy on. + port: int = 8000 + # Record the policy's behavior for debugging. + record: bool = False + + # Specifies how to load the policy. If not provided, the default policy for the environment will be used. + policy: Checkpoint | Default = dataclasses.field(default_factory=Default) + + +# Default checkpoints that should be used for each environment. +DEFAULT_CHECKPOINT: dict[EnvMode, Checkpoint] = { + EnvMode.ALOHA: Checkpoint( + config="pi05_aloha", + dir="gs://openpi-assets/checkpoints/pi05_base", + ), + EnvMode.ALOHA_SIM: Checkpoint( + config="pi0_aloha_sim", + dir="gs://openpi-assets/checkpoints/pi0_aloha_sim", + ), + EnvMode.DROID: Checkpoint( + config="pi05_droid", + dir="gs://openpi-assets/checkpoints/pi05_droid", + ), + EnvMode.LIBERO: Checkpoint( + config="pi05_libero", + dir="gs://openpi-assets/checkpoints/pi05_libero", + ), +} + + +def create_default_policy(env: EnvMode, *, default_prompt: str | None = None) -> _policy.Policy: + """Create a default policy for the given environment.""" + if checkpoint := DEFAULT_CHECKPOINT.get(env): + return _policy_config.create_trained_policy( + _config.get_config(checkpoint.config), checkpoint.dir, default_prompt=default_prompt + ) + raise ValueError(f"Unsupported environment mode: {env}") + + +def create_policy(args: Args) -> _policy.Policy: + """Create a policy from the given arguments.""" + match args.policy: + case Checkpoint(): + return _policy_config.create_trained_policy( + _config.get_config(args.policy.config), args.policy.dir, default_prompt=args.default_prompt + ) + case Default(): + return create_default_policy(args.env, default_prompt=args.default_prompt) + + +def main(args: Args) -> None: + policy = create_policy(args) + policy_metadata = policy.metadata + + # Record the policy's behavior. + if args.record: + policy = _policy.PolicyRecorder(policy, "policy_records") + + hostname = socket.gethostname() + local_ip = socket.gethostbyname(hostname) + logging.info("Creating server (host: %s, ip: %s)", hostname, local_ip) + + server = websocket_policy_server.WebsocketPolicyServer( + policy=policy, + host="0.0.0.0", + port=args.port, + metadata=policy_metadata, + ) + server.serve_forever() + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, force=True) + main(tyro.cli(Args)) diff --git a/capvector-pi05/scripts/train.py b/capvector-pi05/scripts/train.py new file mode 100644 index 0000000000000000000000000000000000000000..3d37bb20f2f8dfadf94e77e30fe22a4b747fd137 --- /dev/null +++ b/capvector-pi05/scripts/train.py @@ -0,0 +1,280 @@ +import dataclasses +import functools +import logging +import platform +from typing import Any + +import etils.epath as epath +import flax.nnx as nnx +from flax.training import common_utils +import flax.traverse_util as traverse_util +import jax +import jax.experimental +import jax.numpy as jnp +import numpy as np +import optax +import tqdm_loggable.auto as tqdm +import wandb + +import openpi.models.model as _model +import openpi.shared.array_typing as at +import openpi.shared.nnx_utils as nnx_utils +import openpi.training.checkpoints as _checkpoints +import openpi.training.config as _config +import openpi.training.data_loader as _data_loader +import openpi.training.optimizer as _optimizer +import openpi.training.sharding as sharding +import openpi.training.utils as training_utils +import openpi.training.weight_loaders as _weight_loaders + + +def init_logging(): + """Custom logging format for better readability.""" + level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"} + + class CustomFormatter(logging.Formatter): + def format(self, record): + record.levelname = level_mapping.get(record.levelname, record.levelname) + return super().format(record) + + formatter = CustomFormatter( + fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)", + datefmt="%H:%M:%S", + ) + + logger = logging.getLogger() + logger.setLevel(logging.INFO) + logger.handlers[0].setFormatter(formatter) + + +def init_wandb(config: _config.TrainConfig, *, resuming: bool, log_code: bool = False, enabled: bool = True): + if not enabled: + wandb.init(mode="disabled") + return + + ckpt_dir = config.checkpoint_dir + if not ckpt_dir.exists(): + raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.") + if resuming: + run_id = (ckpt_dir / "wandb_id.txt").read_text().strip() + wandb.init(id=run_id, resume="must", project=config.project_name) + else: + wandb.init( + name=config.exp_name, + config=dataclasses.asdict(config), + project=config.project_name, + ) + (ckpt_dir / "wandb_id.txt").write_text(wandb.run.id) + + if log_code: + wandb.run.log_code(epath.Path(__file__).parent.parent) + + +def _load_weights_and_validate(loader: _weight_loaders.WeightLoader, params_shape: at.Params) -> at.Params: + """Loads and validates the weights. Returns a loaded subset of the weights.""" + loaded_params = loader.load(params_shape) + at.check_pytree_equality(expected=params_shape, got=loaded_params, check_shapes=True, check_dtypes=True) + + # Remove jax.ShapeDtypeStruct from the loaded params. This makes sure that only the loaded params are returned. + return traverse_util.unflatten_dict( + {k: v for k, v in traverse_util.flatten_dict(loaded_params).items() if not isinstance(v, jax.ShapeDtypeStruct)} + ) + + +@at.typecheck +def init_train_state( + config: _config.TrainConfig, init_rng: at.KeyArrayLike, mesh: jax.sharding.Mesh, *, resume: bool +) -> tuple[training_utils.TrainState, Any]: + tx = _optimizer.create_optimizer(config.optimizer, config.lr_schedule, weight_decay_mask=None) + + def init(rng: at.KeyArrayLike, partial_params: at.Params | None = None) -> training_utils.TrainState: + rng, model_rng = jax.random.split(rng) + # initialize the model (and its parameters). + model = config.model.create(model_rng) + + # Merge the partial params into the model. + if partial_params is not None: + graphdef, state = nnx.split(model) + # This will produce an error if the partial params are not a subset of the state. + state.replace_by_pure_dict(partial_params) + model = nnx.merge(graphdef, state) + + params = nnx.state(model) + # Convert frozen params to bfloat16. + params = nnx_utils.state_map(params, config.freeze_filter, lambda p: p.replace(p.value.astype(jnp.bfloat16))) + + return training_utils.TrainState( + step=0, + params=params, + model_def=nnx.graphdef(model), + tx=tx, + opt_state=tx.init(params.filter(config.trainable_filter)), + ema_decay=config.ema_decay, + ema_params=None if config.ema_decay is None else params, + ) + + train_state_shape = jax.eval_shape(init, init_rng) + state_sharding = sharding.fsdp_sharding(train_state_shape, mesh, log=True) + + if resume: + return train_state_shape, state_sharding + + partial_params = _load_weights_and_validate(config.weight_loader, train_state_shape.params.to_pure_dict()) + replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) + + # Initialize the train state and mix in the partial params. + train_state = jax.jit( + init, + donate_argnums=(1,), # donate the partial params buffer. + in_shardings=replicated_sharding, + out_shardings=state_sharding, + )(init_rng, partial_params) + + return train_state, state_sharding + + +@at.typecheck +def train_step( + config: _config.TrainConfig, + rng: at.KeyArrayLike, + state: training_utils.TrainState, + batch: tuple[_model.Observation, _model.Actions], +) -> tuple[training_utils.TrainState, dict[str, at.Array]]: + model = nnx.merge(state.model_def, state.params) + model.train() + + @at.typecheck + def loss_fn( + model: _model.BaseModel, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions + ): + chunked_loss = model.compute_loss(rng, observation, actions, train=True) + return jnp.mean(chunked_loss) + + train_rng = jax.random.fold_in(rng, state.step) + observation, actions = batch + + # Filter out frozen params. + diff_state = nnx.DiffState(0, config.trainable_filter) + loss, grads = nnx.value_and_grad(loss_fn, argnums=diff_state)(model, train_rng, observation, actions) + + params = state.params.filter(config.trainable_filter) + updates, new_opt_state = state.tx.update(grads, state.opt_state, params) + new_params = optax.apply_updates(params, updates) + + # Update the model in place and return the new full state. + nnx.update(model, new_params) + new_params = nnx.state(model) + + new_state = dataclasses.replace(state, step=state.step + 1, params=new_params, opt_state=new_opt_state) + if state.ema_decay is not None: + new_state = dataclasses.replace( + new_state, + ema_params=jax.tree.map( + lambda old, new: state.ema_decay * old + (1 - state.ema_decay) * new, state.ema_params, new_params + ), + ) + + # Filter out params that aren't kernels. + kernel_params = nnx.state( + model, + nnx.All( + nnx.Param, + nnx.Not(nnx_utils.PathRegex(".*/(bias|scale|pos_embedding|input_embedding)")), + lambda _, x: x.value.ndim > 1, + ), + ) + info = { + "loss": loss, + "grad_norm": optax.global_norm(grads), + "param_norm": optax.global_norm(kernel_params), + } + return new_state, info + + +def main(config: _config.TrainConfig): + init_logging() + logging.info(f"Running on: {platform.node()}") + + if config.batch_size % jax.device_count() != 0: + raise ValueError( + f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}." + ) + + jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax").expanduser())) + + rng = jax.random.key(config.seed) + train_rng, init_rng = jax.random.split(rng) + + mesh = sharding.make_mesh(config.fsdp_devices) + data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(sharding.DATA_AXIS)) + replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) + + checkpoint_manager, resuming = _checkpoints.initialize_checkpoint_dir( + config.checkpoint_dir, + keep_period=config.keep_period, + overwrite=config.overwrite, + resume=config.resume, + ) + init_wandb(config, resuming=resuming, enabled=config.wandb_enabled) + + data_loader = _data_loader.create_data_loader( + config, + sharding=data_sharding, + shuffle=True, + ) + data_iter = iter(data_loader) + batch = next(data_iter) + logging.info(f"Initialized data loader:\n{training_utils.array_tree_to_info(batch)}") + + # Log images from first batch to sanity check. + images_to_log = [ + wandb.Image(np.concatenate([np.array(img[i]) for img in batch[0].images.values()], axis=1)) + for i in range(min(5, len(next(iter(batch[0].images.values()))))) + ] + wandb.log({"camera_views": images_to_log}, step=0) + + train_state, train_state_sharding = init_train_state(config, init_rng, mesh, resume=resuming) + jax.block_until_ready(train_state) + logging.info(f"Initialized train state:\n{training_utils.array_tree_to_info(train_state.params)}") + + if resuming: + train_state = _checkpoints.restore_state(checkpoint_manager, train_state, data_loader) + + ptrain_step = jax.jit( + functools.partial(train_step, config), + in_shardings=(replicated_sharding, train_state_sharding, data_sharding), + out_shardings=(train_state_sharding, replicated_sharding), + donate_argnums=(1,), + ) + + start_step = int(train_state.step) + pbar = tqdm.tqdm( + range(start_step, config.num_train_steps), + initial=start_step, + total=config.num_train_steps, + dynamic_ncols=True, + ) + + infos = [] + for step in pbar: + with sharding.set_mesh(mesh): + train_state, info = ptrain_step(train_rng, train_state, batch) + infos.append(info) + if step % config.log_interval == 0: + stacked_infos = common_utils.stack_forest(infos) + reduced_info = jax.device_get(jax.tree.map(jnp.mean, stacked_infos)) + info_str = ", ".join(f"{k}={v:.4f}" for k, v in reduced_info.items()) + pbar.write(f"Step {step}: {info_str}") + wandb.log(reduced_info, step=step) + infos = [] + batch = next(data_iter) + + if (step % config.save_interval == 0 and step > start_step) or step == config.num_train_steps - 1: + _checkpoints.save_state(checkpoint_manager, train_state, data_loader, step) + + logging.info("Waiting for checkpoint manager to finish") + checkpoint_manager.wait_until_finished() + + +if __name__ == "__main__": + main(_config.cli()) diff --git a/capvector-pi05/scripts/train_align_pytorch.py b/capvector-pi05/scripts/train_align_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..ff1479ce7927f49e0b46e382cf7c57ef0ea86e72 --- /dev/null +++ b/capvector-pi05/scripts/train_align_pytorch.py @@ -0,0 +1,658 @@ +""" +PyTorch training entrypoint for PI0/PI05 with multi-GPU and multi-node (DDP) support. +This script mirrors the behavior of the JAX trainer (`scripts/train.py`) but runs +entirely in PyTorch using the `PI0Pytorch` model and your existing config/data +pipeline from `src/openpi/training/config.py` and `src/openpi/training/data_loader.py`. + +Usage +Single GPU: + python scripts/train_pytorch.py --exp_name --save_interval + Example: + python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test + python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test --resume # Resume from latest checkpoint +Multi-GPU (single node): + torchrun --standalone --nnodes=1 --nproc_per_node= scripts/train_pytorch.py --exp_name + Example: + torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test + torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume +Multi-Node Training: + torchrun \ + --nnodes= --nproc_per_node= --node_rank= \ + --master_addr= --master_port= \ + scripts/train_pytorch.py --exp_name= --save_interval + +""" + +import dataclasses +import gc +import logging +import os +import platform +import shutil +import time + +import jax +import numpy as np +import safetensors.torch +import torch +import torch.distributed as dist +import torch.nn.parallel +import tqdm +import wandb + +import openpi.models.pi0_config +from openpi.models_pytorch import pi0_pytorch, pi0_align_pytorch, projectors +import openpi.shared.normalize as _normalize +import openpi.training.config as _config +import openpi.training.data_loader as _data + +from vggt.models.vggt import VGGT + + +def init_logging(): + level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"} + + class CustomFormatter(logging.Formatter): + def format(self, record): + record.levelname = level_mapping.get(record.levelname, record.levelname) + return super().format(record) + + formatter = CustomFormatter( + fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)", + datefmt="%H:%M:%S", + ) + logger = logging.getLogger() + logger.setLevel(logging.INFO) + if not logger.handlers: + ch = logging.StreamHandler() + ch.setFormatter(formatter) + logger.addHandler(ch) + else: + logger.handlers[0].setFormatter(formatter) + + +def init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: bool = True): + """Initialize wandb logging.""" + if not enabled: + wandb.init(mode="disabled") + return + + ckpt_dir = config.checkpoint_dir + if not ckpt_dir.exists(): + raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.") + + if resuming: + run_id = (ckpt_dir / "wandb_id.txt").read_text().strip() + wandb.init(id=run_id, resume="must", project=config.project_name) + else: + wandb.init( + name=config.exp_name, + config=dataclasses.asdict(config), + project=config.project_name, + ) + (ckpt_dir / "wandb_id.txt").write_text(wandb.run.id) + + +def setup_ddp(): + world_size = int(os.environ.get("WORLD_SIZE", "1")) + use_ddp = world_size > 1 + if use_ddp and not torch.distributed.is_initialized(): + backend = "nccl" if torch.cuda.is_available() else "gloo" + torch.distributed.init_process_group(backend=backend, init_method="env://") + + # Set up debugging environment variables for DDP issues + if os.environ.get("TORCH_DISTRIBUTED_DEBUG") is None: + os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO" + + local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0"))) + device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") + if torch.cuda.is_available(): + torch.cuda.set_device(device) + return use_ddp, local_rank, device + + +def cleanup_ddp(): + if torch.distributed.is_initialized(): + torch.distributed.barrier() + torch.distributed.destroy_process_group() + + +def set_seed(seed: int, local_rank: int): + torch.manual_seed(seed + local_rank) + np.random.seed(seed + local_rank) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed + local_rank) + + +def build_datasets(config: _config.TrainConfig): + # Use the unified data loader with PyTorch framework + data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=True) + return data_loader, data_loader.data_config() + + +def get_model_state_dict(model): + """Get state dict from model, handling DDP wrapper.""" + return ( + model.module.state_dict() + if isinstance(model, torch.nn.parallel.DistributedDataParallel) + else model.state_dict() + ) + + +def get_model_parameters(model): + """Get parameters from model, handling DDP wrapper.""" + return ( + model.module.parameters() + if isinstance(model, torch.nn.parallel.DistributedDataParallel) + else model.parameters() + ) + + +def save_checkpoint(model, optimizer, global_step, config, is_main, data_config): + """Save a checkpoint with model state, optimizer state, and metadata.""" + if not is_main: + return + + # Only save if it's time to save or if it's the final step + if (global_step % config.save_interval == 0 and global_step > 0) or global_step == config.num_train_steps - 1: + # Create temporary directory for atomic checkpoint saving + final_ckpt_dir = config.checkpoint_dir / f"{global_step}" + tmp_ckpt_dir = config.checkpoint_dir / f"tmp_{global_step}" + + # Remove any existing temp directory and create new one + if tmp_ckpt_dir.exists(): + shutil.rmtree(tmp_ckpt_dir) + tmp_ckpt_dir.mkdir(parents=True, exist_ok=True) + + # Save model state using safetensors (handle shared tensors) + model_to_save = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model + safetensors.torch.save_model(model_to_save, tmp_ckpt_dir / "model.safetensors") + + # Save optimizer state using PyTorch format + torch.save(optimizer.state_dict(), tmp_ckpt_dir / "optimizer.pt") + + # Save training metadata (avoid saving full config to prevent JAX/Flax compatibility issues) + metadata = { + "global_step": global_step, + "config": dataclasses.asdict(config), + "timestamp": time.time(), + } + torch.save(metadata, tmp_ckpt_dir / "metadata.pt") + + # save norm stats + norm_stats = data_config.norm_stats + if norm_stats is not None and data_config.asset_id is not None: + _normalize.save(tmp_ckpt_dir / "assets" / data_config.asset_id, norm_stats) + + # Atomically move temp directory to final location + if final_ckpt_dir.exists(): + shutil.rmtree(final_ckpt_dir) + tmp_ckpt_dir.rename(final_ckpt_dir) + + logging.info(f"Saved checkpoint at step {global_step} -> {final_ckpt_dir}") + + # Log checkpoint to wandb + if config.wandb_enabled: + wandb.log({"checkpoint_step": global_step}, step=global_step) + + +def load_checkpoint(model, optimizer, checkpoint_dir, device): + """Load the latest checkpoint and return the global step.""" + checkpoint_steps = [ + int(d.name) + for d in checkpoint_dir.iterdir() + if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_") + ] + + if not checkpoint_steps: + raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") + + latest_step = max(checkpoint_steps) + ckpt_dir = checkpoint_dir / f"{latest_step}" + + # Clear memory before loading checkpoints + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + log_memory_usage(device, latest_step, "before_loading_checkpoint") + + try: + # Load model state with error handling + logging.info("Loading model state...") + safetensors_path = ckpt_dir / "model.safetensors" + + if safetensors_path.exists(): + model_to_load = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model + safetensors.torch.load_model(model_to_load, safetensors_path, device=str(device)) + logging.info("Loaded model state from safetensors format") + else: + raise FileNotFoundError(f"No model checkpoint found at {ckpt_dir}") + + torch.cuda.empty_cache() + gc.collect() + log_memory_usage(device, latest_step, "after_loading_model") + + # Load optimizer state with error handling + logging.info("Loading optimizer state...") + optimizer_path = ckpt_dir / "optimizer.pt" + + if optimizer_path.exists(): + optimizer_state_dict = torch.load(optimizer_path, map_location=device, weights_only=False) + logging.info("Loaded optimizer state from pt format") + else: + raise FileNotFoundError(f"No optimizer checkpoint found at {ckpt_dir}") + + optimizer.load_state_dict(optimizer_state_dict) + del optimizer_state_dict + torch.cuda.empty_cache() + gc.collect() + log_memory_usage(device, latest_step, "after_loading_optimizer") + + # Load metadata + logging.info("Loading metadata...") + metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False) + global_step = metadata.get("global_step", latest_step) + del metadata + torch.cuda.empty_cache() + gc.collect() + log_memory_usage(device, latest_step, "after_loading_metadata") + + logging.info(f"Successfully loaded all checkpoint components from step {latest_step}") + return global_step + + except RuntimeError as e: + if "out of memory" in str(e): + # Clear memory and provide detailed error message + torch.cuda.empty_cache() + gc.collect() + logging.error(f"Out of memory error while loading checkpoint: {e!s}") + log_memory_usage(device, latest_step, "after_oom_error") + raise RuntimeError( + "Out of memory while loading checkpoint. Try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True" + ) from e + raise + + +def get_latest_checkpoint_step(checkpoint_dir): + """Get the latest checkpoint step number from a checkpoint directory.""" + checkpoint_steps = [ + int(d.name) + for d in checkpoint_dir.iterdir() + if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_") + ] + return max(checkpoint_steps) if checkpoint_steps else None + + +def log_memory_usage(device, step, phase="unknown"): + """Log detailed memory usage information.""" + if not torch.cuda.is_available(): + return + + memory_allocated = torch.cuda.memory_allocated(device) / 1e9 + memory_reserved = torch.cuda.memory_reserved(device) / 1e9 + memory_free = torch.cuda.memory_reserved(device) - torch.cuda.memory_allocated(device) + memory_free = memory_free / 1e9 + + # Get more detailed memory info + memory_stats = torch.cuda.memory_stats(device) + max_memory_allocated = memory_stats.get("allocated_bytes.all.peak", 0) / 1e9 + max_memory_reserved = memory_stats.get("reserved_bytes.all.peak", 0) / 1e9 + + # Get DDP info if available + ddp_info = "" + if dist.is_initialized(): + ddp_info = f" | DDP: rank={dist.get_rank()}, world_size={dist.get_world_size()}" + + logging.info( + f"Step {step} ({phase}): GPU memory - allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB, free: {memory_free:.2f}GB, peak_allocated: {max_memory_allocated:.2f}GB, peak_reserved: {max_memory_reserved:.2f}GB{ddp_info}" + ) + + +def train_loop(config: _config.TrainConfig): + use_ddp, local_rank, device = setup_ddp() + is_main = (not use_ddp) or (dist.get_rank() == 0) + set_seed(config.seed, local_rank) + + # Initialize checkpoint directory and wandb + resuming = False + if config.resume: + # Find checkpoint directory based on experiment name + exp_checkpoint_dir = config.checkpoint_dir + if exp_checkpoint_dir.exists(): + # Use validation to find the latest working checkpoint + latest_step = get_latest_checkpoint_step(exp_checkpoint_dir) + if latest_step is not None: + resuming = True + logging.info( + f"Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}" + ) + else: + raise FileNotFoundError(f"No valid checkpoints found in {exp_checkpoint_dir} for resume") + else: + raise FileNotFoundError(f"Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume") + elif config.overwrite and config.checkpoint_dir.exists(): + shutil.rmtree(config.checkpoint_dir) + logging.info(f"Overwriting checkpoint directory: {config.checkpoint_dir}") + + # Create checkpoint directory with experiment name + if not resuming: + # For new runs, create experiment-specific checkpoint directory + exp_checkpoint_dir = config.checkpoint_dir + exp_checkpoint_dir.mkdir(parents=True, exist_ok=True) + logging.info(f"Created experiment checkpoint directory: {exp_checkpoint_dir}") + else: + # For resume, checkpoint_dir is already set to the experiment directory + logging.info(f"Using existing experiment checkpoint directory: {config.checkpoint_dir}") + + # Initialize wandb (only on main process) + if is_main: + init_wandb(config, resuming=resuming, enabled=config.wandb_enabled) + + # Build data loader using the unified data loader + # Calculate effective batch size per GPU for DDP + # For N GPUs, each GPU should get batch_size/N samples, so total across all GPUs is batch_size + world_size = torch.distributed.get_world_size() if use_ddp else 1 + effective_batch_size = config.batch_size // world_size + logging.info( + f"Using batch size per GPU: {effective_batch_size} (total batch size across {world_size} GPUs: {config.batch_size})" + ) + + # Pass the original batch size to data loader - it will handle DDP splitting internally + loader, data_config = build_datasets(config) + + # Log sample images to wandb on first batch + if is_main and config.wandb_enabled and not resuming: + # Create a separate data loader for sample batch to avoid consuming the main loader + sample_data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=False) + sample_batch = next(iter(sample_data_loader)) + # Convert observation and actions to torch tensors + observation, actions = sample_batch + sample_batch = observation.to_dict() + sample_batch["actions"] = actions + + # Create sample images for wandb + images_to_log = [] + # Get batch size from the first image tensor + batch_size = next(iter(sample_batch["image"].values())).shape[0] + for i in range(min(5, batch_size)): + # Concatenate all camera views horizontally for this batch item + # Convert from NCHW to NHWC format for wandb + img_concatenated = torch.cat([img[i].permute(1, 2, 0) for img in sample_batch["image"].values()], axis=1) + img_concatenated = img_concatenated.cpu().numpy() + images_to_log.append(wandb.Image(img_concatenated)) + + wandb.log({"camera_views": images_to_log}, step=0) + + # Clear sample batch from memory aggressively + del sample_batch, observation, actions, images_to_log, img_concatenated + del sample_data_loader # Also delete the sample data loader + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + logging.info("Cleared sample batch and data loader from memory") + + # Build model + if not isinstance(config.model, openpi.models.pi0_config.Pi0Config): + # Convert dataclass to Pi0Config if needed + model_cfg = openpi.models.pi0_config.Pi0Config( + dtype=config.pytorch_training_precision, + action_dim=config.model.action_dim, + action_horizon=config.model.action_horizon, + max_token_len=config.model.max_token_len, + paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"), + action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"), + pi05=getattr(config.model, "pi05", False), + ) + else: + model_cfg = config.model + # Update dtype to match pytorch_training_precision + object.__setattr__(model_cfg, "dtype", config.pytorch_training_precision) + + model = openpi.models_pytorch.pi0_align_pytorch.PI0Pytorch(model_cfg, config).to(device) + vggt_model = VGGT( + enable_camera=False, + enable_point=False, + enable_depth=False, + enable_track=False, + feature_only=True, + ).to(device) + align_projector = projectors.AlignProjector( + model.LLM_width, + config.vggt_dim, + config.use_vlm_norm).to(device) + + if hasattr(model, "gradient_checkpointing_enable"): + enable_gradient_checkpointing = True + model.gradient_checkpointing_enable() + logging.info("Enabled gradient checkpointing for memory optimization") + else: + enable_gradient_checkpointing = False + logging.info("Gradient checkpointing is not supported for this model") + + # Log initial memory usage after model creation + if is_main and torch.cuda.is_available(): + log_memory_usage(device, 0, "after_model_creation") + + # Enable memory optimizations for large-scale training + if world_size >= 8: + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + # Set memory allocation configuration + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True" + logging.info("Enabled memory optimizations for 8+ GPU training") + + if use_ddp: + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[device.index] if device.type == "cuda" else None, + find_unused_parameters=True, # Disable for memory efficiency + gradient_as_bucket_view=True, # Enable for memory efficiency + static_graph=world_size >= 8, # Enable for 8+ GPUs + ) + align_projector = torch.nn.parallel.DistributedDataParallel( + align_projector, + device_ids=[device.index] if device.type == "cuda" else None, + find_unused_parameters=True, # Disable for memory efficiency + gradient_as_bucket_view=True, # Enable for memory efficiency + static_graph=world_size >= 8, # Enable for 8+ GPUs + ) + + # Load weights from weight_loader if specified (for fine-tuning) + if config.pytorch_weight_path is not None: + logging.info(f"Loading weights from: {config.pytorch_weight_path}") + model_path = os.path.join(config.pytorch_weight_path, "model.safetensors") + safetensors.torch.load_model( + (model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model), + model_path, + strict=False, + ) + logging.info(f"Loaded PyTorch weights from {config.pytorch_weight_path}") + if config.vggt_weight_path is not None: + vggt_path = os.path.join(config.vggt_weight_path, "model.pt") + if not os.path.exists(vggt_path): + raise FileNotFoundError(f"VGGT weight file not found at {vggt_path}") + vggt_model.load_state_dict(torch.load(vggt_path), strict=False) + logging.info(f"Loaded VGGT weights from {config.vggt_weight_path}") + + # Optimizer + learning rate schedule from config + warmup_steps = config.lr_schedule.warmup_steps + peak_lr = config.lr_schedule.peak_lr + decay_steps = config.lr_schedule.decay_steps + end_lr = config.lr_schedule.decay_lr + + # Create optimizer with config parameters + optim = torch.optim.AdamW( + list(model.parameters()) + list(align_projector.parameters()), + lr=peak_lr, + betas=(config.optimizer.b1, config.optimizer.b2), + eps=config.optimizer.eps, + weight_decay=config.optimizer.weight_decay, + ) + + # Load checkpoint if resuming + global_step = 0 + if resuming: + global_step = load_checkpoint(model, optim, config.checkpoint_dir, device) + logging.info(f"Resumed training from step {global_step}") + + def lr_schedule(step: int): + if step < warmup_steps: + # Match JAX behavior: start from peak_lr / (warmup_steps + 1) + init_lr = peak_lr / (warmup_steps + 1) + return init_lr + (peak_lr - init_lr) * step / warmup_steps + # cosine decay + progress = min(1.0, (step - warmup_steps) / max(1, decay_steps - warmup_steps)) + cos = 0.5 * (1 + np.cos(np.pi * progress)) + return end_lr + (peak_lr - end_lr) * cos + + model.train() + align_projector.train() + vggt_model.eval() + start_time = time.time() + infos = [] # Collect stats over log interval + if is_main: + logging.info( + f"Running on: {platform.node()} | world_size={torch.distributed.get_world_size() if use_ddp else 1}" + ) + logging.info( + f"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}" + ) + logging.info(f"Memory optimizations: gradient_checkpointing={enable_gradient_checkpointing}") + logging.info( + f"LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}" + ) + logging.info( + f"Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}" + ) + logging.info("EMA is not supported for PyTorch training") + logging.info(f"Training precision: {model_cfg.dtype}") + + # Training loop - iterate until we reach num_train_steps + pbar = ( + tqdm.tqdm(total=config.num_train_steps, initial=global_step, desc="Training", disable=not is_main) + if is_main + else None + ) + + while global_step < config.num_train_steps: + # Set epoch for distributed training + if use_ddp and hasattr(loader, "set_epoch"): + loader.set_epoch(global_step // len(loader)) + + for observation, actions in loader: + # Check if we've reached the target number of steps + if global_step >= config.num_train_steps: + break + + # The unified data loader returns (observation, actions) tuple + observation = jax.tree.map(lambda x: x.to(device), observation) # noqa: PLW2901 + actions = actions.to(torch.float32) # noqa: PLW2901 + actions = actions.to(device) # noqa: PLW2901 + + # Update LR + for pg in optim.param_groups: + pg["lr"] = lr_schedule(global_step) + + # Forward pass + action_losses, align_loss = model(observation, actions, vggt=vggt_model, align_proj=align_projector) + loss = action_losses + config.align_loss_coeff * align_loss + + # Backward pass + loss.backward() + + # Log memory usage after backward pass + if global_step < 5 and is_main and torch.cuda.is_available(): + log_memory_usage(device, global_step, "after_backward") + + # Gradient clipping + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optimizer.clip_gradient_norm) + + # Optimizer step + optim.step() + optim.zero_grad(set_to_none=True) + + # Clear gradients more aggressively + for param in model.parameters(): + if param.grad is not None: + param.grad.detach_() + param.grad = None + + # Collect stats + if is_main: + infos.append( + { + "action_loss": action_losses.item(), + "align_loss": align_loss.item(), + "learning_rate": optim.param_groups[0]["lr"], + "grad_norm": float(grad_norm) if isinstance(grad_norm, torch.Tensor) else grad_norm, + } + ) + + if is_main and (global_step % config.log_interval == 0): + elapsed = time.time() - start_time + + # Average stats over log interval + avg_loss = sum(info["action_loss"] for info in infos) / len(infos) + avg_align_loss = sum(info["align_loss"] for info in infos) / len(infos) + avg_lr = sum(info["learning_rate"] for info in infos) / len(infos) + + avg_grad_norm = None + if any("grad_norm" in info for info in infos): + vals = [ + info["grad_norm"] for info in infos if "grad_norm" in info and info["grad_norm"] is not None + ] + if len(vals) > 0: + avg_grad_norm = sum(vals) / len(vals) + logging.info( + f"step={global_step} action_loss={avg_loss:.4f} align_loss={avg_align_loss:.4f} lr={avg_lr:.2e} grad_norm={avg_grad_norm:.2f} time={elapsed:.1f}s" + if avg_grad_norm is not None + else f"step={global_step} action_loss={avg_loss:.4f} align_loss={avg_align_loss:.4f} lr={avg_lr:.2e} time={elapsed:.1f}s" + ) + + # Log to wandb + if config.wandb_enabled and len(infos) > 0: + log_payload = { + "action_loss": avg_loss, + "align_loss": avg_align_loss, + "learning_rate": avg_lr, + "step": global_step, + "time_per_step": elapsed / config.log_interval, + } + if avg_grad_norm is not None: + log_payload["grad_norm"] = avg_grad_norm + wandb.log(log_payload, step=global_step) + + start_time = time.time() + infos = [] # Reset stats collection + + global_step += 1 + # Save checkpoint using the new mechanism + save_checkpoint(model, optim, global_step, config, is_main, data_config) + + # Update progress bar + if pbar is not None: + pbar.update(1) + pbar.set_postfix( + {"loss": f"{loss.item():.4f}", "lr": f"{optim.param_groups[0]['lr']:.2e}", "step": global_step} + ) + + # Close progress bar + if pbar is not None: + pbar.close() + + # Finish wandb run + if is_main and config.wandb_enabled: + wandb.finish() + + cleanup_ddp() + + +def main(): + init_logging() + config = _config.cli() + train_loop(config) + + +if __name__ == "__main__": + main() diff --git a/capvector-pi05/scripts/train_pytorch.py b/capvector-pi05/scripts/train_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..a03c206466e99c506a0debc0ec51b5b3302d0249 --- /dev/null +++ b/capvector-pi05/scripts/train_pytorch.py @@ -0,0 +1,632 @@ +""" +PyTorch training entrypoint for PI0/PI05 with multi-GPU and multi-node (DDP) support. +This script mirrors the behavior of the JAX trainer (`scripts/train.py`) but runs +entirely in PyTorch using the `PI0Pytorch` model and your existing config/data +pipeline from `src/openpi/training/config.py` and `src/openpi/training/data_loader.py`. + +Usage +Single GPU: + python scripts/train_pytorch.py --exp_name --save_interval + Example: + python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test + python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test --resume # Resume from latest checkpoint +Multi-GPU (single node): + torchrun --standalone --nnodes=1 --nproc_per_node= scripts/train_pytorch.py --exp_name + Example: + torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test + torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume +Multi-Node Training: + torchrun \ + --nnodes= --nproc_per_node= --node_rank= \ + --master_addr= --master_port= \ + scripts/train_pytorch.py --exp_name= --save_interval + +""" + +import dataclasses +import gc +import logging +import os +import platform +import shutil +import time + +import jax +import numpy as np +import safetensors.torch +import torch +import torch.distributed as dist +import torch.nn.parallel +import tqdm +import wandb + +import openpi.models.pi0_config +import openpi.models_pytorch.pi0_pytorch +import openpi.shared.normalize as _normalize +import openpi.training.config as _config +import openpi.training.data_loader as _data + + +def init_logging(): + level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"} + + class CustomFormatter(logging.Formatter): + def format(self, record): + record.levelname = level_mapping.get(record.levelname, record.levelname) + return super().format(record) + + formatter = CustomFormatter( + fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)", + datefmt="%H:%M:%S", + ) + logger = logging.getLogger() + logger.setLevel(logging.INFO) + if not logger.handlers: + ch = logging.StreamHandler() + ch.setFormatter(formatter) + logger.addHandler(ch) + else: + logger.handlers[0].setFormatter(formatter) + + +def init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: bool = True): + """Initialize wandb logging.""" + if not enabled: + wandb.init(mode="disabled") + return + + ckpt_dir = config.checkpoint_dir + if not ckpt_dir.exists(): + raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.") + + if resuming: + run_id = (ckpt_dir / "wandb_id.txt").read_text().strip() + wandb.init(id=run_id, resume="must", project=config.project_name) + else: + wandb.init( + name=config.exp_name, + config=dataclasses.asdict(config), + project=config.project_name, + ) + (ckpt_dir / "wandb_id.txt").write_text(wandb.run.id) + + +def setup_ddp(): + world_size = int(os.environ.get("WORLD_SIZE", "1")) + use_ddp = world_size > 1 + if use_ddp and not torch.distributed.is_initialized(): + backend = "nccl" if torch.cuda.is_available() else "gloo" + torch.distributed.init_process_group(backend=backend, init_method="env://") + + # Set up debugging environment variables for DDP issues + if os.environ.get("TORCH_DISTRIBUTED_DEBUG") is None: + os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO" + + local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0"))) + device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") + if torch.cuda.is_available(): + torch.cuda.set_device(device) + return use_ddp, local_rank, device + + +def cleanup_ddp(): + if torch.distributed.is_initialized(): + torch.distributed.barrier() + torch.distributed.destroy_process_group() + + +def set_seed(seed: int, local_rank: int): + torch.manual_seed(seed + local_rank) + np.random.seed(seed + local_rank) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed + local_rank) + + +def build_datasets(config: _config.TrainConfig): + # Use the unified data loader with PyTorch framework + data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=True) + return data_loader, data_loader.data_config() + + +def get_model_state_dict(model): + """Get state dict from model, handling DDP wrapper.""" + return ( + model.module.state_dict() + if isinstance(model, torch.nn.parallel.DistributedDataParallel) + else model.state_dict() + ) + + +def get_model_parameters(model): + """Get parameters from model, handling DDP wrapper.""" + return ( + model.module.parameters() + if isinstance(model, torch.nn.parallel.DistributedDataParallel) + else model.parameters() + ) + + +def save_checkpoint(model, optimizer, global_step, config, is_main, data_config): + """Save a checkpoint with model state, optimizer state, and metadata.""" + if not is_main: + return + + # Only save if it's time to save or if it's the final step + if (global_step % config.save_interval == 0 and global_step > 0) or global_step == config.num_train_steps - 1: + # Create temporary directory for atomic checkpoint saving + final_ckpt_dir = config.checkpoint_dir / f"{global_step}" + tmp_ckpt_dir = config.checkpoint_dir / f"tmp_{global_step}" + + # Remove any existing temp directory and create new one + if tmp_ckpt_dir.exists(): + shutil.rmtree(tmp_ckpt_dir) + tmp_ckpt_dir.mkdir(parents=True, exist_ok=True) + + # Save model state using safetensors (handle shared tensors) + model_to_save = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model + safetensors.torch.save_model(model_to_save, tmp_ckpt_dir / "model.safetensors") + + # Save optimizer state using PyTorch format + torch.save(optimizer.state_dict(), tmp_ckpt_dir / "optimizer.pt") + + # Save training metadata (avoid saving full config to prevent JAX/Flax compatibility issues) + metadata = { + "global_step": global_step, + "config": dataclasses.asdict(config), + "timestamp": time.time(), + } + torch.save(metadata, tmp_ckpt_dir / "metadata.pt") + + # save norm stats + norm_stats = data_config.norm_stats + if norm_stats is not None and data_config.asset_id is not None: + _normalize.save(tmp_ckpt_dir / "assets" / data_config.asset_id, norm_stats) + + # Atomically move temp directory to final location + if final_ckpt_dir.exists(): + shutil.rmtree(final_ckpt_dir) + tmp_ckpt_dir.rename(final_ckpt_dir) + + logging.info(f"Saved checkpoint at step {global_step} -> {final_ckpt_dir}") + + # Log checkpoint to wandb + if config.wandb_enabled: + wandb.log({"checkpoint_step": global_step}, step=global_step) + + +def load_checkpoint(model, optimizer, checkpoint_dir, device): + """Load the latest checkpoint and return the global step.""" + checkpoint_steps = [ + int(d.name) + for d in checkpoint_dir.iterdir() + if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_") + ] + + if not checkpoint_steps: + raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") + + latest_step = max(checkpoint_steps) + ckpt_dir = checkpoint_dir / f"{latest_step}" + + # Clear memory before loading checkpoints + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + log_memory_usage(device, latest_step, "before_loading_checkpoint") + + try: + # Load model state with error handling + logging.info("Loading model state...") + safetensors_path = ckpt_dir / "model.safetensors" + + if safetensors_path.exists(): + model_to_load = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model + safetensors.torch.load_model(model_to_load, safetensors_path, device=str(device)) + logging.info("Loaded model state from safetensors format") + else: + raise FileNotFoundError(f"No model checkpoint found at {ckpt_dir}") + + torch.cuda.empty_cache() + gc.collect() + log_memory_usage(device, latest_step, "after_loading_model") + + # Load optimizer state with error handling + logging.info("Loading optimizer state...") + optimizer_path = ckpt_dir / "optimizer.pt" + + if optimizer_path.exists(): + optimizer_state_dict = torch.load(optimizer_path, map_location=device, weights_only=False) + logging.info("Loaded optimizer state from pt format") + else: + raise FileNotFoundError(f"No optimizer checkpoint found at {ckpt_dir}") + + optimizer.load_state_dict(optimizer_state_dict) + del optimizer_state_dict + torch.cuda.empty_cache() + gc.collect() + log_memory_usage(device, latest_step, "after_loading_optimizer") + + # Load metadata + logging.info("Loading metadata...") + metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False) + global_step = metadata.get("global_step", latest_step) + del metadata + torch.cuda.empty_cache() + gc.collect() + log_memory_usage(device, latest_step, "after_loading_metadata") + + logging.info(f"Successfully loaded all checkpoint components from step {latest_step}") + return global_step + + except RuntimeError as e: + if "out of memory" in str(e): + # Clear memory and provide detailed error message + torch.cuda.empty_cache() + gc.collect() + logging.error(f"Out of memory error while loading checkpoint: {e!s}") + log_memory_usage(device, latest_step, "after_oom_error") + raise RuntimeError( + "Out of memory while loading checkpoint. Try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True" + ) from e + raise + + +def get_latest_checkpoint_step(checkpoint_dir): + """Get the latest checkpoint step number from a checkpoint directory.""" + checkpoint_steps = [ + int(d.name) + for d in checkpoint_dir.iterdir() + if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_") + ] + return max(checkpoint_steps) if checkpoint_steps else None + + +def log_memory_usage(device, step, phase="unknown"): + """Log detailed memory usage information.""" + if not torch.cuda.is_available(): + return + + memory_allocated = torch.cuda.memory_allocated(device) / 1e9 + memory_reserved = torch.cuda.memory_reserved(device) / 1e9 + memory_free = torch.cuda.memory_reserved(device) - torch.cuda.memory_allocated(device) + memory_free = memory_free / 1e9 + + # Get more detailed memory info + memory_stats = torch.cuda.memory_stats(device) + max_memory_allocated = memory_stats.get("allocated_bytes.all.peak", 0) / 1e9 + max_memory_reserved = memory_stats.get("reserved_bytes.all.peak", 0) / 1e9 + + # Get DDP info if available + ddp_info = "" + if dist.is_initialized(): + ddp_info = f" | DDP: rank={dist.get_rank()}, world_size={dist.get_world_size()}" + + logging.info( + f"Step {step} ({phase}): GPU memory - allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB, free: {memory_free:.2f}GB, peak_allocated: {max_memory_allocated:.2f}GB, peak_reserved: {max_memory_reserved:.2f}GB{ddp_info}" + ) + + +def train_loop(config: _config.TrainConfig): + use_ddp, local_rank, device = setup_ddp() + is_main = (not use_ddp) or (dist.get_rank() == 0) + set_seed(config.seed, local_rank) + + # Initialize checkpoint directory and wandb + resuming = False + if config.resume: + # Find checkpoint directory based on experiment name + exp_checkpoint_dir = config.checkpoint_dir + if exp_checkpoint_dir.exists(): + # Use validation to find the latest working checkpoint + latest_step = get_latest_checkpoint_step(exp_checkpoint_dir) + if latest_step is not None: + resuming = True + logging.info( + f"Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}" + ) + else: + raise FileNotFoundError(f"No valid checkpoints found in {exp_checkpoint_dir} for resume") + else: + raise FileNotFoundError(f"Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume") + elif config.overwrite and config.checkpoint_dir.exists(): + shutil.rmtree(config.checkpoint_dir) + logging.info(f"Overwriting checkpoint directory: {config.checkpoint_dir}") + + # Create checkpoint directory with experiment name + if not resuming: + # For new runs, create experiment-specific checkpoint directory + exp_checkpoint_dir = config.checkpoint_dir + exp_checkpoint_dir.mkdir(parents=True, exist_ok=True) + logging.info(f"Created experiment checkpoint directory: {exp_checkpoint_dir}") + else: + # For resume, checkpoint_dir is already set to the experiment directory + logging.info(f"Using existing experiment checkpoint directory: {config.checkpoint_dir}") + + # Initialize wandb (only on main process) + if is_main: + init_wandb(config, resuming=resuming, enabled=config.wandb_enabled) + + # Build data loader using the unified data loader + # Calculate effective batch size per GPU for DDP + # For N GPUs, each GPU should get batch_size/N samples, so total across all GPUs is batch_size + world_size = torch.distributed.get_world_size() if use_ddp else 1 + effective_batch_size = config.batch_size // world_size + logging.info( + f"Using batch size per GPU: {effective_batch_size} (total batch size across {world_size} GPUs: {config.batch_size})" + ) + + # Pass the original batch size to data loader - it will handle DDP splitting internally + loader, data_config = build_datasets(config) + + # Log sample images to wandb on first batch + if is_main and config.wandb_enabled and not resuming: + # Create a separate data loader for sample batch to avoid consuming the main loader + sample_data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=False) + sample_batch = next(iter(sample_data_loader)) + # Convert observation and actions to torch tensors + observation, actions = sample_batch + sample_batch = observation.to_dict() + sample_batch["actions"] = actions + + # Create sample images for wandb + images_to_log = [] + # Get batch size from the first image tensor + batch_size = next(iter(sample_batch["image"].values())).shape[0] + for i in range(min(5, batch_size)): + # Concatenate all camera views horizontally for this batch item + # Convert from NCHW to NHWC format for wandb + img_concatenated = torch.cat([img[i].permute(1, 2, 0) for img in sample_batch["image"].values()], axis=1) + img_concatenated = img_concatenated.cpu().numpy() + images_to_log.append(wandb.Image(img_concatenated)) + + wandb.log({"camera_views": images_to_log}, step=0) + + # Clear sample batch from memory aggressively + del sample_batch, observation, actions, images_to_log, img_concatenated + del sample_data_loader # Also delete the sample data loader + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + logging.info("Cleared sample batch and data loader from memory") + + # Build model + if not isinstance(config.model, openpi.models.pi0_config.Pi0Config): + # Convert dataclass to Pi0Config if needed + model_cfg = openpi.models.pi0_config.Pi0Config( + dtype=config.pytorch_training_precision, + action_dim=config.model.action_dim, + action_horizon=config.model.action_horizon, + max_token_len=config.model.max_token_len, + paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"), + action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"), + pi05=getattr(config.model, "pi05", False), + ) + else: + model_cfg = config.model + # Update dtype to match pytorch_training_precision + object.__setattr__(model_cfg, "dtype", config.pytorch_training_precision) + + model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_cfg).to(device) + + if hasattr(model, "gradient_checkpointing_enable"): + enable_gradient_checkpointing = True + model.gradient_checkpointing_enable() + logging.info("Enabled gradient checkpointing for memory optimization") + else: + enable_gradient_checkpointing = False + logging.info("Gradient checkpointing is not supported for this model") + + # Log initial memory usage after model creation + if is_main and torch.cuda.is_available(): + log_memory_usage(device, 0, "after_model_creation") + + # Enable memory optimizations for large-scale training + if world_size >= 8: + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + # Set memory allocation configuration + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True" + logging.info("Enabled memory optimizations for 8+ GPU training") + + if use_ddp: + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[device.index] if device.type == "cuda" else None, + find_unused_parameters=True, # Disable for memory efficiency + gradient_as_bucket_view=True, # Enable for memory efficiency + static_graph=world_size >= 8, # Enable for 8+ GPUs + ) + + # Load weights from weight_loader if specified (for fine-tuning) + if config.pytorch_weight_path is not None: + logging.info(f"Loading weights from: {config.pytorch_weight_path}") + + model_path = os.path.join(config.pytorch_weight_path, "model.safetensors") + safetensors.torch.load_model( + (model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model), model_path + ) + logging.info(f"Loaded PyTorch weights from {config.pytorch_weight_path}") + + # Optimizer + learning rate schedule from config + warmup_steps = config.lr_schedule.warmup_steps + peak_lr = config.lr_schedule.peak_lr + decay_steps = config.lr_schedule.decay_steps + end_lr = config.lr_schedule.decay_lr + + # Create optimizer with config parameters + optim = torch.optim.AdamW( + model.parameters(), + lr=peak_lr, + betas=(config.optimizer.b1, config.optimizer.b2), + eps=config.optimizer.eps, + weight_decay=config.optimizer.weight_decay, + ) + + # Load checkpoint if resuming + global_step = 0 + if resuming: + global_step = load_checkpoint(model, optim, config.checkpoint_dir, device) + logging.info(f"Resumed training from step {global_step}") + + def lr_schedule(step: int): + if step < warmup_steps: + # Match JAX behavior: start from peak_lr / (warmup_steps + 1) + init_lr = peak_lr / (warmup_steps + 1) + return init_lr + (peak_lr - init_lr) * step / warmup_steps + # cosine decay + progress = min(1.0, (step - warmup_steps) / max(1, decay_steps - warmup_steps)) + cos = 0.5 * (1 + np.cos(np.pi * progress)) + return end_lr + (peak_lr - end_lr) * cos + + model.train() + start_time = time.time() + infos = [] # Collect stats over log interval + if is_main: + logging.info( + f"Running on: {platform.node()} | world_size={torch.distributed.get_world_size() if use_ddp else 1}" + ) + logging.info( + f"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}" + ) + logging.info(f"Memory optimizations: gradient_checkpointing={enable_gradient_checkpointing}") + logging.info( + f"LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}" + ) + logging.info( + f"Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}" + ) + logging.info("EMA is not supported for PyTorch training") + logging.info(f"Training precision: {model_cfg.dtype}") + + # Training loop - iterate until we reach num_train_steps + pbar = ( + tqdm.tqdm(total=config.num_train_steps, initial=global_step, desc="Training", disable=not is_main) + if is_main + else None + ) + + while global_step < config.num_train_steps: + # Set epoch for distributed training + if use_ddp and hasattr(loader, "set_epoch"): + loader.set_epoch(global_step // len(loader)) + + for observation, actions in loader: + # Check if we've reached the target number of steps + if global_step >= config.num_train_steps: + break + + # The unified data loader returns (observation, actions) tuple + observation = jax.tree.map(lambda x: x.to(device), observation) # noqa: PLW2901 + actions = actions.to(torch.float32) # noqa: PLW2901 + actions = actions.to(device) # noqa: PLW2901 + + # Update LR + for pg in optim.param_groups: + pg["lr"] = lr_schedule(global_step) + + # Forward pass + losses = model(observation, actions) + # Ensure losses is a tensor and handle different return types + if isinstance(losses, list | tuple): + losses = torch.stack(losses) + elif not isinstance(losses, torch.Tensor): + losses = torch.tensor(losses, device=device, dtype=torch.float32) + + loss = losses.mean() + + # Backward pass + loss.backward() + + # Log memory usage after backward pass + if global_step < 5 and is_main and torch.cuda.is_available(): + log_memory_usage(device, global_step, "after_backward") + + # Gradient clipping + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optimizer.clip_gradient_norm) + + # Optimizer step + optim.step() + optim.zero_grad(set_to_none=True) + + # Clear gradients more aggressively + for param in model.parameters(): + if param.grad is not None: + param.grad.detach_() + param.grad = None + + # Collect stats + if is_main: + infos.append( + { + "loss": loss.item(), + "learning_rate": optim.param_groups[0]["lr"], + "grad_norm": float(grad_norm) if isinstance(grad_norm, torch.Tensor) else grad_norm, + } + ) + + if is_main and (global_step % config.log_interval == 0): + elapsed = time.time() - start_time + + # Average stats over log interval + avg_loss = sum(info["loss"] for info in infos) / len(infos) + avg_lr = sum(info["learning_rate"] for info in infos) / len(infos) + + avg_grad_norm = None + if any("grad_norm" in info for info in infos): + vals = [ + info["grad_norm"] for info in infos if "grad_norm" in info and info["grad_norm"] is not None + ] + if len(vals) > 0: + avg_grad_norm = sum(vals) / len(vals) + logging.info( + f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} grad_norm={avg_grad_norm:.2f} time={elapsed:.1f}s" + if avg_grad_norm is not None + else f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} time={elapsed:.1f}s" + ) + + # Log to wandb + if config.wandb_enabled and len(infos) > 0: + log_payload = { + "loss": avg_loss, + "learning_rate": avg_lr, + "step": global_step, + "time_per_step": elapsed / config.log_interval, + } + if avg_grad_norm is not None: + log_payload["grad_norm"] = avg_grad_norm + wandb.log(log_payload, step=global_step) + + start_time = time.time() + infos = [] # Reset stats collection + + global_step += 1 + # Save checkpoint using the new mechanism + save_checkpoint(model, optim, global_step, config, is_main, data_config) + + # Update progress bar + if pbar is not None: + pbar.update(1) + pbar.set_postfix( + {"loss": f"{loss.item():.4f}", "lr": f"{optim.param_groups[0]['lr']:.2e}", "step": global_step} + ) + + # Close progress bar + if pbar is not None: + pbar.close() + + # Finish wandb run + if is_main and config.wandb_enabled: + wandb.finish() + + cleanup_ddp() + + +def main(): + init_logging() + config = _config.cli() + train_loop(config) + + +if __name__ == "__main__": + main() diff --git a/capvector-pi05/scripts/train_regular_loss_pytorch.py b/capvector-pi05/scripts/train_regular_loss_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..2a688cfc72b950f46073b24177fbc4b6b13246f6 --- /dev/null +++ b/capvector-pi05/scripts/train_regular_loss_pytorch.py @@ -0,0 +1,754 @@ +""" +PyTorch training entrypoint for PI0/PI05 with multi-GPU and multi-node (DDP) support. +This script mirrors the behavior of the JAX trainer (`scripts/train.py`) but runs +entirely in PyTorch using the `PI0Pytorch` model and your existing config/data +pipeline from `src/openpi/training/config.py` and `src/openpi/training/data_loader.py`. + +Usage +Single GPU: + python scripts/train_pytorch.py --exp_name --save_interval + Example: + python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test + python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test --resume # Resume from latest checkpoint +Multi-GPU (single node): + torchrun --standalone --nnodes=1 --nproc_per_node= scripts/train_pytorch.py --exp_name + Example: + torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test + torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume +Multi-Node Training: + torchrun \ + --nnodes= --nproc_per_node= --node_rank= \ + --master_addr= --master_port= \ + scripts/train_pytorch.py --exp_name= --save_interval + +""" + +import dataclasses +import gc +import logging +import os +import platform +from pathlib import Path +import shutil +import time + +import jax +import numpy as np +import safetensors.torch +import torch +import torch.distributed as dist +import torch.nn.parallel +import tqdm +import wandb + +import openpi.models.pi0_config +import openpi.models_pytorch.pi0_pytorch +import openpi.shared.normalize as _normalize +import openpi.training.config as _config +import openpi.training.data_loader as _data + + +def init_logging(): + level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"} + + class CustomFormatter(logging.Formatter): + def format(self, record): + record.levelname = level_mapping.get(record.levelname, record.levelname) + return super().format(record) + + formatter = CustomFormatter( + fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)", + datefmt="%H:%M:%S", + ) + logger = logging.getLogger() + logger.setLevel(logging.INFO) + if not logger.handlers: + ch = logging.StreamHandler() + ch.setFormatter(formatter) + logger.addHandler(ch) + else: + logger.handlers[0].setFormatter(formatter) + + +def init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: bool = True): + """Initialize wandb logging.""" + if not enabled: + wandb.init(mode="disabled") + return + + ckpt_dir = config.checkpoint_dir + if not ckpt_dir.exists(): + raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.") + + if resuming: + run_id = (ckpt_dir / "wandb_id.txt").read_text().strip() + wandb.init(id=run_id, resume="must", project=config.project_name) + else: + wandb.init( + name=config.name, + config=dataclasses.asdict(config), + project=config.project_name, + id="-".join([config.name, config.exp_name]), + ) + (ckpt_dir / "wandb_id.txt").write_text(wandb.run.id) + + +def setup_ddp(): + world_size = int(os.environ.get("WORLD_SIZE", "1")) + use_ddp = world_size > 1 + if use_ddp and not torch.distributed.is_initialized(): + backend = "nccl" if torch.cuda.is_available() else "gloo" + torch.distributed.init_process_group(backend=backend, init_method="env://") + + # Set up debugging environment variables for DDP issues + if os.environ.get("TORCH_DISTRIBUTED_DEBUG") is None: + os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO" + + local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0"))) + device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") + if torch.cuda.is_available(): + torch.cuda.set_device(device) + return use_ddp, local_rank, device + + +def cleanup_ddp(): + if torch.distributed.is_initialized(): + torch.distributed.barrier() + torch.distributed.destroy_process_group() + + +def set_seed(seed: int, local_rank: int): + torch.manual_seed(seed + local_rank) + np.random.seed(seed + local_rank) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed + local_rank) + + +def build_datasets(config: _config.TrainConfig): + # Use the unified data loader with PyTorch framework + data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=True) + return data_loader, data_loader.data_config() + + +def get_model_state_dict(model): + """Get state dict from model, handling DDP wrapper.""" + return ( + model.module.state_dict() + if isinstance(model, torch.nn.parallel.DistributedDataParallel) + else model.state_dict() + ) + + +def get_model_parameters(model): + """Get parameters from model, handling DDP wrapper.""" + return ( + model.module.parameters() + if isinstance(model, torch.nn.parallel.DistributedDataParallel) + else model.parameters() + ) + + +def load_regular_vector_dict(path: str | Path) -> dict[str, torch.Tensor]: + """Load the regularization vectors, which are used for delta-based regularization.""" + tensor_path = Path(path) + suffix = tensor_path.suffix.lower() + + if suffix in {".pt", ".pth"}: + tensors = torch.load(tensor_path, map_location="cpu", weights_only=False, mmap=True) + elif suffix == ".safetensors": + tensors = safetensors.torch.load_file(str(tensor_path), device="cpu") + else: + raise ValueError(f"Unsupported tensor file format: {tensor_path}") + + return tensors["state_dict"] + + +def prepare_regularization_context( + model, + config: _config.TrainConfig, +) -> dict | None: + """Load regularization tensors and build the runtime context for delta-based regularization.""" + + # Don't use regularization optionally + if not config.regularization_vector_path or config.regularization_coeff == 0.0: + return None + + # Get the regularization vectors as reference directions + if config.resume: + raise ValueError( + "Delta-based regularization with --resume is not supported in this PyTorch trainer. " + "This run now keeps the anchor only in memory at startup." + ) + vector_path = Path(config.regularization_vector_path).expanduser() + if not vector_path.exists(): + raise FileNotFoundError(f"Regularization vector file does not exist: {vector_path}") + regularization_vectors = load_regular_vector_dict(vector_path) + + # Get the model's trainable parameters to be regularized and the corresponding freezing anchors at startup + model_module = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model + + trainable_entries = [] + missing_vectors = 0 + shape_mismatches = 0 + trainable_param_names = set() + + for name, param in model_module.named_parameters(): + if not param.requires_grad: + continue + trainable_param_names.add(name) + regularization_vector = regularization_vectors.get(name) + if regularization_vector is None: + missing_vectors += 1 + continue + anchor_param = param.detach().clone().contiguous() + if regularization_vector.shape != param.shape or anchor_param.shape != param.shape: + shape_mismatches += 1 + continue + trainable_entries.append( + { + "name": name, + "param": param, + "anchor": anchor_param, + "vector": regularization_vector.to(device=param.device, dtype=param.dtype).contiguous(), + } + ) + + logging.info( + "Regularization coverage: matched=%d missing_vectors=%d shape_mismatches=%d", + len(trainable_entries), + missing_vectors, + shape_mismatches, + ) + + return { + "entries": trainable_entries, + "weight": config.regularization_coeff, + "vector_path": str(vector_path), + } + + +def compute_regularization_loss(regularization_context: dict | None, device: torch.device) -> torch.Tensor: + """Compute the delta-based regularization loss for the current model parameters.""" + reg_loss = torch.zeros((), device=device, dtype=torch.float32) + + if not regularization_context: + return reg_loss + + for entry in regularization_context["entries"]: + param = entry["param"] + anchor = entry["anchor"] + vector = entry["vector"] + + delta = (param - anchor).reshape(-1).float() + direction = vector.reshape(-1).float() + reg_loss = reg_loss + torch.abs(torch.dot(delta, direction)) + + return reg_loss * regularization_context["weight"] + + +def save_checkpoint(model, optimizer, global_step, config, is_main, data_config): + """Save a checkpoint with model state, optimizer state, and metadata.""" + if not is_main: + return + + # Only save if it's time to save or if it's the final step + if (global_step % config.save_interval == 0 and global_step > 0) or global_step == config.num_train_steps - 1: + # Create temporary directory for atomic checkpoint saving + final_ckpt_dir = config.checkpoint_dir / f"{global_step}" + tmp_ckpt_dir = config.checkpoint_dir / f"tmp_{global_step}" + + # Remove any existing temp directory and create new one + if tmp_ckpt_dir.exists(): + shutil.rmtree(tmp_ckpt_dir) + tmp_ckpt_dir.mkdir(parents=True, exist_ok=True) + + # Save model state using safetensors (handle shared tensors) + model_to_save = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model + safetensors.torch.save_model(model_to_save, tmp_ckpt_dir / "model.safetensors") + + # Save optimizer state using PyTorch format + torch.save(optimizer.state_dict(), tmp_ckpt_dir / "optimizer.pt") + + # Save training metadata (avoid saving full config to prevent JAX/Flax compatibility issues) + metadata = { + "global_step": global_step, + "config": dataclasses.asdict(config), + "timestamp": time.time(), + } + torch.save(metadata, tmp_ckpt_dir / "metadata.pt") + + # save norm stats + norm_stats = data_config.norm_stats + if norm_stats is not None and data_config.asset_id is not None: + _normalize.save(tmp_ckpt_dir / "assets" / data_config.asset_id, norm_stats) + + # Atomically move temp directory to final location + if final_ckpt_dir.exists(): + shutil.rmtree(final_ckpt_dir) + tmp_ckpt_dir.rename(final_ckpt_dir) + + logging.info(f"Saved checkpoint at step {global_step} -> {final_ckpt_dir}") + + # Log checkpoint to wandb + if config.wandb_enabled: + wandb.log({"checkpoint_step": global_step}, step=global_step) + + +def load_checkpoint(model, optimizer, checkpoint_dir, device): + """Load the latest checkpoint and return the global step.""" + checkpoint_steps = [ + int(d.name) + for d in checkpoint_dir.iterdir() + if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_") + ] + + if not checkpoint_steps: + raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") + + latest_step = max(checkpoint_steps) + ckpt_dir = checkpoint_dir / f"{latest_step}" + + # Clear memory before loading checkpoints + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + log_memory_usage(device, latest_step, "before_loading_checkpoint") + + try: + # Load model state with error handling + logging.info("Loading model state...") + safetensors_path = ckpt_dir / "model.safetensors" + + if safetensors_path.exists(): + model_to_load = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model + safetensors.torch.load_model(model_to_load, safetensors_path, device=str(device)) + logging.info("Loaded model state from safetensors format") + else: + raise FileNotFoundError(f"No model checkpoint found at {ckpt_dir}") + + torch.cuda.empty_cache() + gc.collect() + log_memory_usage(device, latest_step, "after_loading_model") + + # Load optimizer state with error handling + logging.info("Loading optimizer state...") + optimizer_path = ckpt_dir / "optimizer.pt" + + if optimizer_path.exists(): + optimizer_state_dict = torch.load(optimizer_path, map_location=device, weights_only=False) + logging.info("Loaded optimizer state from pt format") + else: + raise FileNotFoundError(f"No optimizer checkpoint found at {ckpt_dir}") + + optimizer.load_state_dict(optimizer_state_dict) + del optimizer_state_dict + torch.cuda.empty_cache() + gc.collect() + log_memory_usage(device, latest_step, "after_loading_optimizer") + + # Load metadata + logging.info("Loading metadata...") + metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False) + global_step = metadata.get("global_step", latest_step) + del metadata + torch.cuda.empty_cache() + gc.collect() + log_memory_usage(device, latest_step, "after_loading_metadata") + + logging.info(f"Successfully loaded all checkpoint components from step {latest_step}") + return global_step + + except RuntimeError as e: + if "out of memory" in str(e): + # Clear memory and provide detailed error message + torch.cuda.empty_cache() + gc.collect() + logging.error(f"Out of memory error while loading checkpoint: {e!s}") + log_memory_usage(device, latest_step, "after_oom_error") + raise RuntimeError( + "Out of memory while loading checkpoint. Try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True" + ) from e + raise + + +def get_latest_checkpoint_step(checkpoint_dir): + """Get the latest checkpoint step number from a checkpoint directory.""" + checkpoint_steps = [ + int(d.name) + for d in checkpoint_dir.iterdir() + if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_") + ] + return max(checkpoint_steps) if checkpoint_steps else None + + +def log_memory_usage(device, step, phase="unknown"): + """Log detailed memory usage information.""" + if not torch.cuda.is_available(): + return + + memory_allocated = torch.cuda.memory_allocated(device) / 1e9 + memory_reserved = torch.cuda.memory_reserved(device) / 1e9 + memory_free = torch.cuda.memory_reserved(device) - torch.cuda.memory_allocated(device) + memory_free = memory_free / 1e9 + + # Get more detailed memory info + memory_stats = torch.cuda.memory_stats(device) + max_memory_allocated = memory_stats.get("allocated_bytes.all.peak", 0) / 1e9 + max_memory_reserved = memory_stats.get("reserved_bytes.all.peak", 0) / 1e9 + + # Get DDP info if available + ddp_info = "" + if dist.is_initialized(): + ddp_info = f" | DDP: rank={dist.get_rank()}, world_size={dist.get_world_size()}" + + logging.info( + f"Step {step} ({phase}): GPU memory - allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB, free: {memory_free:.2f}GB, peak_allocated: {max_memory_allocated:.2f}GB, peak_reserved: {max_memory_reserved:.2f}GB{ddp_info}" + ) + + +def train_loop(config: _config.TrainConfig): + use_ddp, local_rank, device = setup_ddp() + is_main = (not use_ddp) or (dist.get_rank() == 0) + set_seed(config.seed, local_rank) + + # Initialize checkpoint directory and wandb + resuming = False + if config.resume: + # Find checkpoint directory based on experiment name + exp_checkpoint_dir = config.checkpoint_dir + if exp_checkpoint_dir.exists(): + # Use validation to find the latest working checkpoint + latest_step = get_latest_checkpoint_step(exp_checkpoint_dir) + if latest_step is not None: + resuming = True + logging.info( + f"Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}" + ) + else: + raise FileNotFoundError(f"No valid checkpoints found in {exp_checkpoint_dir} for resume") + else: + raise FileNotFoundError(f"Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume") + elif config.overwrite and config.checkpoint_dir.exists(): + shutil.rmtree(config.checkpoint_dir) + logging.info(f"Overwriting checkpoint directory: {config.checkpoint_dir}") + + # Create checkpoint directory with experiment name + if not resuming: + # For new runs, create experiment-specific checkpoint directory + exp_checkpoint_dir = config.checkpoint_dir + exp_checkpoint_dir.mkdir(parents=True, exist_ok=True) + logging.info(f"Created experiment checkpoint directory: {exp_checkpoint_dir}") + else: + # For resume, checkpoint_dir is already set to the experiment directory + logging.info(f"Using existing experiment checkpoint directory: {config.checkpoint_dir}") + + # Initialize wandb (only on main process) + if is_main: + init_wandb(config, resuming=resuming, enabled=config.wandb_enabled) + + # Build data loader using the unified data loader + # Calculate effective batch size per GPU for DDP + # For N GPUs, each GPU should get batch_size/N samples, so total across all GPUs is batch_size + world_size = torch.distributed.get_world_size() if use_ddp else 1 + effective_batch_size = config.batch_size // world_size + logging.info( + f"Using batch size per GPU: {effective_batch_size} (total batch size across {world_size} GPUs: {config.batch_size})" + ) + + # Pass the original batch size to data loader - it will handle DDP splitting internally + loader, data_config = build_datasets(config) + + # Log sample images to wandb on first batch + if is_main and config.wandb_enabled and not resuming: + # Create a separate data loader for sample batch to avoid consuming the main loader + sample_data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=False) + sample_batch = next(iter(sample_data_loader)) + # Convert observation and actions to torch tensors + observation, actions = sample_batch + sample_batch = observation.to_dict() + sample_batch["actions"] = actions + + # Create sample images for wandb + images_to_log = [] + # Get batch size from the first image tensor + batch_size = next(iter(sample_batch["image"].values())).shape[0] + for i in range(min(5, batch_size)): + # Concatenate all camera views horizontally for this batch item + # Convert from NCHW to NHWC format for wandb + img_concatenated = torch.cat([img[i].permute(1, 2, 0) for img in sample_batch["image"].values()], axis=1) + img_concatenated = img_concatenated.cpu().numpy() + images_to_log.append(wandb.Image(img_concatenated)) + + wandb.log({"camera_views": images_to_log}, step=0) + + # Clear sample batch from memory aggressively + del sample_batch, observation, actions, images_to_log, img_concatenated + del sample_data_loader # Also delete the sample data loader + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + logging.info("Cleared sample batch and data loader from memory") + + # Build model + if not isinstance(config.model, openpi.models.pi0_config.Pi0Config): + # Convert dataclass to Pi0Config if needed + model_cfg = openpi.models.pi0_config.Pi0Config( + dtype=config.pytorch_training_precision, + action_dim=config.model.action_dim, + action_horizon=config.model.action_horizon, + max_token_len=config.model.max_token_len, + paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"), + action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"), + pi05=getattr(config.model, "pi05", False), + ) + else: + model_cfg = config.model + # Update dtype to match pytorch_training_precision + object.__setattr__(model_cfg, "dtype", config.pytorch_training_precision) + + model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_cfg).to(device) + + if hasattr(model, "gradient_checkpointing_enable"): + enable_gradient_checkpointing = True + model.gradient_checkpointing_enable() + logging.info("Enabled gradient checkpointing for memory optimization") + else: + enable_gradient_checkpointing = False + logging.info("Gradient checkpointing is not supported for this model") + + # Log initial memory usage after model creation + if is_main and torch.cuda.is_available(): + log_memory_usage(device, 0, "after_model_creation") + + # Enable memory optimizations for large-scale training + if world_size >= 8: + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + # Set memory allocation configuration + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True" + logging.info("Enabled memory optimizations for 8+ GPU training") + + if use_ddp: + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[device.index] if device.type == "cuda" else None, + find_unused_parameters=True, # Disable for memory efficiency + gradient_as_bucket_view=True, # Enable for memory efficiency + static_graph=world_size >= 8, # Enable for 8+ GPUs + ) + + # Load weights from weight_loader if specified (for fine-tuning) + if config.pytorch_weight_path is not None: + logging.info(f"Loading weights from: {config.pytorch_weight_path}") + + model_path = os.path.join(config.pytorch_weight_path, "model.safetensors") + safetensors.torch.load_model( + (model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model), model_path + ) + logging.info(f"Loaded PyTorch weights from {config.pytorch_weight_path}") + + regularization_context = prepare_regularization_context(model, config) + + # Optimizer + learning rate schedule from config + warmup_steps = config.lr_schedule.warmup_steps + peak_lr = config.lr_schedule.peak_lr + decay_steps = config.lr_schedule.decay_steps + end_lr = config.lr_schedule.decay_lr + + # Create optimizer with config parameters + optim = torch.optim.AdamW( + model.parameters(), + lr=peak_lr, + betas=(config.optimizer.b1, config.optimizer.b2), + eps=config.optimizer.eps, + weight_decay=config.optimizer.weight_decay, + ) + + # Load checkpoint if resuming + global_step = 0 + if resuming: + global_step = load_checkpoint(model, optim, config.checkpoint_dir, device) + logging.info(f"Resumed training from step {global_step}") + + def lr_schedule(step: int): + if step < warmup_steps: + # Match JAX behavior: start from peak_lr / (warmup_steps + 1) + init_lr = peak_lr / (warmup_steps + 1) + return init_lr + (peak_lr - init_lr) * step / warmup_steps + # cosine decay + progress = min(1.0, (step - warmup_steps) / max(1, decay_steps - warmup_steps)) + cos = 0.5 * (1 + np.cos(np.pi * progress)) + return end_lr + (peak_lr - end_lr) * cos + + model.train() + start_time = time.time() + infos = [] # Collect stats over log interval + if is_main: + logging.info( + f"Running on: {platform.node()} | world_size={torch.distributed.get_world_size() if use_ddp else 1}" + ) + logging.info( + f"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}" + ) + logging.info(f"Memory optimizations: gradient_checkpointing={enable_gradient_checkpointing}") + logging.info( + f"LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}" + ) + logging.info( + f"Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}" + ) + logging.info("EMA is not supported for PyTorch training") + logging.info(f"Training precision: {model_cfg.dtype}") + if regularization_context: + logging.info( + "Delta-based regularization: enabled | weight=%.2e | vector=%s", + config.regularization_coeff, + regularization_context["vector_path"], + ) + + # Training loop - iterate until we reach num_train_steps + pbar = ( + tqdm.tqdm(total=config.num_train_steps, initial=global_step, desc="Training", disable=not is_main) + if is_main + else None + ) + + while global_step < config.num_train_steps: + # Set epoch for distributed training + if use_ddp and hasattr(loader, "set_epoch"): + loader.set_epoch(global_step // len(loader)) + + for observation, actions in loader: + # Check if we've reached the target number of steps + if global_step >= config.num_train_steps: + break + + # The unified data loader returns (observation, actions) tuple + observation = jax.tree.map(lambda x: x.to(device), observation) # noqa: PLW2901 + actions = actions.to(torch.float32) # noqa: PLW2901 + actions = actions.to(device) # noqa: PLW2901 + + # Update LR + for pg in optim.param_groups: + pg["lr"] = lr_schedule(global_step) + + # Forward pass + losses = model(observation, actions) + # Ensure losses is a tensor and handle different return types + if isinstance(losses, list | tuple): + losses = torch.stack(losses) + elif not isinstance(losses, torch.Tensor): + losses = torch.tensor(losses, device=device, dtype=torch.float32) + + action_loss = losses.mean() + regularization_loss = compute_regularization_loss(regularization_context, device) + total_loss = action_loss + regularization_loss + + # Backward pass + total_loss.backward() + + # Log memory usage after backward pass + if global_step < 5 and is_main and torch.cuda.is_available(): + log_memory_usage(device, global_step, "after_backward") + + # Gradient clipping + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optimizer.clip_gradient_norm) + + # Optimizer step + optim.step() + optim.zero_grad(set_to_none=True) + + # Clear gradients more aggressively + for param in model.parameters(): + if param.grad is not None: + param.grad.detach_() + param.grad = None + + # Collect stats + if is_main: + infos.append( + { + "action_loss": action_loss.item(), + "regularization_loss": regularization_loss.item(), + "total_loss": total_loss.item(), + "learning_rate": optim.param_groups[0]["lr"], + "grad_norm": float(grad_norm) if isinstance(grad_norm, torch.Tensor) else grad_norm, + } + ) + + if is_main and (global_step % config.log_interval == 0): + elapsed = time.time() - start_time + + # Average stats over log interval + avg_action_loss = sum(info["action_loss"] for info in infos) / len(infos) + avg_regularization_loss = sum(info["regularization_loss"] for info in infos) / len(infos) + avg_total_loss = sum(info["total_loss"] for info in infos) / len(infos) + avg_lr = sum(info["learning_rate"] for info in infos) / len(infos) + + avg_grad_norm = None + if any("grad_norm" in info for info in infos): + vals = [ + info["grad_norm"] for info in infos if "grad_norm" in info and info["grad_norm"] is not None + ] + if len(vals) > 0: + avg_grad_norm = sum(vals) / len(vals) + logging.info( + f"step={global_step} action_loss={avg_action_loss:.4f} regularization_loss={avg_regularization_loss:.4f} total_loss={avg_total_loss:.4f} lr={avg_lr:.2e} grad_norm={avg_grad_norm:.2f} time={elapsed:.1f}s" + if avg_grad_norm is not None + else f"step={global_step} action_loss={avg_action_loss:.4f} regularization_loss={avg_regularization_loss:.4f} total_loss={avg_total_loss:.4f} lr={avg_lr:.2e} time={elapsed:.1f}s" + ) + + # Log to wandb + if config.wandb_enabled and len(infos) > 0: + log_payload = { + "action_loss": avg_action_loss, + "regularization_loss": avg_regularization_loss, + "total_loss": avg_total_loss, + "learning_rate": avg_lr, + "step": global_step, + "time_per_step": elapsed / config.log_interval, + } + if avg_grad_norm is not None: + log_payload["grad_norm"] = avg_grad_norm + wandb.log(log_payload, step=global_step) + + start_time = time.time() + infos = [] # Reset stats collection + + global_step += 1 + # Save checkpoint using the new mechanism + save_checkpoint(model, optim, global_step, config, is_main, data_config) + + # Update progress bar + if pbar is not None: + pbar.update(1) + pbar.set_postfix( + { + "action_loss": f"{action_loss.item():.4f}", + "reg_loss": f"{regularization_loss.item():.4f}", + "total_loss": f"{total_loss.item():.4f}", + "lr": f"{optim.param_groups[0]['lr']:.2e}", + "step": global_step, + } + ) + + # Close progress bar + if pbar is not None: + pbar.close() + + # Finish wandb run + if is_main and config.wandb_enabled: + wandb.finish() + + cleanup_ddp() + + +def main(): + init_logging() + config = _config.cli() + train_loop(config) + + +if __name__ == "__main__": + main() diff --git a/capvector-pi05/scripts/train_test.py b/capvector-pi05/scripts/train_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9e0a31234f68eadbe2721c18c1c98967f68280dc --- /dev/null +++ b/capvector-pi05/scripts/train_test.py @@ -0,0 +1,30 @@ +import dataclasses +import os +import pathlib + +import pytest + +os.environ["JAX_PLATFORMS"] = "cpu" + +from openpi.training import config as _config + +from . import train + + +@pytest.mark.parametrize("config_name", ["debug"]) +def test_train(tmp_path: pathlib.Path, config_name: str): + config = dataclasses.replace( + _config._CONFIGS_DICT[config_name], # noqa: SLF001 + batch_size=2, + checkpoint_base_dir=str(tmp_path / "checkpoint"), + exp_name="test", + overwrite=False, + resume=False, + num_train_steps=2, + log_interval=1, + ) + train.main(config) + + # test resuming + config = dataclasses.replace(config, resume=True, num_train_steps=4) + train.main(config) diff --git a/capvector-pi05/src/openpi/__init__.py b/capvector-pi05/src/openpi/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/capvector-pi05/src/openpi/conftest.py b/capvector-pi05/src/openpi/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..dfc58102eae19451f63992024bb9a856c78b46a5 --- /dev/null +++ b/capvector-pi05/src/openpi/conftest.py @@ -0,0 +1,17 @@ +import os + +import pynvml +import pytest + + +def set_jax_cpu_backend_if_no_gpu() -> None: + try: + pynvml.nvmlInit() + pynvml.nvmlShutdown() + except pynvml.NVMLError: + # No GPU found. + os.environ["JAX_PLATFORMS"] = "cpu" + + +def pytest_configure(config: pytest.Config) -> None: + set_jax_cpu_backend_if_no_gpu() diff --git a/capvector-pi05/src/openpi/models/__init__.py b/capvector-pi05/src/openpi/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/capvector-pi05/src/openpi/models/gemma.py b/capvector-pi05/src/openpi/models/gemma.py new file mode 100644 index 0000000000000000000000000000000000000000..128a286cae227a461c88dde0e3f7f3b7bb21bce6 --- /dev/null +++ b/capvector-pi05/src/openpi/models/gemma.py @@ -0,0 +1,459 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Gemma adaptation for Pi, taken from big_vision. + +We follow this einsum axis naming convention: + B: batch + T: query length + S: k/v length + N: num query heads + K: num k/v heads + G: num query heads per k/v head + H: head dim + D: d_model ("features") +""" + +from collections.abc import Sequence +import dataclasses +from typing import Literal, TypeAlias + +import einops +import flax.linen as nn +import jax +import jax.numpy as jnp + +import openpi.models.lora as lora +import openpi.shared.array_typing as at +import openpi.training.sharding as sharding + +PALIGEMMA_VOCAB_SIZE = 257_152 + + +@dataclasses.dataclass +class Config: + width: int + depth: int + mlp_dim: int + num_heads: int + num_kv_heads: int + head_dim: int + lora_configs: dict[str, lora.LoRAConfig] = dataclasses.field(default_factory=dict) + + +Variant = Literal["dummy", "gemma_300m", "gemma_300m_lora", "gemma_2b", "gemma_2b_lora"] + + +def get_config(variant: Variant) -> Config: + """Returns config for specified gemma variant.""" + if variant == "dummy": + return Config( + width=64, + depth=4, + mlp_dim=128, + num_heads=8, + num_kv_heads=1, + head_dim=16, + ) + if variant == "gemma_300m": + # 311M params + return Config( + width=1024, + depth=18, + mlp_dim=4096, + num_heads=8, + num_kv_heads=1, + head_dim=256, + ) + if variant == "gemma_2b": + return Config( + width=2048, + depth=18, + mlp_dim=16_384, + num_heads=8, + num_kv_heads=1, + head_dim=256, + ) + if variant == "gemma_2b_lora": + return Config( + width=2048, + depth=18, + mlp_dim=16_384, + num_heads=8, + num_kv_heads=1, + head_dim=256, + lora_configs={"attn": lora.LoRAConfig(rank=16, alpha=16.0), "ffn": lora.LoRAConfig(rank=16, alpha=16.0)}, + ) + if variant == "gemma_300m_lora": + # 311M params + return Config( + width=1024, + depth=18, + mlp_dim=4096, + num_heads=8, + num_kv_heads=1, + head_dim=256, + lora_configs={"attn": lora.LoRAConfig(rank=32, alpha=32.0), "ffn": lora.LoRAConfig(rank=32, alpha=32.0)}, + ) + raise ValueError(f"Unknown variant: {variant}") + + +@at.typecheck +class RMSNorm(nn.Module): + @nn.compact + def __call__(self, x, cond): + dtype = x.dtype # original dtype, could be half-precision + var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True) # compute variance in float32 + normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) # compute normalization in float32 + if cond is None: + # regular RMSNorm + scale = self.param("scale", nn.initializers.zeros_init(), (x.shape[-1])) + normed_inputs = normed_inputs * ( + 1 + scale + ) # scale by learned parameter in float32 (matches Flax implementation) + return normed_inputs.astype(dtype), None # return in original dtype + + # adaptive RMSNorm + modulation = nn.Dense(x.shape[-1] * 3, kernel_init=nn.initializers.zeros, dtype=dtype)(cond) + scale, shift, gate = jnp.split(modulation[:, None, :], 3, axis=-1) + normed_inputs = normed_inputs * (1 + scale) + shift # scale and shift in float32 + return normed_inputs.astype(dtype), gate + + +@at.typecheck +class Embedder(nn.Module): + """Embedder module.""" + + vocab_size: int + embed_dim: int + + def setup(self): + self.input_embedding_table = self.param( + "input_embedding", + nn.initializers.normal(), + (self.vocab_size, self.embed_dim), + ) + + def encode(self, x): + x = self.input_embedding_table[(x,)] + x *= jnp.sqrt(self.embed_dim).astype(x.dtype) + return x + + def decode(self, x): + return jnp.dot(x, self.input_embedding_table.T) + + +@at.typecheck +class Attention(nn.Module): + """Attention module.""" + + configs: Sequence[Config] + + @nn.compact + def __call__(self, xs, positions, attn_mask, kv_cache): + # all experts must share the same head dim, num heads, and num kv heads for self-attention to work + assert all(config.head_dim == self.configs[0].head_dim for config in self.configs) + assert all(config.num_heads == self.configs[0].num_heads for config in self.configs) + assert all(config.num_kv_heads == self.configs[0].num_kv_heads for config in self.configs) + + dtype = next(x.dtype for x in xs if x is not None) # original dtype, could be half-precision + + qkvs = [] + for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)): + if x is None: + continue + if config.num_kv_heads == config.num_heads: + qkv_einsum = lora.Einsum( + shape=(3, config.num_heads, config.width, config.head_dim), + name=_name("qkv_einsum", i), + init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)), + lora_config=config.lora_configs.get("attn"), + ) + qkvs.append(qkv_einsum("BSD,3KDH->3BSKH", x)) + else: + q_einsum = lora.Einsum( + shape=(config.num_heads, config.width, config.head_dim), + name=_name("q_einsum", i), + init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)), + lora_config=config.lora_configs.get("attn"), + ) + q = q_einsum("BTD,NDH->BTNH", x) + kv_einsum = lora.Einsum( + shape=(2, config.num_kv_heads, config.width, config.head_dim), + name=_name("kv_einsum", i), + init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)), + lora_config=config.lora_configs.get("attn"), + ) + k, v = kv_einsum("BSD,2KDH->2BSKH", x) + qkvs.append((q, k, v)) + + q, k, v = (jnp.concatenate(y, axis=1) for y in zip(*qkvs, strict=True)) + + q = _apply_rope(q, positions=positions) + q *= self.configs[0].head_dim ** -0.5 + + k = _apply_rope(k, positions=positions) + + # should still be half-precision here (if input was half-precision) + assert q.dtype == k.dtype == v.dtype == dtype + + if kv_cache is not None: + cache_k, cache_v = kv_cache + k = jnp.concatenate([cache_k, k], axis=1) + v = jnp.concatenate([cache_v, v], axis=1) + + q = einops.rearrange(q, "B T (K G) H -> B T K G H", K=self.configs[0].num_kv_heads) + logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k, preferred_element_type=jnp.float32) + + if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]): + raise ValueError( + f"Attention mask with shape {attn_mask.shape} but shapes for q and k are: {q.shape} and {k.shape}" + ) + + # big_neg = jnp.finfo(logits.dtype).min + big_neg = -2.3819763e38 # See gemma/modules.py + masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg) + + probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype) + + encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v) + encoded = einops.rearrange(encoded, "B T K G H -> B T (K G) H") + + out = [] + start = 0 + for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)): + if x is not None: + end = start + x.shape[1] + out_einsum = lora.Einsum( + shape=(config.num_heads, config.head_dim, config.width), + name=_name("attn_vec_einsum", i), + init_fn=nn.initializers.lecun_normal(in_axis=(-3, -2), out_axis=-1), + lora_config=config.lora_configs.get("attn"), + ) + out.append(out_einsum("BTNH,NHD->BTD", encoded[:, start:end])) + start = end + else: + out.append(None) + + return out, (k, v) + + +@at.typecheck +class FeedForward(nn.Module): + """Feed forward module.""" + + features: int + hidden_dim: int + + @nn.compact + def __call__(self, x): + dtype = x.dtype # original dtype, could be half-precision + w_gating = self.param( + "gating_einsum", + nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)), + (2, self.features, self.hidden_dim), + ).astype(dtype) + ff_gate = jnp.dot(x, w_gating[0]) + gate_value = nn.gelu(ff_gate) + + ff1 = jnp.dot(x, w_gating[1]) + activations = gate_value * ff1 + + w_linear = self.param( + "linear", + nn.initializers.lecun_normal(in_axis=-2, out_axis=-1), + (self.hidden_dim, self.features), + ).astype(dtype) + outputs = jnp.dot(activations, w_linear) + assert outputs.dtype == dtype + return outputs + + +@at.typecheck +class Block(nn.Module): + """Transformer block.""" + + configs: tuple[Config, ...] + + dropout: float = 0.0 + dropout_bdims: tuple[int, ...] = () + + @nn.compact + def __call__(self, xs, kv_cache, positions, attn_mask, adarms_cond, deterministic=True): # noqa: FBT002 + xs = sharding.activation_sharding_constraint(xs) + drop = nn.Dropout(self.dropout, self.dropout_bdims) if self.dropout else lambda x, _: x + + attn = Attention(configs=self.configs, name="attn") + + pre_attn = [] + gates = [] + for i, x in enumerate(xs): + if x is not None: + x, gate = RMSNorm(name=_name("pre_attention_norm", i))(x, adarms_cond[i]) # noqa: PLW2901 + pre_attn.append(x) + gates.append(gate if x is not None else None) + + pre_attn = sharding.activation_sharding_constraint(pre_attn) + post_attn, kv_cache = attn(pre_attn, positions, attn_mask, kv_cache) + post_attn = jax.tree.map(lambda x: drop(x, deterministic), post_attn) + post_attn = sharding.activation_sharding_constraint(post_attn) + xs = [_gated_residual(x, y, gate) for x, y, gate in zip(xs, post_attn, gates, strict=True)] + xs = sharding.activation_sharding_constraint(xs) + + out = [] + gates = [] + for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)): + if x is not None: + x, gate = RMSNorm(name=_name("pre_ffw_norm", i))(x, adarms_cond[i]) # noqa: PLW2901 + x = lora.FeedForward( # noqa: PLW2901 + features=config.width, + hidden_dim=config.mlp_dim, + name=_name("mlp", i), + lora_config=config.lora_configs.get("ffn"), + )(x) + out.append(x) + gates.append(gate if x is not None else None) + + out = sharding.activation_sharding_constraint(out) + out = jax.tree.map(lambda x: drop(x, deterministic), out) + xs = [_gated_residual(x, y, gate) for x, y, gate in zip(xs, out, gates, strict=True)] + xs = sharding.activation_sharding_constraint(xs) + + return xs, kv_cache + + +KVCache: TypeAlias = tuple[at.Float[at.Array, "l b _t _k _h"], at.Float[at.Array, "l b _t _v _h"]] + + +@at.typecheck +class Module(nn.Module): + """Transformer model, supporting a mixture of different weights for different tokens.""" + + configs: Sequence[Config] # list of configs, one for each expert + embed_dtype: str + + dropout: float = 0.0 + dropout_bdims: tuple[int, ...] = () # Every float is dropped independently. + adarms: bool = False + + def setup(self): + # all experts must have the same depth + assert all(config.depth == self.configs[0].depth for config in self.configs) + + self.embedder = Embedder( + vocab_size=PALIGEMMA_VOCAB_SIZE, + embed_dim=self.configs[0].width, # embedder for first expert only + name="embedder", + ) + block_cls = nn.remat( + Block, + prevent_cse=False, + static_argnums=(5,), # 0=self, 6=deterministic + policy=jax.checkpoint_policies.nothing_saveable, + ) + self.layers = nn.scan( + block_cls, + variable_axes={"params": 0}, + split_rngs={"params": True, "dropout": True}, + in_axes=( + 0, + nn.broadcast, + nn.broadcast, + nn.broadcast, + nn.broadcast, + ), # 0=kv_cache, 1=positions, 2=mask, 3=adarms_cond, 4=deterministic + length=self.configs[0].depth, + )( + configs=self.configs, + dropout=self.dropout, + dropout_bdims=self.dropout_bdims, + ) + self.final_norms = [RMSNorm(name=_name("final_norm", i)) for i in range(len(self.configs))] + + @at.typecheck + def embed(self, tokens: at.Int[at.Array, "b t"]) -> at.Float[at.Array, "b t d"]: + return self.embedder.encode(tokens).astype(self.embed_dtype) + + @at.typecheck + def __call__( + self, + # list of token arrays, one for each expert, or None if that expert should not be run + embedded: Sequence[at.Float[at.Array, "b _t _d"] | None], + positions: at.Int[at.Array, "b t"], + mask: at.Bool[at.Array, "b t s"], + adarms_cond: Sequence[at.Float[at.Array, "b _d"] | None] | None = None, + *, + kv_cache: KVCache | None = None, + deterministic: bool = True, + ) -> tuple[Sequence[at.Float[at.Array, "b _t _d"] | None], KVCache]: + embedded = jax.tree.map(lambda e: e.astype(self.embed_dtype), embedded) + mask = jnp.asarray(mask)[:, None, :, :] + if adarms_cond is None: + adarms_cond = [None] * len(self.configs) + + embedded, kv_cache = self.layers(embedded, kv_cache, positions, mask, adarms_cond, deterministic) + + assert all(e.dtype == jnp.dtype(self.embed_dtype) for e in embedded if e is not None) + + return [ + f(e, a)[0] if e is not None else e for f, e, a in zip(self.final_norms, embedded, adarms_cond, strict=True) + ], kv_cache + + def init(self, use_adarms: Sequence[bool]): + """Convenience method for initializing all parameters, necessary due to the quirks of linen.""" + self.embed(jnp.zeros((1, 1), dtype=jnp.int32)) + self( + [jnp.zeros((1, 1, c.width)) for c in self.configs], + jnp.zeros((1, len(self.configs)), dtype=jnp.int32), + jnp.zeros((1, len(self.configs), len(self.configs)), dtype=bool), + adarms_cond=[jnp.zeros((1, c.width)) if u else None for u, c in zip(use_adarms, self.configs, strict=True)], + ) + + +def _apply_rope(x, *, positions, max_wavelength=10_000): + """Applies RoPE positions [B, L] to x [B, L, H, D].""" + freq_exponents = (2.0 / x.shape[-1]) * jnp.arange(x.shape[-1] // 2, dtype=jnp.float32) + timescale = max_wavelength**freq_exponents + radians = positions[..., None] / timescale[None, None, :] + radians = radians[..., None, :] + assert radians.dtype == jnp.float32 + # radians.shape = [...,L,1,d=D/2] + sin, cos = jnp.sin(radians), jnp.cos(radians) + x1, x2 = jnp.split(x, 2, axis=-1) + res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1) + assert res.dtype == jnp.float32 + # The original bigvision impl allows RoPE to upcast to float32. It is then immediately downcast again to the cache + # dtype when in inference mode (but not in training mode). I don't think any of this was intentional. Based on the + # original DeepMind impl, as well as the widely-used transformers impl, it is ok to always downcast back to bfloat16 + # here. + return res.astype(x.dtype) + + +def _name(name, i): + # we name layers like this because we want the first expert's weights to have no suffix (e.g., "attn"), so that they + # can be loaded seamlessly from the existing PaliGemma checkpoint. subsequent experts will have a suffix (e.g., + # "attn_1") and their weights will be initialized from scratch. in practice, we only use two experts -- PaliGemma, + # and the action expert. + if i == 0: + return name + return f"{name}_{i}" + + +def _gated_residual(x, y, gate): + assert (x is None) == (y is None) + if x is None: + return None + if gate is None: + return x + y + return x + y * gate diff --git a/capvector-pi05/src/openpi/models/gemma_fast.py b/capvector-pi05/src/openpi/models/gemma_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..0ba787601e143338a54241dab96c0fb6a311de89 --- /dev/null +++ b/capvector-pi05/src/openpi/models/gemma_fast.py @@ -0,0 +1,437 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Gemma model implementation from big_vision/models/ppp/gemma.py (with small modifications for NNX compatibility) +Used for FAST autoregressive policies. +""" + +import dataclasses +from typing import Literal, TypeAlias + +import einops +import flax.linen as nn +import jax +import jax.numpy as jnp +import ml_collections + +import openpi.models.lora as lora +import openpi.shared.array_typing as at + +Variant = Literal["gemma_2b", "gemma_2b_lora"] + + +def get_config(variant): + """Returns config for specified gemma variant.""" + if variant == "gemma_2b": + return ml_collections.ConfigDict( + { + "variant": variant, + "width": 2048, + "depth": 18, + "mlp_dim": 16_384, + "num_heads": 8, + "num_kv_heads": 1, + "head_dim": 256, + "norm_eps": 1e-6, + "vocab_size": 257_152, + "scan": True, + "remat_policy": "nothing_saveable", + } + ) + if variant == "gemma_2b_lora": + return ml_collections.ConfigDict( + { + "variant": variant, + "width": 2048, + "depth": 18, + "mlp_dim": 16_384, + "num_heads": 8, + "num_kv_heads": 1, + "head_dim": 256, + "norm_eps": 1e-6, + "vocab_size": 257_152, + "scan": True, + "remat_policy": "nothing_saveable", + "lora_configs": { + "attn": lora.LoRAConfig(rank=16, alpha=16.0), + "ffn": lora.LoRAConfig(rank=16, alpha=16.0), + }, + } + ) + raise ValueError(f"Unknown variant: {variant}") + + +@at.typecheck +class Einsum(nn.Module): + shape: tuple[int, ...] + + @nn.compact + def __call__(self, eqn, x): + dtype = x.dtype # original dtype, could be half-precision + w = self.param("w", nn.initializers.zeros_init(), self.shape).astype(dtype) + return jnp.einsum(eqn, x, w) + + +@at.typecheck +class RMSNorm(nn.Module): + @nn.compact + def __call__(self, x): + dtype = x.dtype # original dtype, could be half-precision + scale = self.param("scale", nn.initializers.zeros_init(), (x.shape[-1])) + var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True) # compute variance in float32 + normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) # compute normalization in float32 + normed_inputs = normed_inputs * ( + 1 + scale + ) # scale by learned parameter in float32 (matches Flax implementation) + return normed_inputs.astype(dtype) # return in original dtype + + +@at.typecheck +class Embedder(nn.Module): + """Embedder module.""" + + vocab_size: int + embed_dim: int + + def setup(self): + self.input_embedding_table = self.param( + "input_embedding", + nn.initializers.zeros_init(), + (self.vocab_size, self.embed_dim), + ) + + def encode(self, x): + x = self.input_embedding_table[(x,)] + x *= jnp.sqrt(self.embed_dim).astype(x.dtype) + return x + + def decode(self, x): + return jnp.dot(x, self.input_embedding_table.T) + + +@at.typecheck +class Attention(nn.Module): + """Attention module.""" + + num_heads: int + num_kv_heads: int + features: int + head_dim: int + + cache_dtype: str | None = None + + lora_config: lora.LoRAConfig | None = None + + def setup(self): + if self.num_kv_heads == self.num_heads: + self.qkv_einsum = lora.Einsum( + shape=(3, self.num_heads, self.features, self.head_dim), + name="qkv_einsum", + init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)), + lora_config=self.lora_config, + ) + else: + self.q_einsum = lora.Einsum( + shape=(self.num_heads, self.features, self.head_dim), + name="q_einsum", + init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)), + lora_config=self.lora_config, + ) + self.kv_einsum = lora.Einsum( + shape=(2, self.num_kv_heads, self.features, self.head_dim), + name="kv_einsum", + init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)), + lora_config=self.lora_config, + ) + self.attn_vec_einsum = lora.Einsum( + shape=(self.num_heads, self.head_dim, self.features), + name="attn_vec_einsum", + init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)), + lora_config=self.lora_config, + ) + + def _init_cache(self, k, v, cache_size): + """Initialize KV cache""" + prefill_len = k.shape[1] + pad_width = ((0, 0), (0, cache_size - prefill_len), (0, 0), (0, 0)) + cache_dtype = self.cache_dtype or k.dtype + k_cache = jnp.pad(k.astype(cache_dtype), pad_width) + v_cache = jnp.pad(v.astype(cache_dtype), pad_width) + idx = jnp.zeros((k.shape[0],), dtype=jnp.int32) + prefill_len + return idx, k_cache, v_cache + + def _update_cache(self, k, v, idx, k_cache, v_cache): + """Update KV cache with new values""" + assert k.shape[1] == 1, "Only support kv-cache updates of length 1" + indices = (0, idx[0], 0, 0) + cache_dtype = self.cache_dtype or k.dtype + k_new = jax.lax.dynamic_update_slice(k_cache, k.astype(cache_dtype), indices) + v_new = jax.lax.dynamic_update_slice(v_cache, v.astype(cache_dtype), indices) + idx_new = idx + 1 + return idx_new, k_new, v_new + + @nn.compact + def __call__(self, x, positions, attn_mask, kv_cache, decode, deterministic=True): # noqa: FBT002 + dtype = x.dtype # original dtype, could be half-precision + if self.num_kv_heads == self.num_heads: + q, k, v = self.qkv_einsum("BSD,3KDH->3BSKH", x) + else: + q = self.q_einsum("BTD,NDH->BTNH", x) + k, v = self.kv_einsum("BSD,2KDH->2BSKH", x) + + q = _apply_rope(q, positions=positions) # promotes to float32 + q *= self.head_dim**-0.5 + + k = _apply_rope(k, positions=positions) # promotes to float32 + + if kv_cache is None: + idx, k_cache, v_cache = self._init_cache(k, v, attn_mask.shape[-1]) + else: + idx, k_cache, v_cache = kv_cache + idx, k_cache, v_cache = self._update_cache(k, v, idx, k_cache, v_cache) + + k, v = k_cache, v_cache + kv_cache = (idx, k_cache, v_cache) + + q = einops.rearrange(q, "B T (K G) H -> B T K G H", K=self.num_kv_heads) + logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k, preferred_element_type=jnp.float32) + + if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]): + raise ValueError( + f"Attention mask with shape {attn_mask.shape} but shapes for q and k are: {q.shape} and {k.shape}" + ) + + # big_neg = jnp.finfo(logits.dtype).min + big_neg = -2.3819763e38 # See gemma/modules.py + masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg) + + probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype) + + encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v) + encoded = einops.rearrange(encoded, "B T K G H -> B T (K G) H") + return self.attn_vec_einsum("BTNH,NHD->BTD", encoded), kv_cache + + +@at.typecheck +class Block(nn.Module): + """Transformer block.""" + + num_heads: int + num_kv_heads: int + embed_dim: int + head_dim: int + hidden_dim: int + + dropout: float = 0.0 + dropout_bdims: tuple[int, ...] = () + cache_dtype: str | None = None + lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict) + + def setup(self): + self.pre_attention_norm = RMSNorm() + self.attn = Attention( + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + features=self.embed_dim, + head_dim=self.head_dim, + cache_dtype=self.cache_dtype, + lora_config=self.lora_configs.get("attn"), + ) + self.pre_ffw_norm = RMSNorm() + self.mlp = lora.FeedForward( + features=self.embed_dim, hidden_dim=self.hidden_dim, name="mlp", lora_config=self.lora_configs.get("ffn") + ) + if self.dropout: + self.drop = nn.Dropout(self.dropout, self.dropout_bdims) + else: + self.drop = lambda x, _: x + + def __call__(self, x, kv_cache, positions, attn_mask, decode, deterministic=True): # noqa: FBT002 + x = nn.with_logical_constraint(x, ("act_batch", "act_len", "act_emb")) + inputs_normalized = self.pre_attention_norm(x) + attn_output, kv_cache = self.attn(inputs_normalized, positions, attn_mask, kv_cache, decode, deterministic) + attn_output = self.drop(attn_output, deterministic) + attn_output += x + residual = attn_output + attn_output = self.pre_ffw_norm(attn_output) + outputs = self.mlp(attn_output) + outputs = self.drop(outputs, deterministic) + outputs = residual + outputs + return outputs, kv_cache + + +KVCache: TypeAlias = tuple[at.Int[at.Array, " b"], at.Float[at.Array, "b _t _k _h"], at.Float[at.Array, "b _t _v _h"]] + + +@at.typecheck +class Module(nn.Module): + """gemma model.""" + + variant: str + + width: int + depth: int + mlp_dim: int + num_heads: int + num_kv_heads: int + head_dim: int + norm_eps: float + vocab_size: int + embed_dtype: str + + dropout: float = 0.0 + dropout_bdims: tuple[int, ...] = () # Every float is dropped independently. + cache_dtype: str | None = None + + scan: bool = False + remat_policy: str = "none" + lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict) + + @nn.compact + def __call__( + self, + tokens=None, + embedded_prefix=None, + embed_only=False, # noqa: FBT002 + pre_logits=None, + positions=None, + mask=None, + decode=False, # noqa: FBT002 + kv_cache=None, + deterministic=True, # noqa: FBT002 + return_prelogits=False, # noqa: FBT002 + ): + """Embed only, or complete forward pass. + + Args: + tokens: Embedded, then and appended to `embedded_prefix`. Can be None. + embedded_prefix: Optional prefix that is already embedded. + embed_only: Whether to compute embeddings only. + pre_logits: If present computes logits from pre_logits and returns. + positions: Optional `[B, T]` allows to specify the absolute position of + the tokens. + mask: Optional attention mask `[B, T, S]`. + decode: Whether to use kv-cache. Caller must pass masks and positions. + deterministic: Forwarded to all dropout layers. + return_prelogits: Whether to return the pre-logits. + + Returns: + If `embed_only=False`, then `(logits, out)` will be returned. + If `embed_only=True`, then the embeddings will be returned. + If `return_prelogits=True`, then the pre-logits will be returned. + """ + out = {} + + embedder = Embedder(vocab_size=self.vocab_size, embed_dim=self.width, name="embedder") + + if pre_logits is not None: + x = out["pre_logits"] = pre_logits + logits = out["logits"] = embedder.decode(x) + return logits, out + + x = [] + if embedded_prefix is not None: + x.append(embedded_prefix) + if tokens is not None: + x.append(embedder.encode(tokens)) + + x = jnp.concatenate(x, axis=-2) + x = x.astype(self.embed_dtype) + batch_size, seq_len, width = x.shape + + if embed_only: + return x + + if decode: + assert positions is not None and mask is not None, ( # noqa: PT018 + "Must explicitly pass positions and mask for decoding." + ) + + if positions is None: + positions = jnp.arange(seq_len).astype(jnp.int32)[None, :] + assert positions.shape[1] == x.shape[1], (positions.shape, x.shape) + + if mask is None: + mask = nn.attention.make_causal_mask(jnp.ones([batch_size, seq_len])) + if mask.ndim == 3: + mask = mask[:, None, :, :] + cache_size = max(seq_len, mask.shape[-1]) + assert mask.shape == (batch_size, 1, seq_len, cache_size), mask.shape + + if self.remat_policy == "none": + block_cls = Block + else: + block_cls = nn.remat( + Block, + prevent_cse=not self.scan, + static_argnums=(5, 6), # 0=self, 5=decode, 6=deterministic + policy=getattr(jax.checkpoint_policies, self.remat_policy), + ) + + block_kw = { + "num_heads": self.num_heads, + "head_dim": self.head_dim, + "num_kv_heads": self.num_kv_heads, + "embed_dim": width, + "hidden_dim": self.mlp_dim, + "dropout": self.dropout, + "dropout_bdims": self.dropout_bdims, + "cache_dtype": self.cache_dtype, + "lora_configs": self.lora_configs, + } + layers = self.scope.push("layers") + blocks = [ + nn.scan( + block_cls, + variable_axes={"params": 0}, + split_rngs={"params": True, "dropout": True}, + in_axes=(0, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast), # 0=kv_cache, 1=positions, 2=mask + length=self.depth, + )(parent=layers, **block_kw) + ] + for block in blocks: + x, kv_cache = block(x, kv_cache, positions, mask, decode, deterministic) + + assert x.dtype == jnp.dtype(self.embed_dtype) # Sanity check. + out["encoded"] = x + + x = RMSNorm(name="final_norm")(x) + out["pre_logits"] = x + if return_prelogits: + return x, kv_cache, out + + x = embedder.decode(x) + out["logits"] = x + + return x, kv_cache, out + + def init(self): + """Convenience method for initializing all parameters, necessary due to the quirks of linen.""" + self(jnp.zeros((1, 1), dtype=jnp.int32)) + + +def _apply_rope(x, *, positions, max_wavelength=10_000): + """Applies RoPE positions [B, L] to x [B, L, H, D].""" + freq_exponents = (2.0 / x.shape[-1]) * jnp.arange(x.shape[-1] // 2, dtype=jnp.float32) + timescale = max_wavelength**freq_exponents + radians = positions[..., None] / timescale[None, None, :] + radians = radians[..., None, :] + assert radians.dtype == jnp.float32 + # radians.shape = [...,L,1,d=D/2] + sin, cos = jnp.sin(radians), jnp.cos(radians) + x1, x2 = jnp.split(x, 2, axis=-1) + res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1) + assert res.dtype == jnp.float32 + return res diff --git a/capvector-pi05/src/openpi/models/lora.py b/capvector-pi05/src/openpi/models/lora.py new file mode 100644 index 0000000000000000000000000000000000000000..0524f2e10023837634e3855d87f35a53736ccd07 --- /dev/null +++ b/capvector-pi05/src/openpi/models/lora.py @@ -0,0 +1,148 @@ +import math +import re + +import flax.linen as nn +import flax.struct as struct +import jax.numpy as jnp + +import openpi.shared.array_typing as at + + +@struct.dataclass +class LoRAConfig: + """Configuration for LoRA.""" + + # LoRA rank. + rank: int + # LoRA scaling factor. + alpha: float = 1.0 + # Initialization function for LoRA parameters. + init_fn: nn.initializers.Initializer = nn.initializers.normal(stddev=0.01) + # Enable rank-stabilized LoRA: https://arxiv.org/pdf/2312.03732 + rslora: bool = False + # Axes in the weight to apply LoRA to. Should typically be the last two axes. + axes: tuple[int, int] = (-2, -1) + # Axis label which is used by LoRA in einsum equations. Must not be present in the original equation. + label: str = "L" + + @property + def scaling_value(self) -> float: + return self.alpha / math.sqrt(self.rank) if self.rslora else self.alpha / self.rank + + +class Einsum(nn.Module): + """Einsum with LoRA support. Can be used as a drop-in replacement for the Gemma Einsum.""" + + # Shape of the weight. + shape: tuple[int, ...] + # Initialization function for the weight. + init_fn: nn.initializers.Initializer = nn.initializers.zeros + # If not None, apply LoRA to the weight. + lora_config: LoRAConfig | None = None + + def setup(self): + self.w = self.param("w", self.init_fn, self.shape) + + if config := self.lora_config: + # Setup LoRA parameters. + shape_a, shape_b = list(self.shape), list(self.shape) + shape_a[config.axes[1]] = config.rank + shape_b[config.axes[0]] = config.rank + self.w_a = self.param("lora_a", config.init_fn, shape_a) + self.w_b = self.param("lora_b", config.init_fn, shape_b) + + @nn.compact + def __call__(self, eqn: str, x): + dtype = x.dtype # original dtype, could be half-precision + result = jnp.einsum(eqn, x, self.w.astype(dtype)) + + if config := self.lora_config: + eqn_a, eqn_b = self._make_lora_eqns(eqn) + lora = jnp.einsum(eqn_a, x, self.w_a.astype(dtype)) + lora = jnp.einsum(eqn_b, lora, self.w_b.astype(dtype)) + result = result + lora * config.scaling_value + + return result + + def _make_lora_eqns(self, eqn: str) -> tuple[str, str]: + if "L" in eqn: + raise ValueError(f"L already in eqn: {eqn}") + if not (m := re.match("(.*),(.*)->(.*)", eqn)): + raise ValueError(f"Unsupported einsum eqn: {eqn}") + lhs, rhs, out = m.groups() + + assert self.lora_config is not None + a_label, b_label = (rhs[x] for x in self.lora_config.axes) + label = self.lora_config.label + + a_rhs = rhs.replace(b_label, label) + a_out = out.replace(b_label, label) + eqn_a = f"{lhs},{a_rhs}->{a_out}" + + b_rhs = rhs.replace(a_label, label) + eqn_b = f"{a_out},{b_rhs}->{out}" + + return eqn_a, eqn_b + + +class FeedForward(nn.Module): + """Feed forward module.""" + + features: int + hidden_dim: int + # If not None, apply LoRA to the weight. + lora_config: LoRAConfig | None = None + + def setup(self): + self.w_gating = self.param( + "gating_einsum", + nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)), + (2, self.features, self.hidden_dim), + ) + self.w_linear = self.param( + "linear", + nn.initializers.lecun_normal(in_axis=-2, out_axis=-1), + (self.hidden_dim, self.features), + ) + self.w_gating_lora = None + self.w_linear_lora = None + if self.lora_config: + # Setup LoRA parameters. + # TODO: follow up with a simplified init_fn api. + self.w_gating_lora = ( + self.param("gating_einsum_lora_a", self.lora_config.init_fn, (2, self.features, self.lora_config.rank)), + self.param( + "gating_einsum_lora_b", self.lora_config.init_fn, (2, self.lora_config.rank, self.hidden_dim) + ), + ) + self.w_linear_lora = ( + self.param("linear_lora_a", self.lora_config.init_fn, (self.hidden_dim, self.lora_config.rank)), + self.param("linear_lora_b", self.lora_config.init_fn, (self.lora_config.rank, self.features)), + ) + + @nn.compact + def __call__(self, x): + dtype = x.dtype # original dtype, could be half-precision + ff_gate = self._dot( + x, + self.w_gating[0], + None if self.w_gating_lora is None else (self.w_gating_lora[0][0], self.w_gating_lora[1][0]), + ) + gate_value = nn.gelu(ff_gate) + + ff1 = self._dot( + x, + self.w_gating[1], + None if self.w_gating_lora is None else (self.w_gating_lora[0][1], self.w_gating_lora[1][1]), + ) + activations = gate_value * ff1 + + outputs = self._dot(activations, self.w_linear, self.w_linear_lora) + assert outputs.dtype == dtype + return outputs + + def _dot(self, x: at.Array, w: at.Array, lora_weights: tuple[at.Array, at.Array] | None) -> at.Array: + base = jnp.dot(x, w.astype(x.dtype)) + if lora_weights is None: + return base + return base + jnp.dot(jnp.dot(x, lora_weights[0].astype(x.dtype)), lora_weights[1].astype(x.dtype)) diff --git a/capvector-pi05/src/openpi/models/lora_test.py b/capvector-pi05/src/openpi/models/lora_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d303c025000565b8c3002cb9a42ac8246a2a4936 --- /dev/null +++ b/capvector-pi05/src/openpi/models/lora_test.py @@ -0,0 +1,94 @@ +import flax.linen as nn +import jax +import jax.numpy as jnp + +import openpi.models.lora as lora + + +def test_lora_einsum_params_shape(): + shape = (3, 8, 32, 4) # (3KDH) + einsum = lora.Einsum(shape) + lora0 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2)) + lora1 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2, axes=(1, 2))) + + key = jax.random.key(0) + x = jax.random.normal(key, (8, 64, 32)) # (BSD) + eqn = "BSD,3KDH->3BSKH" + + # Ensure that lora parameters are not initialized when LoRA is not used. + params = einsum.init(key, eqn, x) + assert "lora_a" not in params["params"] + assert "lora_b" not in params["params"] + + # Check that default axes work. + params_lora0 = lora0.init(key, eqn, x) + assert params_lora0["params"]["lora_a"].shape == (3, 8, 32, 2) + assert params_lora0["params"]["lora_b"].shape == (3, 8, 2, 4) + + # Check that user provided axes work. + params_lora1 = lora1.init(key, eqn, x) + assert params_lora1["params"]["lora_a"].shape == (3, 8, 2, 4) + assert params_lora1["params"]["lora_b"].shape == (3, 2, 32, 4) + + +def test_lora_einsum_same_output(): + shape = (3, 8, 32, 4) # (3KDH) + einsum = lora.Einsum(shape) + einsum_lora = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros)) + + key = jax.random.key(0) + x = jax.random.normal(key, (8, 64, 32)) # (BSD) + eqn = "BSD,3KDH->3BSKH" + + params = einsum.init(key, eqn, x) + output = einsum.apply(params, eqn, x) + + params_lora = einsum_lora.init(key, eqn, x) + output_lora = einsum_lora.apply(params_lora, eqn, x) + + # Results are the same since the LoRA parameters are initialized to zeros. + assert jnp.allclose(output, output_lora) + + +def test_lora_ffn_params_shape(): + ffn = lora.FeedForward(features=8, hidden_dim=32) + ffn_lora = lora.FeedForward( + features=8, + hidden_dim=32, + lora_config=lora.LoRAConfig(rank=2), + ) + + key = jax.random.key(0) + x = jax.random.normal(key, (2, 8)) + + params = ffn.init(key, x) + assert params["params"]["gating_einsum"].shape == (2, 8, 32) + assert params["params"]["linear"].shape == (32, 8) + + params_lora = ffn_lora.init(key, x) + assert params_lora["params"]["gating_einsum"].shape == (2, 8, 32) + assert params_lora["params"]["linear"].shape == (32, 8) + assert params_lora["params"]["gating_einsum_lora_a"].shape == (2, 8, 2) + assert params_lora["params"]["gating_einsum_lora_b"].shape == (2, 2, 32) + assert params_lora["params"]["linear_lora_a"].shape == (32, 2) + assert params_lora["params"]["linear_lora_b"].shape == (2, 8) + + +def test_lora_ffn_same_output(): + ffn = lora.FeedForward(features=8, hidden_dim=32) + ffn_lora = lora.FeedForward( + features=8, + hidden_dim=32, + lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros), + ) + + key = jax.random.key(0) + x = jax.random.normal(key, (2, 8)) + + params = ffn.init(key, x) + output = ffn.apply(params, x) + + params_lora = ffn_lora.init(key, x) + output_lora = ffn_lora.apply(params_lora, x) + + assert jnp.allclose(output, output_lora) diff --git a/capvector-pi05/src/openpi/models/model.py b/capvector-pi05/src/openpi/models/model.py new file mode 100644 index 0000000000000000000000000000000000000000..f097be4e3558084b017c7bceabc7d78a159d7da6 --- /dev/null +++ b/capvector-pi05/src/openpi/models/model.py @@ -0,0 +1,335 @@ +import abc +from collections.abc import Sequence +import dataclasses +import enum +import logging +import pathlib +from typing import Generic, TypeVar + +import augmax +from flax import nnx +from flax import struct +from flax import traverse_util +import jax +import jax.numpy as jnp +import numpy as np +import orbax.checkpoint as ocp +import safetensors +import torch + +from openpi.models_pytorch import pi0_pytorch +from openpi.shared import image_tools +import openpi.shared.array_typing as at + +logger = logging.getLogger("openpi") + +# Type variable for array types (JAX arrays, PyTorch tensors, or numpy arrays) +ArrayT = TypeVar("ArrayT", bound=jax.Array | torch.Tensor | np.ndarray) + + +class ModelType(enum.Enum): + """Supported model types.""" + + PI0 = "pi0" + PI0_FAST = "pi0_fast" + PI05 = "pi05" + + +# The model always expects these images +IMAGE_KEYS = ( + "base_0_rgb", + "left_wrist_0_rgb", + "right_wrist_0_rgb", +) + + +# This may need change if we release a small model. +IMAGE_RESOLUTION = (224, 224) + + +# Data format +# +# Data transforms produce the model input as a nested dictionary which is later converted +# into `Obesrvation` and `Actions` objects. See below. +# +# In the dictory form, this data should look like: +# { +# # Observation data. +# "image": { +# "base_0_rgb": (float32|uint8)[*b, h, w, 3], # RGB image in [-1, 1] or [0, 255] +# ... # Additional camera views +# }, +# "image_mask": { +# "base_0_rgb": bool[*b], # True if image is valid +# ... # Masks for additional views +# }, +# "state": float32[*b, s], # Low-dimensional robot state +# "tokenized_prompt": int32[*b, l], # Optional, tokenized language prompt +# "tokenized_prompt_mask": bool[*b, l], # Optional, mask for tokenized prompt +# "token_ar_mask": int32[*b, l], # Optional, autoregressive mask for FAST model +# "token_loss_mask": bool[*b, l], # Optional, loss mask for FAST model +# +# # Actions data. +# "actions": float32[*b ah ad] +# } +# where: +# *b = batch dimensions +# h,w = image height/width +# s = state dimension +# l = sequence length +# +@at.typecheck +@struct.dataclass +class Observation(Generic[ArrayT]): + """Holds observations, i.e., inputs to the model. + + See `Observation.from_dict` to see the expected dictionary form. This is the format + that should be produced by the data transforms. + """ + + # Images, in [-1, 1] float32. + images: dict[str, at.Float[ArrayT, "*b h w c"]] + # the padding area for non-rectangular input images is False + image_padding_mask: dict[str, at.Bool[ArrayT, "*b w c"]] + # Image masks, with same keys as images. + image_masks: dict[str, at.Bool[ArrayT, "*b"]] + # Low-dimensional robot state. + state: at.Float[ArrayT, "*b s"] + + # Tokenized prompt. + tokenized_prompt: at.Int[ArrayT, "*b l"] | None = None + # Tokenized prompt mask. + tokenized_prompt_mask: at.Bool[ArrayT, "*b l"] | None = None + + # pi0-fast model specific fields. + + # Token auto-regressive mask (for FAST autoregressive model). + token_ar_mask: at.Int[ArrayT, "*b l"] | None = None + # Token loss mask (for FAST autoregressive model). + token_loss_mask: at.Bool[ArrayT, "*b l"] | None = None + + @classmethod + def from_dict(cls, data: at.PyTree[ArrayT]) -> "Observation[ArrayT]": + """This method defines the mapping between unstructured data (i.e., nested dict) to the structured Observation format.""" + # Ensure that tokenized_prompt and tokenized_prompt_mask are provided together. + if ("tokenized_prompt" in data) != ("tokenized_prompt_mask" in data): + raise ValueError("tokenized_prompt and tokenized_prompt_mask must be provided together.") + # If images are uint8, convert them to [-1, 1] float32. + for key in data["image"]: + if data["image"][key].dtype == np.uint8: + data["image"][key] = data["image"][key].astype(np.float32) / 255.0 * 2.0 - 1.0 + elif hasattr(data["image"][key], "dtype") and data["image"][key].dtype == torch.uint8: + data["image"][key] = data["image"][key].to(torch.float32).permute(0, 3, 1, 2) / 255.0 * 2.0 - 1.0 + return cls( + images=data["image"], + image_padding_mask=data.get("image_padding_mask", {}), + image_masks=data["image_mask"], + state=data["state"], + tokenized_prompt=data.get("tokenized_prompt"), + tokenized_prompt_mask=data.get("tokenized_prompt_mask"), + token_ar_mask=data.get("token_ar_mask"), + token_loss_mask=data.get("token_loss_mask"), + ) + + def to_dict(self) -> at.PyTree[ArrayT]: + """Convert the Observation to a nested dict.""" + result = dataclasses.asdict(self) + result["image"] = result.pop("images") + result["image_mask"] = result.pop("image_masks") + return result + + +# Defines the format of the actions. This field is included as "actions" inside the dictionary +# produced by the data transforms. +Actions = at.Float[ArrayT, "*b ah ad"] + + +def preprocess_observation( + rng: at.KeyArrayLike | None, + observation: Observation, + *, + train: bool = False, + image_keys: Sequence[str] = IMAGE_KEYS, + image_resolution: tuple[int, int] = IMAGE_RESOLUTION, +) -> Observation: + """Preprocess the observations by performing image augmentations (if train=True), resizing (if necessary), and + filling in a default image mask (if necessary). + """ + + if not set(image_keys).issubset(observation.images): + raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}") + + batch_shape = observation.state.shape[:-1] + + out_images = {} + for key in image_keys: + image = observation.images[key] + if image.shape[1:3] != image_resolution: + logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}") + image = image_tools.resize_with_pad(image, *image_resolution) + + if train: + # Convert from [-1, 1] to [0, 1] for augmax. + image = image / 2.0 + 0.5 + + transforms = [] + if "wrist" not in key: + height, width = image.shape[1:3] + transforms += [ + augmax.RandomCrop(int(width * 0.95), int(height * 0.95)), + augmax.Resize(width, height), + augmax.Rotate((-5, 5)), + ] + transforms += [ + augmax.ColorJitter(brightness=0.3, contrast=0.4, saturation=0.5), + ] + sub_rngs = jax.random.split(rng, image.shape[0]) + image = jax.vmap(augmax.Chain(*transforms))(sub_rngs, image) + + # Back to [-1, 1]. + image = image * 2.0 - 1.0 + + out_images[key] = image + + # obtain mask + out_masks = {} + for key in out_images: + if key not in observation.image_masks: + # do not mask by default + out_masks[key] = jnp.ones(batch_shape, dtype=jnp.bool) + else: + out_masks[key] = jnp.asarray(observation.image_masks[key]) + + return Observation( + images=out_images, + image_masks=out_masks, + state=observation.state, + tokenized_prompt=observation.tokenized_prompt, + tokenized_prompt_mask=observation.tokenized_prompt_mask, + token_ar_mask=observation.token_ar_mask, + token_loss_mask=observation.token_loss_mask, + ) + + +@dataclasses.dataclass(frozen=True) +class BaseModelConfig(abc.ABC): + """Configuration shared by all models. Specific models should inherit from this class, and implement the `create` + method to create the corresponding model. + """ + + # Action space dimension. + action_dim: int + # Action sequence length. + action_horizon: int + # Tokenized prompt maximum length. + max_token_len: int + + @property + @abc.abstractmethod + def model_type(self) -> ModelType: + """The model type.""" + + @abc.abstractmethod + def create(self, rng: at.KeyArrayLike) -> "BaseModel": + """Create a new model, initializing parameters.""" + + def load(self, params: at.Params, *, remove_extra_params: bool = True) -> "BaseModel": + """Create a model with the given parameters.""" + model = nnx.eval_shape(self.create, jax.random.key(0)) + graphdef, state = nnx.split(model) + if remove_extra_params: + params = ocp.transform_utils.intersect_trees(state.to_pure_dict(), params) + at.check_pytree_equality(expected=state.to_pure_dict(), got=params, check_shapes=True, check_dtypes=False) + state.replace_by_pure_dict(params) + return nnx.merge(graphdef, state) + + def load_pytorch(self, train_config, weight_path: str): + logger.info(f"train_config: {train_config}") + model = pi0_pytorch.PI0Pytorch(config=train_config.model) + safetensors.torch.load_model(model, weight_path) + return model + + @abc.abstractmethod + def inputs_spec(self, *, batch_size: int = 1) -> tuple[Observation, Actions]: + """Returns the input specification for the model. Values are jax.ShapeDtypeStruct.""" + + def fake_obs(self, batch_size: int = 1) -> Observation: + observation_spec, _ = self.inputs_spec(batch_size=batch_size) + return jax.tree.map(lambda x: jnp.ones(x.shape, x.dtype), observation_spec) + + def fake_act(self, batch_size: int = 1) -> Actions: + _, action_spec = self.inputs_spec(batch_size=batch_size) + return jax.tree.map(lambda x: jnp.ones(x.shape, x.dtype), action_spec) + + +@dataclasses.dataclass +class BaseModel(nnx.Module, abc.ABC): + """Base class for all model implementations. Specific models should inherit from this class. They should call + super().__init__() to initialize the shared attributes (action_dim, action_horizon, and max_token_len). + """ + + action_dim: int + action_horizon: int + max_token_len: int + + @abc.abstractmethod + def compute_loss( + self, + rng: at.KeyArrayLike, + observation: Observation, + actions: Actions, + *, + train: bool = False, + ) -> at.Float[at.Array, "*b ah"]: ... + + @abc.abstractmethod + def sample_actions(self, rng: at.KeyArrayLike, observation: Observation, **kwargs) -> Actions: ... + + +def restore_params( + params_path: pathlib.Path | str, + *, + restore_type: type[np.ndarray] | type[jax.Array] = jax.Array, + dtype: jnp.dtype | None = None, + sharding: jax.sharding.Sharding | None = None, +) -> at.Params: + """Restores unstructured params PyTree from a checkpoint. + + This works with checkpoints saved with `save_state` during openpi training (see `training/checkpoints.py`) as + well as pre-trained checkpoints released for openpi. + + Args: + params_path: The local path to the checkpoint directory. + restore_type: The type to restore the params as. Can be set to `np.ndarray` to load the params as a numpy array. + dtype: The dtype to restore all params as. If not provided, will use the original dtype from the checkpoint. + sharding: The sharding to use for the params. If not provided, the params will be replicated across all devices. + + Returns: + The restored params. + """ + params_path = pathlib.Path(params_path).resolve() if not str(params_path).startswith("gs://") else params_path + + if restore_type is jax.Array and sharding is None: + mesh = jax.sharding.Mesh(jax.devices(), ("x",)) + sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) + + with ocp.PyTreeCheckpointer() as ckptr: + metadata = ckptr.metadata(params_path) + item = {"params": metadata["params"]} + + params = ckptr.restore( + params_path, + ocp.args.PyTreeRestore( + item=item, + restore_args=jax.tree.map( + lambda _: ocp.ArrayRestoreArgs(sharding=sharding, restore_type=restore_type, dtype=dtype), item + ), + ), + )["params"] + + # If the params were saved with `save_state` during openpi training, every key path will end with "value", which is + # added by `nnx.State`. We remove the "value" suffix here and always return what NNX calls a "pure dict". + flat_params = traverse_util.flatten_dict(params) + if all(kp[-1] == "value" for kp in flat_params): + flat_params = {kp[:-1]: v for kp, v in flat_params.items()} + return traverse_util.unflatten_dict(flat_params) diff --git a/capvector-pi05/src/openpi/models/model_test.py b/capvector-pi05/src/openpi/models/model_test.py new file mode 100644 index 0000000000000000000000000000000000000000..528dc32f9c0f42d5008f6496deec056141f8cb03 --- /dev/null +++ b/capvector-pi05/src/openpi/models/model_test.py @@ -0,0 +1,94 @@ +from flax import nnx +import jax +import pytest + +from openpi.models import model as _model +from openpi.models import pi0_config +from openpi.models import pi0_fast +from openpi.shared import download +from openpi.shared import nnx_utils + + +def test_pi0_model(): + key = jax.random.key(0) + config = pi0_config.Pi0Config() + model = config.create(key) + + batch_size = 2 + obs, act = config.fake_obs(batch_size), config.fake_act(batch_size) + + loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act) + assert loss.shape == (batch_size, config.action_horizon) + + actions = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10) + assert actions.shape == (batch_size, model.action_horizon, model.action_dim) + + +def test_pi0_lora_model(): + key = jax.random.key(0) + config = pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora") + model = config.create(key) + + batch_size = 2 + obs, act = config.fake_obs(batch_size), config.fake_act(batch_size) + + loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act) + assert loss.shape == (batch_size, config.action_horizon) + + actions = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10) + assert actions.shape == (batch_size, model.action_horizon, model.action_dim) + + +def test_pi0_fast_model(): + key = jax.random.key(0) + config = pi0_fast.Pi0FASTConfig() + model = config.create(key) + + batch_size = 2 + obs, act = config.fake_obs(batch_size), config.fake_act(batch_size) + + loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act) + assert loss.shape == (batch_size,) + + actions = nnx_utils.module_jit(model.sample_actions)(key, obs) + assert actions.shape == (batch_size, 256) + + +def test_pi0_fast_lora_model(): + key = jax.random.key(0) + config = pi0_fast.Pi0FASTConfig(paligemma_variant="gemma_2b_lora") + model = config.create(key) + + batch_size = 2 + obs, act = config.fake_obs(batch_size), config.fake_act(batch_size) + + loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act) + assert loss.shape == (batch_size,) + + actions = nnx_utils.module_jit(model.sample_actions)(key, obs) + assert actions.shape == (batch_size, 256) + + lora_filter = nnx_utils.PathRegex(".*lora.*") + model_state = nnx.state(model) + + lora_state_elems = list(model_state.filter(lora_filter)) + assert len(lora_state_elems) > 0 + + +@pytest.mark.manual +def test_model_restore(): + key = jax.random.key(0) + config = pi0_config.Pi0Config() + + batch_size = 2 + obs, act = config.fake_obs(batch_size), config.fake_act(batch_size) + + model = config.load( + _model.restore_params(download.maybe_download("gs://openpi-assets/checkpoints/pi0_base/params")) + ) + + loss = model.compute_loss(key, obs, act) + assert loss.shape == (batch_size, config.action_horizon) + + actions = model.sample_actions(key, obs, num_steps=10) + assert actions.shape == (batch_size, model.action_horizon, model.action_dim) diff --git a/capvector-pi05/src/openpi/models/pi0.py b/capvector-pi05/src/openpi/models/pi0.py new file mode 100644 index 0000000000000000000000000000000000000000..90fd7935a3f46c86b3100a42db701e8af719cbff --- /dev/null +++ b/capvector-pi05/src/openpi/models/pi0.py @@ -0,0 +1,279 @@ +import logging + +import einops +import flax.nnx as nnx +import flax.nnx.bridge as nnx_bridge +import jax +import jax.numpy as jnp +from typing_extensions import override + +from openpi.models import model as _model +from openpi.models import pi0_config +import openpi.models.gemma as _gemma +import openpi.models.siglip as _siglip +from openpi.shared import array_typing as at + +logger = logging.getLogger("openpi") + + +def make_attn_mask(input_mask, mask_ar): + """Adapted from big_vision. + + Tokens can attend to valid inputs tokens which have a cumulative mask_ar + smaller or equal to theirs. This way `mask_ar` bool[?B, N] can be used to + setup several types of attention, for example: + + [[1 1 1 1 1 1]]: pure causal attention. + + [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between + themselves and the last 3 tokens have a causal attention. The first + entry could also be a 1 without changing behaviour. + + [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a + block can attend all previous blocks and all tokens on the same block. + + Args: + input_mask: bool[B, N] true if its part of the input, false if padding. + mask_ar: bool[?B, N] mask that's true where previous tokens cannot depend on + it and false where it shares the same attention mask as the previous token. + """ + mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape) + cumsum = jnp.cumsum(mask_ar, axis=1) + attn_mask = cumsum[:, None, :] <= cumsum[:, :, None] + valid_mask = input_mask[:, None, :] * input_mask[:, :, None] + return jnp.logical_and(attn_mask, valid_mask) + + +@at.typecheck +def posemb_sincos( + pos: at.Real[at.Array, " b"], embedding_dim: int, min_period: float, max_period: float +) -> at.Float[at.Array, "b {embedding_dim}"]: + """Computes sine-cosine positional embedding vectors for scalar positions.""" + if embedding_dim % 2 != 0: + raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by 2") + + fraction = jnp.linspace(0.0, 1.0, embedding_dim // 2) + period = min_period * (max_period / min_period) ** fraction + sinusoid_input = jnp.einsum( + "i,j->ij", + pos, + 1.0 / period * 2 * jnp.pi, + precision=jax.lax.Precision.HIGHEST, + ) + return jnp.concatenate([jnp.sin(sinusoid_input), jnp.cos(sinusoid_input)], axis=-1) + + +class Pi0(_model.BaseModel): + def __init__(self, config: pi0_config.Pi0Config, rngs: nnx.Rngs): + super().__init__(config.action_dim, config.action_horizon, config.max_token_len) + self.pi05 = config.pi05 + paligemma_config = _gemma.get_config(config.paligemma_variant) + action_expert_config = _gemma.get_config(config.action_expert_variant) + # TODO: rewrite gemma in NNX. For now, use bridge. + llm = nnx_bridge.ToNNX( + _gemma.Module( + configs=[paligemma_config, action_expert_config], + embed_dtype=config.dtype, + adarms=config.pi05, + ) + ) + llm.lazy_init(rngs=rngs, method="init", use_adarms=[False, True] if config.pi05 else [False, False]) + img = nnx_bridge.ToNNX( + _siglip.Module( + num_classes=paligemma_config.width, + variant="So400m/14", + pool_type="none", + scan=True, + dtype_mm=config.dtype, + ) + ) + img.lazy_init(next(iter(config.fake_obs().images.values())), train=False, rngs=rngs) + self.PaliGemma = nnx.Dict(llm=llm, img=img) + self.action_in_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs) + if config.pi05: + self.time_mlp_in = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs) + self.time_mlp_out = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs) + else: + self.state_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs) + self.action_time_mlp_in = nnx.Linear(2 * action_expert_config.width, action_expert_config.width, rngs=rngs) + self.action_time_mlp_out = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs) + self.action_out_proj = nnx.Linear(action_expert_config.width, config.action_dim, rngs=rngs) + + # This attribute gets automatically set by model.train() and model.eval(). + self.deterministic = True + + @at.typecheck + def embed_prefix( + self, obs: _model.Observation + ) -> tuple[at.Float[at.Array, "b s emb"], at.Bool[at.Array, "b s"], at.Bool[at.Array, " s"]]: + input_mask = [] + ar_mask = [] + tokens = [] + # embed images + for name in obs.images: + image_tokens, _ = self.PaliGemma.img(obs.images[name], train=False) + + tokens.append(image_tokens) + input_mask.append( + einops.repeat( + obs.image_masks[name], + "b -> b s", + s=image_tokens.shape[1], + ) + ) + # image tokens attend to each other + ar_mask += [False] * image_tokens.shape[1] + + # add language (aka tokenized inputs) + if obs.tokenized_prompt is not None: + tokenized_inputs = self.PaliGemma.llm(obs.tokenized_prompt, method="embed") + tokens.append(tokenized_inputs) + input_mask.append(obs.tokenized_prompt_mask) + # full attention between image and language inputs + ar_mask += [False] * tokenized_inputs.shape[1] + tokens = jnp.concatenate(tokens, axis=1) + input_mask = jnp.concatenate(input_mask, axis=1) + ar_mask = jnp.array(ar_mask) + return tokens, input_mask, ar_mask + + @at.typecheck + def embed_suffix( + self, obs: _model.Observation, noisy_actions: _model.Actions, timestep: at.Float[at.Array, " b"] + ) -> tuple[ + at.Float[at.Array, "b s emb"], + at.Bool[at.Array, "b s"], + at.Bool[at.Array, " s"], + at.Float[at.Array, "b emb"] | None, + ]: + input_mask = [] + ar_mask = [] + tokens = [] + if not self.pi05: + # add a single state token + state_token = self.state_proj(obs.state)[:, None, :] + tokens.append(state_token) + input_mask.append(jnp.ones((obs.state.shape[0], 1), dtype=jnp.bool_)) + # image/language inputs do not attend to state or actions + ar_mask += [True] + + action_tokens = self.action_in_proj(noisy_actions) + # embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1] + time_emb = posemb_sincos(timestep, self.action_in_proj.out_features, min_period=4e-3, max_period=4.0) + if self.pi05: + # time MLP (for adaRMS) + time_emb = self.time_mlp_in(time_emb) + time_emb = nnx.swish(time_emb) + time_emb = self.time_mlp_out(time_emb) + time_emb = nnx.swish(time_emb) + action_expert_tokens = action_tokens + adarms_cond = time_emb + else: + # mix timestep + action information using an MLP (no adaRMS) + time_tokens = einops.repeat(time_emb, "b emb -> b s emb", s=self.action_horizon) + action_time_tokens = jnp.concatenate([action_tokens, time_tokens], axis=-1) + action_time_tokens = self.action_time_mlp_in(action_time_tokens) + action_time_tokens = nnx.swish(action_time_tokens) + action_time_tokens = self.action_time_mlp_out(action_time_tokens) + action_expert_tokens = action_time_tokens + adarms_cond = None + tokens.append(action_expert_tokens) + input_mask.append(jnp.ones(action_expert_tokens.shape[:2], dtype=jnp.bool_)) + # image/language/state inputs do not attend to action tokens + ar_mask += [True] + ([False] * (self.action_horizon - 1)) + tokens = jnp.concatenate(tokens, axis=1) + input_mask = jnp.concatenate(input_mask, axis=1) + ar_mask = jnp.array(ar_mask) + return tokens, input_mask, ar_mask, adarms_cond + + @override + def compute_loss( + self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool = False + ) -> at.Float[at.Array, "*b ah"]: + preprocess_rng, noise_rng, time_rng = jax.random.split(rng, 3) + observation = _model.preprocess_observation(preprocess_rng, observation, train=train) + + batch_shape = actions.shape[:-2] + noise = jax.random.normal(noise_rng, actions.shape) + time = jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 + 0.001 + time_expanded = time[..., None, None] + x_t = time_expanded * noise + (1 - time_expanded) * actions + u_t = noise - actions + + # one big forward pass of prefix + suffix at once + prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation) + suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix(observation, x_t, time) + input_mask = jnp.concatenate([prefix_mask, suffix_mask], axis=1) + ar_mask = jnp.concatenate([prefix_ar_mask, suffix_ar_mask], axis=0) + attn_mask = make_attn_mask(input_mask, ar_mask) + positions = jnp.cumsum(input_mask, axis=1) - 1 + (prefix_out, suffix_out), _ = self.PaliGemma.llm( + [prefix_tokens, suffix_tokens], mask=attn_mask, positions=positions, adarms_cond=[None, adarms_cond] + ) + v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :]) + + return jnp.mean(jnp.square(v_t - u_t), axis=-1) + + @override + def sample_actions( + self, + rng: at.KeyArrayLike, + observation: _model.Observation, + *, + num_steps: int | at.Int[at.Array, ""] = 10, + noise: at.Float[at.Array, "b ah ad"] | None = None, + ) -> _model.Actions: + observation = _model.preprocess_observation(None, observation, train=False) + # note that we use the convention more common in diffusion literature, where t=1 is noise and t=0 is the target + # distribution. yes, this is the opposite of the pi0 paper, and I'm sorry. + dt = -1.0 / num_steps + batch_size = observation.state.shape[0] + if noise is None: + noise = jax.random.normal(rng, (batch_size, self.action_horizon, self.action_dim)) + + # first fill KV cache with a forward pass of the prefix + prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation) + prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask) + positions = jnp.cumsum(prefix_mask, axis=1) - 1 + _, kv_cache = self.PaliGemma.llm([prefix_tokens, None], mask=prefix_attn_mask, positions=positions) + + def step(carry): + x_t, time = carry + suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix( + observation, x_t, jnp.broadcast_to(time, batch_size) + ) + # `suffix_attn_mask` is shape (b, suffix_len, suffix_len) indicating how the suffix tokens can attend to each + # other + suffix_attn_mask = make_attn_mask(suffix_mask, suffix_ar_mask) + # `prefix_attn_mask` is shape (b, suffix_len, prefix_len) indicating how the suffix tokens can attend to the + # prefix tokens + prefix_attn_mask = einops.repeat(prefix_mask, "b p -> b s p", s=suffix_tokens.shape[1]) + # `combined_mask` is shape (b, suffix_len, prefix_len + suffix_len) indicating how the suffix tokens (which + # generate the queries) can attend to the full prefix + suffix sequence (which generates the keys and values) + full_attn_mask = jnp.concatenate([prefix_attn_mask, suffix_attn_mask], axis=-1) + assert full_attn_mask.shape == ( + batch_size, + suffix_tokens.shape[1], + prefix_tokens.shape[1] + suffix_tokens.shape[1], + ) + # `positions` is shape (b, suffix_len) indicating the positions of the suffix tokens + positions = jnp.sum(prefix_mask, axis=-1)[:, None] + jnp.cumsum(suffix_mask, axis=-1) - 1 + + (prefix_out, suffix_out), _ = self.PaliGemma.llm( + [None, suffix_tokens], + mask=full_attn_mask, + positions=positions, + kv_cache=kv_cache, + adarms_cond=[None, adarms_cond], + ) + assert prefix_out is None + v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :]) + + return x_t + dt * v_t, time + dt + + def cond(carry): + x_t, time = carry + # robust to floating-point error + return time >= -dt / 2 + + x_0, _ = jax.lax.while_loop(cond, step, (noise, 1.0)) + return x_0 diff --git a/capvector-pi05/src/openpi/models/pi0_config.py b/capvector-pi05/src/openpi/models/pi0_config.py new file mode 100644 index 0000000000000000000000000000000000000000..26c97f720d491c917fdb870ff8a85102e3c7a3b5 --- /dev/null +++ b/capvector-pi05/src/openpi/models/pi0_config.py @@ -0,0 +1,108 @@ +import dataclasses +from typing import TYPE_CHECKING + +import flax.nnx as nnx +import jax +import jax.numpy as jnp +from typing_extensions import override + +from openpi.models import model as _model +import openpi.models.gemma as _gemma +from openpi.shared import array_typing as at +import openpi.shared.nnx_utils as nnx_utils + +if TYPE_CHECKING: + from openpi.models.pi0 import Pi0 + + +@dataclasses.dataclass(frozen=True) +class Pi0Config(_model.BaseModelConfig): + dtype: str = "bfloat16" + paligemma_variant: _gemma.Variant = "gemma_2b" + action_expert_variant: _gemma.Variant = "gemma_300m" + + # Set the model specific defaults. + action_dim: int = 32 + action_horizon: int = 50 + max_token_len: int = None # type: ignore + # Pi05 has two differences from Pi0: + # - the state input is part of the discrete language tokens rather than a continuous input that is part of the suffix + # - the action expert uses adaRMSNorm to inject the flow matching timestep + pi05: bool = False + # This config option is not used directly by the model, but it is read by the ModelTransformFactory. + discrete_state_input: bool = None # type: ignore + + def __post_init__(self): + if self.max_token_len is None: + object.__setattr__(self, "max_token_len", 200 if self.pi05 else 48) + if self.discrete_state_input is None: + object.__setattr__(self, "discrete_state_input", self.pi05) + + @property + @override + def model_type(self) -> _model.ModelType: + if self.pi05: + return _model.ModelType.PI05 + return _model.ModelType.PI0 + + @override + def create(self, rng: at.KeyArrayLike) -> "Pi0": + from openpi.models.pi0 import Pi0 + + return Pi0(self, rngs=nnx.Rngs(rng)) + + @override + def inputs_spec(self, *, batch_size: int = 1) -> tuple[_model.Observation, _model.Actions]: + image_spec = jax.ShapeDtypeStruct([batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32) + image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_) + + with at.disable_typechecking(): + observation_spec = _model.Observation( + images={ + "base_0_rgb": image_spec, + "left_wrist_0_rgb": image_spec, + "right_wrist_0_rgb": image_spec, + }, + image_masks={ + "base_0_rgb": image_mask_spec, + "left_wrist_0_rgb": image_mask_spec, + "right_wrist_0_rgb": image_mask_spec, + }, + state=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32), + tokenized_prompt=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32), + tokenized_prompt_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], bool), + ) + action_spec = jax.ShapeDtypeStruct([batch_size, self.action_horizon, self.action_dim], jnp.float32) + + return observation_spec, action_spec + + def get_freeze_filter(self) -> nnx.filterlib.Filter: + """Returns the freeze filter based on the model config.""" + filters = [] + has_lora = False + gemma_params_filter = nnx_utils.PathRegex(".*llm.*") + action_expert_params_filter = nnx_utils.PathRegex(".*llm.*_1.*") + if "lora" in self.paligemma_variant: + filters.append( + gemma_params_filter, + ) + if "lora" not in self.action_expert_variant: + # If only freeze gemma params, exclude action expert params. + filters.append( + nnx.Not(action_expert_params_filter), + ) + has_lora = True + elif "lora" in self.action_expert_variant: + filters.append( + action_expert_params_filter, + ) + has_lora = True + + if has_lora: + # If any lora is used, exclude all lora params. + filters.append( + nnx.Not(nnx_utils.PathRegex(".*lora.*")), + ) + if not filters: + return nnx.Nothing + return nnx.All(*filters) diff --git a/capvector-pi05/src/openpi/models/pi0_fast.py b/capvector-pi05/src/openpi/models/pi0_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..8c3ed5503c27facc30a293d4740195ebc8b5fe46 --- /dev/null +++ b/capvector-pi05/src/openpi/models/pi0_fast.py @@ -0,0 +1,313 @@ +import dataclasses +import logging +from typing import Any + +import einops +import flax.nnx as nnx +import flax.nnx.bridge as nnx_bridge +import jax +import jax.numpy as jnp +from typing_extensions import override + +from openpi.models import model as _model +import openpi.models.gemma_fast as _gemma +import openpi.models.siglip as _siglip +from openpi.shared import array_typing as at +import openpi.shared.nnx_utils as nnx_utils + +logger = logging.getLogger("openpi") + +PALIGEMMA_EOS_TOKEN = 1 + + +def make_attn_mask(input_mask, mask_ar): + """Adapted from big_vision. + + Tokens can attend to valid inputs tokens which have a cumulative mask_ar + smaller or equal to theirs. This way `mask_ar` bool[?B, N] can be used to + setup several types of attention, for example: + + [[1 1 1 1 1 1]]: pure causal attention. + + [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between + themselves and the last 3 tokens have a causal attention. The first + entry could also be a 1 without changing behaviour. + + [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a + block can attend all previous blocks and all tokens on the same block. + + Args: + input_mask: bool[B, N] true if its part of the input, false if padding. + mask_ar: bool[?B, N] mask that's true where previous tokens cannot depend on + it and false where it shares the same attention mask as the previous token. + """ + mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape) + cumsum = jnp.cumsum(mask_ar, axis=1) + attn_mask = cumsum[:, None, :] <= cumsum[:, :, None] + valid_mask = input_mask[:, None, :] * input_mask[:, :, None] + return jnp.logical_and(attn_mask, valid_mask) + + +@jax.vmap +def left_to_right_align(x, input_mask, attn_mask): + """Converts input from left-align to right-aligned.""" + # Due to vmap, this is operating in a single example (not batch level). + assert x.ndim == 2 + assert input_mask.ndim == 1 + assert attn_mask.ndim == 2 + assert x.shape[0] == input_mask.shape[0] + assert attn_mask.shape[0] == attn_mask.shape[1], attn_mask.shape + seqlen = jnp.max(input_mask * jnp.arange(input_mask.shape[0])) + 1 + x = jnp.roll(x, -seqlen, axis=0) + input_mask = jnp.roll(input_mask, -seqlen, axis=0) + attn_mask = jnp.roll(attn_mask, -seqlen, axis=(0, 1)) + return x, input_mask, attn_mask + + +def put_along_last_axis(arr, indices, values): + """Like np.put_along_axis(..., axis=-1), since jax is missing it.""" + assert arr.ndim == indices.ndim == values.ndim, (arr.ndim, indices.ndim, values.ndim) + onehot = jax.nn.one_hot(indices, arr.shape[-1], dtype=values.dtype) + put_mask = jnp.einsum("...i,...in->...n", jnp.ones(values.shape, jnp.int32), onehot) + put_values = jnp.einsum("...i,...in->...n", values, onehot) + return jnp.where(put_mask, put_values, arr) + + +@dataclasses.dataclass(frozen=True) +class Pi0FASTConfig(_model.BaseModelConfig): + dtype: str = "bfloat16" + paligemma_variant: _gemma.Variant = "gemma_2b" + + # Set the model specific defaults. + action_dim: int = 32 + action_horizon: int = 32 + max_token_len: int = 250 + + # Tokenizer for the fast model. + fast_model_tokenizer: Any | None = None + # Keyword arguments for the fast model tokenizer. + fast_model_tokenizer_kwargs: dict[str, Any] | None = None + + @property + @override + def model_type(self) -> _model.ModelType: + return _model.ModelType.PI0_FAST + + @override + def create(self, rng: at.KeyArrayLike) -> "Pi0FAST": + return Pi0FAST(self, rngs=nnx.Rngs(rng)) + + @override + def inputs_spec(self, *, batch_size: int = 1) -> tuple[_model.Observation, _model.Actions]: + image_spec = jax.ShapeDtypeStruct([batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32) + image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_) + + with at.disable_typechecking(): + observation_spec = _model.Observation( + images={ + "base_0_rgb": image_spec, + "base_1_rgb": image_spec, + "wrist_0_rgb": image_spec, + }, + image_masks={ + "base_0_rgb": image_mask_spec, + "base_1_rgb": image_mask_spec, + "wrist_0_rgb": image_mask_spec, + }, + state=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32), + tokenized_prompt=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32), + tokenized_prompt_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], bool), + token_ar_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32), + token_loss_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.bool_), + ) + action_spec = jax.ShapeDtypeStruct([batch_size, self.action_horizon, self.action_dim], jnp.float32) + + return observation_spec, action_spec + + def get_freeze_filter(self) -> nnx.filterlib.Filter: + """Returns the freeze filter based on the model config.""" + if "lora" in self.paligemma_variant: + return nnx.All(nnx_utils.PathRegex(".*llm.*"), nnx.Not(nnx_utils.PathRegex(".*lora.*"))) + return nnx.Nothing + + +class Pi0FAST(_model.BaseModel): + def __init__(self, config: Pi0FASTConfig, rngs: nnx.Rngs): + super().__init__(config.action_dim, config.action_horizon, config.max_token_len) + paligemma_config = _gemma.get_config(config.paligemma_variant) + # TODO: rewrite gemma in NNX. For now, use bridge. + llm = nnx_bridge.ToNNX( + _gemma.Module( + **paligemma_config, + embed_dtype=config.dtype, + cache_dtype=config.dtype, + ) + ) + llm.lazy_init(rngs=rngs, method="init") + img = nnx_bridge.ToNNX( + _siglip.Module( + num_classes=paligemma_config.width, + variant="So400m/14", + pool_type="none", + scan=True, + dtype_mm=config.dtype, + ) + ) + img.lazy_init(next(iter(config.fake_obs().images.values())), train=False, rngs=rngs) + self.PaliGemma = nnx.Dict(llm=llm, img=img) + + @at.typecheck + def embed_inputs( + self, obs: _model.Observation + ) -> tuple[at.Float[at.Array, "b s emb"], at.Bool[at.Array, "b s"], at.Int[at.Array, "b s"]]: + input_mask = [] + ar_mask = [] + token_embeddings = [] + # embed images + for name in obs.images: + image_token_embeddings, _ = self.PaliGemma.img(obs.images[name], train=False) + + token_embeddings.append(image_token_embeddings) + input_mask.append( + einops.repeat( + obs.image_masks[name], + "b -> b s", + s=image_token_embeddings.shape[1], + ) + ) + # image tokens attend to each other --> AR mask = 0 + ar_mask.append(0 * input_mask[-1]) + + # add tokenized inputs + assert obs.tokenized_prompt is not None, "Tokenized prompt is required" + assert obs.tokenized_prompt_mask is not None, "Tokenized prompt mask is required" + assert obs.token_ar_mask is not None, "Token auto-regressive mask is required" + tokenized_inputs_embeddings = self.PaliGemma.llm(obs.tokenized_prompt, embed_only=True) + token_embeddings.append(tokenized_inputs_embeddings) + input_mask.append(obs.tokenized_prompt_mask) + ar_mask.append(obs.token_ar_mask) + + # return embeddings, input mask, and ar mask + return ( + jnp.concatenate(token_embeddings, axis=1), + jnp.concatenate(input_mask, axis=1), + jnp.concatenate(ar_mask, axis=1), + ) + + @override + def compute_loss( + self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool = False + ) -> at.Float[at.Array, "*b ah"]: + observation = _model.preprocess_observation( + rng, observation, train=train, image_keys=list(observation.images.keys()) + ) + + # Compute inputs: one big forward pass of prefix + suffix at once + input_token_embeddings, input_mask, ar_mask = self.embed_inputs(observation) + attn_mask = make_attn_mask(input_mask, ar_mask) + + # Compute one-hot targets: we predict *next* token, so shift the input tokens by one. + targets = jax.nn.one_hot( + observation.tokenized_prompt[:, 1:], + self.PaliGemma.llm.module.vocab_size, + ) + + # Each input predicts *next* token, so we don't input the last token. + pre_logits, _, _ = self.PaliGemma.llm( + embedded_prefix=input_token_embeddings[:, :-1], + mask=attn_mask[:, :-1, :-1], + return_prelogits=True, + ) + + # Only decode logits for the target tokens to save memory + # (decoding matmul is large because it is a seq_len x vocab_size dense layer). + logits, _ = self.PaliGemma.llm( + pre_logits=pre_logits[:, -targets.shape[1] :], + ) + logp = jax.nn.log_softmax(logits, axis=-1) + + # Compute CE loss on token targets + assert observation.token_loss_mask is not None, "Token loss mask is required" + loss_mask = observation.token_loss_mask[:, 1:] + token_pplx = jnp.sum(targets * logp, axis=-1) + return -jnp.sum(token_pplx * loss_mask, axis=-1) / jnp.clip(jnp.sum(loss_mask, -1), 1) + + @override + def sample_actions( + self, + rng: at.KeyArrayLike, + observation: _model.Observation, + *, + max_decoding_steps: int | at.Int[at.Array, ""] = 256, + temperature: float = 0.0, + ) -> _model.Actions: + # TODO: this is a hack to get the image keys. + observation = _model.preprocess_observation( + None, observation, train=False, image_keys=list(observation.images.keys()) + ) + + # embed inputs + prefix_token_embeddings, prefix_mask, prefix_ar_mask = self.embed_inputs(observation) + prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask) + + # left to right align all input token sequences + prefix_token_embeddings, prefix_mask, prefix_attn_mask = left_to_right_align( + prefix_token_embeddings, prefix_mask, prefix_attn_mask + ) + prefill_size = prefix_token_embeddings.shape[1] + prefill_len = jnp.sum(prefix_mask, axis=-1) + prefix_start = prefill_size - prefill_len + + # first fill KV cache with a forward pass of the prefix + # pad attention mask to set the size of the KV cache (prefill_size + max_decoding_steps) + prefix_attn_mask = jnp.pad(prefix_attn_mask, ((0, 0), (0, 0), (0, max_decoding_steps))) + prefix_positions = jnp.cumsum(prefix_mask, axis=-1) - 1 + prefix_logits, kv_cache, _ = self.PaliGemma.llm( + embedded_prefix=prefix_token_embeddings, mask=prefix_attn_mask, positions=prefix_positions, decode=True + ) + + # prepare decoding -- final logit decodes the first token + last_logit = prefix_logits[:, -1:] + output_tokens = jnp.zeros((last_logit.shape[0], max_decoding_steps)) + + def step(carry): + rng, last_logit, output_tokens, cache, _, step = carry + + # Sample token from last logit + # Split RNG for this step + rng, rng_step = jax.random.split(rng) + token = jax.lax.cond( + temperature > 0.0, + lambda _: jax.random.categorical(rng_step, last_logit / temperature, axis=-1), + lambda _: jnp.argmax(last_logit, axis=-1), + operand=None, + ) + output_tokens = put_along_last_axis(output_tokens, jnp.broadcast_to(step, (token.shape[0], 1)), token) + + # Check for early stopping --> stop if all batch elements have EOS token + has_eos = jnp.any(token == PALIGEMMA_EOS_TOKEN, axis=-1) + all_eos = jnp.all(has_eos) + + # Decode one step + token_embedding = self.PaliGemma.llm(token, embed_only=True) + positions = prefill_len[:, None] + step + 1 + mask = jnp.logical_and( + jnp.arange(prefill_size + max_decoding_steps)[None, None, :] >= prefix_start[:, None, None], + jnp.arange(prefill_size + max_decoding_steps)[None, None, :] + < (jnp.broadcast_to(prefill_size + step + 1, (prefix_start.shape[0], 1, 1))), + ) + last_logit, kv_cache, _ = self.PaliGemma.llm( + embedded_prefix=token_embedding, mask=mask, positions=positions, decode=True, kv_cache=cache + ) + + return rng, last_logit, output_tokens, kv_cache, all_eos, step + 1 + + def cond(carry): + _, _, _, _, all_eos, step = carry + return (~all_eos) & (step < max_decoding_steps) + + # Use lax.while_loop so we can jit the full decoding loop. + _, _, output_tokens, _, _, _ = jax.lax.while_loop( + cond, step, (rng, last_logit, output_tokens, kv_cache, False, 0) + ) + return output_tokens diff --git a/capvector-pi05/src/openpi/models/pi0_test.py b/capvector-pi05/src/openpi/models/pi0_test.py new file mode 100644 index 0000000000000000000000000000000000000000..793739d137c88665a66f9396513f7264654e8cb1 --- /dev/null +++ b/capvector-pi05/src/openpi/models/pi0_test.py @@ -0,0 +1,46 @@ +import flax.nnx as nnx +import jax + +import openpi.models.pi0_config as _pi0_config + + +def _get_frozen_state(config: _pi0_config.Pi0Config) -> nnx.State: + abstract_model = nnx.eval_shape(config.create, jax.random.key(0)) + + freeze_filter = config.get_freeze_filter() + return nnx.state(abstract_model, nnx.All(nnx.Param, freeze_filter)).flat_state() + + +def test_pi0_full_finetune(): + config = _pi0_config.Pi0Config() + state = _get_frozen_state(config) + assert len(state) == 0 + + +def test_pi0_gemma_lora(): + config = _pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora") + state = _get_frozen_state(config) + assert len(state) == 9 + assert all("lora" not in p for p in state) + assert all("llm" in p for p in state) + assert all("_1" not in p for p in state) + + +def test_pi0_action_expert_lora(): + config = _pi0_config.Pi0Config(action_expert_variant="gemma_300m_lora") + state = _get_frozen_state(config) + # excluding embedder, rest of the params should be same as gemma_lora. + assert len(state) == 8 + assert all("lora" not in p for p in state) + assert all("llm" in p for p in state) + # all frozen params should have _1 in their path since it's the action expert. + assert all(any("_1" in p for p in path) for path in state) + + +def test_pi0_all_lora(): + config = _pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora") + state = _get_frozen_state(config) + # sum of gemma_lora and action_expert_lora's frozen params. + assert len(state) == 17 + assert all("lora" not in p for p in state) + assert all("llm" in p for p in state) diff --git a/capvector-pi05/src/openpi/models/siglip.py b/capvector-pi05/src/openpi/models/siglip.py new file mode 100644 index 0000000000000000000000000000000000000000..e306802ac98a31174f7922a98018e13a1af58647 --- /dev/null +++ b/capvector-pi05/src/openpi/models/siglip.py @@ -0,0 +1,373 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A refactored and simplified ViT adoptation for Pi, taken from big_vision.""" + +from collections.abc import Sequence + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np + +import openpi.training.sharding as sharding + + +def posemb_sincos_2d(h, w, width, temperature=10_000.0, dtype=jnp.float32): + """Follows the MoCo v3 logic.""" + y, x = jnp.mgrid[:h, :w] + + assert width % 4 == 0, "Width must be mult of 4 for sincos posemb" + omega = jnp.arange(width // 4) / (width // 4 - 1) + omega = 1.0 / (temperature**omega) + y = jnp.einsum("m,d->md", y.flatten(), omega) + x = jnp.einsum("m,d->md", x.flatten(), omega) + pe = jnp.concatenate([jnp.sin(x), jnp.cos(x), jnp.sin(y), jnp.cos(y)], axis=1) + return jnp.asarray(pe, dtype)[None, :, :] + + +def get_posemb(self, typ, seqshape, width, name, dtype=jnp.float32): + if typ == "learn": + return self.param( + name, + nn.initializers.normal(stddev=1 / np.sqrt(width)), + (1, np.prod(seqshape), width), + dtype, + ) + if typ == "sincos2d": + return posemb_sincos_2d(*seqshape, width, dtype=dtype) + raise ValueError(f"Unknown posemb type: {typ}") + + +class MlpBlock(nn.Module): + """Transformer MLP / feed-forward block.""" + + mlp_dim: int | None = None # Defaults to 4x input dim + dropout: float = 0.0 + dtype_mm: str = "float32" + + @nn.compact + def __call__(self, x, deterministic=True): # noqa: FBT002 + """Applies Transformer MlpBlock module.""" + inits = { + "kernel_init": nn.initializers.xavier_uniform(), + "bias_init": nn.initializers.normal(stddev=1e-6), + } + + _, _, d = x.shape # n,l,d + x = nn.Dense(self.mlp_dim or 4 * d, dtype=self.dtype_mm, **inits)(x) + x = nn.gelu(x) + x = nn.Dropout(rate=self.dropout)(x, deterministic) + return nn.Dense(d, dtype=self.dtype_mm, **inits)(x) + + +class Encoder1DBlock(nn.Module): + """Single transformer encoder block (MHSA + MLP).""" + + mlp_dim: int | None = None # Defaults to 4x input dim + num_heads: int = 12 + dropout: float = 0.0 + dtype_mm: str = "float32" + + @nn.compact + def __call__(self, x, deterministic=True): # noqa: FBT002 + out = {} + x = sharding.activation_sharding_constraint(x) + y = nn.LayerNorm(dtype=self.dtype_mm)(x) + y = out["sa"] = nn.MultiHeadDotProductAttention( + num_heads=self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + deterministic=deterministic, + dtype=self.dtype_mm, + )(y, y) + y = sharding.activation_sharding_constraint(y) + y = nn.Dropout(rate=self.dropout)(y, deterministic) + x = out["+sa"] = x + y + + y = nn.LayerNorm(dtype=self.dtype_mm)(x) + y = out["mlp"] = MlpBlock( + mlp_dim=self.mlp_dim, + dropout=self.dropout, + dtype_mm=self.dtype_mm, + )(y, deterministic) + y = sharding.activation_sharding_constraint(y) + y = nn.Dropout(rate=self.dropout)(y, deterministic) + x = out["+mlp"] = x + y + x = sharding.activation_sharding_constraint(x) + return x, out + + +class Encoder(nn.Module): + """Transformer Model Encoder for sequence to sequence translation.""" + + depth: int + mlp_dim: int | None = None # Defaults to 4x input dim + num_heads: int = 12 + dropout: float = 0.0 + scan: bool = False + remat_policy: str = "nothing_saveable" + dtype_mm: str = "float32" + + @nn.compact + def __call__(self, x, deterministic=True): # noqa: FBT002 + out = {} + + if self.scan: + block = nn.remat( + Encoder1DBlock, + prevent_cse=False, + static_argnums=(2,), # 0=self, 2=deterministic + policy=getattr(jax.checkpoint_policies, self.remat_policy, None), + ) + x, scan_out = nn.scan( + block, + variable_axes={"params": 0}, + split_rngs={"params": True, "dropout": True}, + in_axes=nn.broadcast, + length=self.depth, + )( + name="encoderblock", + dtype_mm=self.dtype_mm, + mlp_dim=self.mlp_dim, + num_heads=self.num_heads, + dropout=self.dropout, + )(x, deterministic) + for lyr in range(self.depth): + out[f"block{lyr:02d}"] = jax.tree.map(lambda o, lyr=lyr: o[lyr], scan_out) + else: + # Input Encoder + for lyr in range(self.depth): + block_cur = Encoder1DBlock( + name=f"encoderblock_{lyr}", + dtype_mm=self.dtype_mm, + mlp_dim=self.mlp_dim, + num_heads=self.num_heads, + dropout=self.dropout, + ) + x, out[f"block{lyr:02d}"] = block_cur(x, deterministic) + out["pre_ln"] = x # Alias for last block, but without the number in it. + + return nn.LayerNorm(name="encoder_norm", dtype=self.dtype_mm)(x), out + + +class MAPHead(nn.Module): + """Multihead Attention Pooling.""" + + mlp_dim: int | None = None # Defaults to 4x input dim + num_heads: int = 12 + dtype_mm: str = "float32" + + @nn.compact + def __call__(self, x): + n, _, d = x.shape # n,l,d + probe = self.param("probe", nn.initializers.xavier_uniform(), (1, 1, d), x.dtype) + probe = jnp.tile(probe, [n, 1, 1]) + + x = nn.MultiHeadDotProductAttention( + num_heads=self.num_heads, + dtype=self.dtype_mm, + kernel_init=nn.initializers.xavier_uniform(), + )(probe, x) + + y = nn.LayerNorm(dtype=self.dtype_mm)(x) + x = x + MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype_mm)(y) + return x[:, 0] + + +class _Module(nn.Module): + """ViT model.""" + + num_classes: int | None = None + patch_size: Sequence[int] = (16, 16) + width: int = 768 + depth: int = 12 + mlp_dim: int | None = None # Defaults to 4x input dim + num_heads: int = 12 + posemb: str = "learn" # Can also be "sincos2d" + rep_size: int | bool = False + dropout: float = 0.0 + pool_type: str = "gap" # Can also be "map" or "tok" + head_zeroinit: bool = True + scan: bool = False + # or "dots_with_no_batch_dims_saveable" for more speed (memory costly) + remat_policy: str = "nothing_saveable" + dtype_mm: str = "float32" + + @nn.compact + def __call__(self, image, *, train=False): + out = {} + + # Kevin edit: do patch extraction and posemb in float32, + # because I feel like it's a bit safer. + image = jnp.asarray(image, jnp.float32) + + # Patch extraction + x = out["stem"] = nn.Conv( + self.width, + self.patch_size, + strides=self.patch_size, + padding="VALID", + name="embedding", + dtype=jnp.float32, + )(image) + + n, h, w, c = x.shape + x = jnp.reshape(x, [n, h * w, c]) + + # Add posemb before adding extra token. + x = out["with_posemb"] = x + get_posemb(self, self.posemb, (h, w), c, "pos_embedding", jnp.float32) + + if self.pool_type == "tok": + cls = self.param("cls", nn.initializers.zeros, (1, 1, c), x.dtype) + x = jnp.concatenate([jnp.tile(cls, [n, 1, 1]), x], axis=1) + + n, _, c = x.shape # n,l,d + x = nn.Dropout(rate=self.dropout)(x, not train) + + # Kevin edit: now cast back to dtype_mm (potentially half precision) + x = x.astype(self.dtype_mm) + + x, out["encoder"] = Encoder( + depth=self.depth, + mlp_dim=self.mlp_dim, + num_heads=self.num_heads, + dropout=self.dropout, + scan=self.scan, + remat_policy=self.remat_policy, + dtype_mm=self.dtype_mm, + name="Transformer", + )(x, deterministic=not train) + encoded = out["encoded"] = x + + if self.pool_type == "map": + x = out["head_input"] = MAPHead( + num_heads=self.num_heads, + mlp_dim=self.mlp_dim, + dtype=self.dtype_mm, + )(x) + elif self.pool_type == "gap": + x = out["head_input"] = jnp.mean(x, axis=1) + elif self.pool_type == "0": + x = out["head_input"] = x[:, 0] + elif self.pool_type == "tok": + x = out["head_input"] = x[:, 0] + encoded = encoded[:, 1:] + elif self.pool_type == "none": + pass + else: + raise ValueError(f"Unknown pool type: '{self.pool_type}'") + + x_2d = jnp.reshape(encoded, [n, h, w, -1]) + + if self.rep_size: + rep_size = self.width if self.rep_size is True else self.rep_size + hid = nn.Dense(rep_size, dtype=self.dtype_mm, name="pre_logits") + # NOTE: In the past we did not include tanh in pre_logits. + # For few-shot, it should not matter much, as it whitens anyways. + x_2d = nn.tanh(hid(x_2d)) + x = nn.tanh(hid(x)) + + out["pre_logits_2d"] = x_2d + out["pre_logits"] = x + + if self.num_classes: + kw = {"kernel_init": nn.initializers.zeros} if self.head_zeroinit else {} + head = nn.Dense(self.num_classes, dtype=self.dtype_mm, name="head", **kw) + x_2d = out["logits_2d"] = head(x_2d) + x = out["logits"] = head(x) + + return x, out + + +def Module(num_classes=None, *, variant=None, **kw): # pylint: disable=invalid-name # noqa: N802 + """Factory function, because linen really don't like what I'm doing!""" + return _Module(num_classes, **{**decode_variant(variant), **kw}) + + +def decode_variant(variant): + """Converts a string like "B" or "B/32" into a params dict.""" + if variant is None: + return {} + + v, patch = variant, {} + if "/" in variant: + v, patch = variant.split("/") + patch = {"patch_size": (int(patch), int(patch))} + + return { + # pylint:disable=line-too-long + # Reference: Table 2 of https://arxiv.org/abs/2106.04560. + "width": { + "mu": 32, + "Ti": 192, + "S": 384, + "M": 512, + "B": 768, + "L": 1024, + "So400m": 1152, + "H": 1280, + "g": 1408, + "g-opt": 1536, + "G": 1664, + "G-opt": 1536, + "e": 1792, + }[v], + "depth": { + "mu": 1, + "Ti": 12, + "S": 12, + "M": 12, + "B": 12, + "L": 24, + "So400m": 27, + "H": 32, + "g": 40, + "g-opt": 40, + "G": 48, + "G-opt": 48, + "e": 56, + }[v], + "mlp_dim": { + "mu": 128, + "Ti": 768, + "S": 1536, + "M": 2048, + "B": 3072, + "L": 4096, + "So400m": 4304, + "H": 5120, + "g": 6144, + "g-opt": 6144, + "G": 8192, + "G-opt": 8192, + "e": 15360, + }[v], + "num_heads": { + "mu": 2, + "Ti": 3, + "S": 6, + "M": 8, + "B": 12, + "L": 16, + "So400m": 16, + "H": 16, + "g": 16, + "g-opt": 16, + "G": 16, + "G-opt": 16, + "e": 16, + }[v], + # pylint:enable=line-too-long + **patch, + } diff --git a/capvector-pi05/src/openpi/models/tokenizer.py b/capvector-pi05/src/openpi/models/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..ec36ff6f83524e4eadec2c4474f92ee1b5f2bb20 --- /dev/null +++ b/capvector-pi05/src/openpi/models/tokenizer.py @@ -0,0 +1,371 @@ +import logging +import os + +import jax +import numpy as np +import orbax.checkpoint as ocp +import sentencepiece +from transformers import AutoProcessor + +import openpi.models.utils.fsq_tokenizer as fsq_tokenizer +import openpi.shared.download as download + + +class PaligemmaTokenizer: + def __init__(self, max_len: int = 48): + self._max_len = max_len + + path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"}) + with path.open("rb") as f: + self._tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read()) + + def tokenize(self, prompt: str, state: np.ndarray | None = None) -> tuple[np.ndarray, np.ndarray]: + cleaned_text = prompt.strip().replace("_", " ").replace("\n", " ") + if state is not None: + # This is the Pi05 format, where the state is part of the discrete language input. + discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 + state_str = " ".join(map(str, discretized_state)) + full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: " + tokens = self._tokenizer.encode(full_prompt, add_bos=True) + else: + # This is the Pi0 format, where the state is part of the continuous action expert input. + # tokenize "\n" separately as the "start of answer" token + tokens = self._tokenizer.encode(cleaned_text, add_bos=True) + self._tokenizer.encode("\n") + tokens_len = len(tokens) + if tokens_len < self._max_len: + padding = [False] * (self._max_len - tokens_len) + mask = [True] * tokens_len + padding + tokens = tokens + padding + else: + if len(tokens) > self._max_len: + logging.warning( + f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. " + "Consider increasing the `max_token_len` in your model config if this happens frequently." + ) + tokens = tokens[: self._max_len] + mask = [True] * self._max_len + + return np.asarray(tokens), np.asarray(mask) + + +class FASTTokenizer: + def __init__(self, max_len: int = 256, fast_tokenizer_path: str = "physical-intelligence/fast"): + self._max_len = max_len + + # Download base PaliGemma tokenizer + path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"}) + with path.open("rb") as f: + self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read()) + + # Instantiate FAST tokenizer + self._fast_tokenizer = AutoProcessor.from_pretrained(fast_tokenizer_path, trust_remote_code=True) + self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens + + def tokenize( + self, prompt: str, state: np.ndarray, actions: np.ndarray | None + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + cleaned_text = prompt.lower().strip().replace("_", " ") + + # Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1]) + discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 + + # Convention: prefix includes prompt and string-representation of state, followed by ';' + state_str = " ".join(map(str, discretized_state)) + prefix = f"Task: {cleaned_text}, State: {state_str};\n" + prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True) + + if actions is not None: + # Tokenize actions with FAST tokenizer --> map to last tokens in PaliGemma vocab + action_tokens = self._fast_tokenizer(actions[None])[0] + action_tokens_in_pg = self._act_tokens_to_paligemma_tokens(action_tokens) + + # Convention: postfix contains 'Action:' followed by FAST tokens, followed by '|' + postfix_tokens = ( + self._paligemma_tokenizer.encode("Action: ") + + action_tokens_in_pg.tolist() + + self._paligemma_tokenizer.encode("|", add_eos=True) + ) + else: + postfix_tokens = [] + + # Create output token sequence & masks + # AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens) + tokens = prefix_tokens + postfix_tokens + token_mask = [True] * len(tokens) + ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens) + loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only + + # Pad tokens to max length + tokens_len = len(tokens) + if tokens_len < self._max_len: + padding = [False] * (self._max_len - tokens_len) + tokens = tokens + padding + token_mask = token_mask + padding + ar_mask = ar_mask + padding + loss_mask = loss_mask + padding + else: + if len(tokens) > self._max_len: + logging.warning( + f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. " + "Consider increasing the `max_token_len` in your model config if this happens frequently." + ) + tokens = tokens[: self._max_len] + token_mask = token_mask[: self._max_len] + ar_mask = ar_mask[: self._max_len] + loss_mask = loss_mask[: self._max_len] + + return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask) + + def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray: + # Decode predicted output tokens + decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist()) + + # Extract actions from FAST model outputs + if "Action: " not in decoded_tokens: + return np.zeros((action_horizon, action_dim), dtype=np.float32) + + # Extract actions from decoded tokens + raw_action_tokens = np.array( + self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip()) + ) + action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens) + return self._fast_tokenizer.decode( + [action_tokens.tolist()], time_horizon=action_horizon, action_dim=action_dim + )[0] + + def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray: + if isinstance(tokens, list): + tokens = np.array(tokens) + return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens + + +########################################################################### +## The tokenizers below are used for RoboArena baseline implementations. ## +## They are *not* used for pi0-style models. ## +########################################################################### + + +class BinningTokenizer: + """ + Standard RT-2 / OpenVLA style binning tokenizer. + """ + + def __init__(self, max_len: int = 256, n_bins: int = 256): + self._max_len = max_len + self._n_bins = n_bins + + # Download base PaliGemma tokenizer + path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"}) + with path.open("rb") as f: + self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read()) + + self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens + + def tokenize( + self, prompt: str, state: np.ndarray, actions: np.ndarray | None + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Tokenize a prompt and state into a sequence of tokens. + + Args: + prompt: The text prompt to tokenize. + state: The state array to discretize and tokenize. + actions: Must be None. Action encoding is not currently supported. + + Returns: + A tuple of (tokens, token_mask, ar_mask, targets). + + Raises: + NotImplementedError: If actions is not None. + """ + cleaned_text = prompt.lower().strip().replace("_", " ") + + # Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1]) + discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 + + # Convention: prefix includes prompt and string-representation of state, followed by ';' + state_str = " ".join(map(str, discretized_state)) + prefix = f"Task: {cleaned_text}, State: {state_str};\n" + prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True) + + if actions is not None: + raise NotImplementedError("BinningTokenizer does not support encoding actions atm (only for inference use)") + postfix_tokens = [] + + # Create output token sequence & masks + # AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens) + tokens = prefix_tokens + postfix_tokens + token_mask = [True] * len(tokens) + ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens) + loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only + + # Pad tokens to max length + tokens_len = len(tokens) + if tokens_len < self._max_len: + padding = [False] * (self._max_len - tokens_len) + tokens = tokens + padding + token_mask = token_mask + padding + ar_mask = ar_mask + padding + loss_mask = loss_mask + padding + else: + if len(tokens) > self._max_len: + logging.warning( + f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. " + "Consider increasing the `max_token_len` in your model config if this happens frequently." + ) + tokens = tokens[: self._max_len] + token_mask = token_mask[: self._max_len] + ar_mask = ar_mask[: self._max_len] + loss_mask = loss_mask[: self._max_len] + + return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask) + + def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray: + # Decode predicted output tokens + decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist()) + + # Extract actions from FAST model outputs + if "Action: " not in decoded_tokens: + return np.zeros((action_horizon, action_dim), dtype=np.float32) + + # Extract actions from decoded tokens + raw_action_tokens = np.array( + self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip()) + ) + action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens) + if len(action_tokens) < action_horizon * action_dim: + return np.zeros([action_horizon, action_dim], dtype=np.float32) + action_tokens = action_tokens[: (action_horizon * action_dim)].reshape([action_horizon, action_dim]) + return action_tokens / self._n_bins * 2 - 1 + + def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray: + if isinstance(tokens, list): + tokens = np.array(tokens) + return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens + + +class FSQTokenizer: + """ + FSQ tokenizer from the FAST paper baselines. + """ + + def __init__(self, max_len: int = 256, fsq_tokenizer_path: str | None = None): + self._max_len = max_len + + assert fsq_tokenizer_path is not None, "fsq_tokenizer_path must be provided" + # Download tokenizer + path = download.maybe_download(fsq_tokenizer_path) + tok_path = os.path.join(path, os.listdir(path)[0]) + + # Split step from path + step = int(tok_path.split("/")[-1]) + base_path = tok_path.rsplit("/", 1)[0] + + mgr = ocp.CheckpointManager( + base_path, + item_handlers={ + "params": ocp.StandardCheckpointHandler(), + "opt_state": ocp.StandardCheckpointHandler(), + "config": ocp.JsonCheckpointHandler(), + }, + options=ocp.CheckpointManagerOptions(max_to_keep=1), + ) + + try: + restored = mgr.restore( + step, args=ocp.args.Composite(config=ocp.args.JsonRestore(), params=ocp.args.StandardRestore()) + ) + config = restored["config"] + self._params = restored["params"] + self._fsq_tokenizer = fsq_tokenizer.FsqAttentionTokenizer(**config) + except Exception as e: + raise RuntimeError( + f"Failed to load FSQ tokenizer checkpoint from {fsq_tokenizer_path}. Error: {e!s}" + ) from e + + # Compile tokenize and detokenize functions + self._tokenize_fn = jax.jit( + lambda params, x: self._fsq_tokenizer.apply({"params": params}, x, method=self._fsq_tokenizer.tokenize) + ) + self._detokenize_fn = jax.jit( + lambda params, x: self._fsq_tokenizer.apply({"params": params}, x, method=self._fsq_tokenizer.detokenize) + ) + + # Download base PaliGemma tokenizer + path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"}) + with path.open("rb") as f: + self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read()) + + self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens + + def tokenize( + self, prompt: str, state: np.ndarray, actions: np.ndarray | None + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + cleaned_text = prompt.lower().strip().replace("_", " ") + + # Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1]) + discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 + + # Convention: prefix includes prompt and string-representation of state, followed by ';' + state_str = " ".join(map(str, discretized_state)) + prefix = f"Task: {cleaned_text}, State: {state_str};\n" + prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True) + + if actions is not None: + raise NotImplementedError("FSQTokenizer does not support encoding actions atm (only for inference use)") + postfix_tokens = [] + + # Create output token sequence & masks + # AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens) + tokens = prefix_tokens + postfix_tokens + token_mask = [True] * len(tokens) + ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens) + loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only + + # Pad tokens to max length + tokens_len = len(tokens) + if tokens_len < self._max_len: + padding = [False] * (self._max_len - tokens_len) + tokens = tokens + padding + token_mask = token_mask + padding + ar_mask = ar_mask + padding + loss_mask = loss_mask + padding + else: + if len(tokens) > self._max_len: + logging.warning( + f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. " + "Consider increasing the `max_token_len` in your model config if this happens frequently." + ) + tokens = tokens[: self._max_len] + token_mask = token_mask[: self._max_len] + ar_mask = ar_mask[: self._max_len] + loss_mask = loss_mask[: self._max_len] + + return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask) + + def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray: + # Decode predicted output tokens + decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist()) + + # Extract actions from FAST model outputs + if "Action: " not in decoded_tokens: + return np.zeros((action_horizon, action_dim), dtype=np.float32) + + # Extract actions from decoded tokens + raw_action_tokens = np.array( + self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip()) + ) + action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens) + try: + # Move computation to CPU and compile on-demand + device = jax.devices("cpu")[0] + with jax.default_device(device): + detok_act = self._detokenize_fn(self._params, action_tokens[None, ...])[0] + return detok_act[: action_horizon * action_dim].reshape([action_horizon, action_dim]) + except Exception as e: + logging.warning(f"Error decoding FSQ: {e}") + return np.zeros((action_horizon, action_dim)) + + def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray: + if isinstance(tokens, list): + tokens = np.array(tokens) + return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens diff --git a/capvector-pi05/src/openpi/models/tokenizer_test.py b/capvector-pi05/src/openpi/models/tokenizer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3182e0a190ec94275f6af3ddb8bf896ee70be9c1 --- /dev/null +++ b/capvector-pi05/src/openpi/models/tokenizer_test.py @@ -0,0 +1,27 @@ +import numpy as np + +from openpi.models import tokenizer as _tokenizer + + +def test_tokenize(): + tokenizer = _tokenizer.PaligemmaTokenizer(max_len=10) + tokens, masks = tokenizer.tokenize("Hello, world!") + + assert tokens.shape == (10,) + assert masks.shape == (10,) + + +def test_fast_tokenizer(): + prompt = "Hello, world!" + state = np.random.rand(5).astype(np.float32) + action = np.random.rand(3, 2).astype(np.float32) + tokenizer = _tokenizer.FASTTokenizer(max_len=256) + tokens, token_masks, ar_masks, loss_masks = tokenizer.tokenize(prompt, state, action) + + assert tokens.shape == (256,) + assert token_masks.shape == (256,) + assert ar_masks.shape == (256,) + assert loss_masks.shape == (256,) + + act = tokenizer.extract_actions(tokens, 3, 2) + assert act.shape == (3, 2) diff --git a/capvector-pi05/src/openpi/models/utils/fsq_tokenizer.py b/capvector-pi05/src/openpi/models/utils/fsq_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..3c8f4033d1a438d8652414871b2c46fd5a39af0a --- /dev/null +++ b/capvector-pi05/src/openpi/models/utils/fsq_tokenizer.py @@ -0,0 +1,472 @@ +import math +from typing import Any, Literal + +import chex +from einops import einops +from flax import linen as nn +from flax.linen.module import Module +from flax.linen.module import compact +from flax.struct import dataclass +from flax.typing import Array +import jax +import jax.numpy as jnp + + +class FsqCodebook(nn.Module): + input_dim: int + target_codebook_size: int + codebook_type: Literal["fsq", "lfq"] + + _bins_per_dim: tuple[int] | None = None + + @property + def bins_per_dim(self) -> tuple[int]: + if self._bins_per_dim is not None: + return self._bins_per_dim + + if self.codebook_type == "fsq": + return self._get_bins_fsq(self.target_codebook_size) + elif self.codebook_type == "lfq": # noqa: RET505 + return self._get_bins_lfq(self.target_codebook_size) + elif self.codebook_type == "custom": + return self._get_bins_custom(self.target_codebook_size) + else: + raise ValueError(f"Codebook type {self.codebook_type} not supported.") + + @property + def place_values(self) -> jnp.ndarray: + place_values = [1] + for b in self.bins_per_dim[:-1]: + place_values.append(place_values[-1] * b) + return jnp.array(place_values) + + @staticmethod + def _get_bins_fsq(target_codebook_size: int) -> tuple[int]: + """ + Get bins per dimension based on codebook size, from the original FSQ paper. + """ + if target_codebook_size == 2**8: + return (8, 6, 5) + elif target_codebook_size == 2**10: # noqa: RET505 + return (8, 5, 5, 5) + elif target_codebook_size == 2**12: + return (7, 5, 5, 5, 5) + elif target_codebook_size == 2**14: + return (8, 8, 8, 6, 5) + elif target_codebook_size == 2**16: + return (8, 8, 8, 5, 5, 5) + else: + raise ValueError(f"Codebook size {target_codebook_size} not supported.") + + @staticmethod + def _get_bins_custom(target_codebook_size: int) -> tuple[int]: + if target_codebook_size == 2**8: + return (16, 16) + elif target_codebook_size == 2**10: # noqa: RET505 + return (32, 32) + elif target_codebook_size == 2**12: + return (64, 64) + elif target_codebook_size == 2**14: + return (128, 128) + elif target_codebook_size == 2**16: + return (256, 256) + return None + + @staticmethod + def _get_bins_lfq(target_codebook_size: int) -> tuple[int]: + """ + Get bins per dimension according to the Lookup-Free Quantization paper (2 bins per dimension) + """ + assert target_codebook_size & (target_codebook_size - 1) == 0, "Codebook size should be a power of two for LFQ" + + return (2,) * int(math.log2(target_codebook_size)) + + def setup(self): + self.proj_down = nn.Dense(len(self.bins_per_dim)) + self.proj_up = nn.Dense(self.input_dim) + + def __call__(self, inputs: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: + tokens, z = self.encode(inputs) + output = self.decode(tokens, z_grad=z) + return tokens, output + + def encode(self, inputs: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: + bases = jnp.array(self.bins_per_dim) + + x = self.proj_down(inputs) + z = jnp.tanh(x) + + # Quantize + digits = jnp.round((z + 1) * (bases - 1) / 2).astype(jnp.int32) + tokens = self.undigitize(digits) + + return tokens, z + + def decode(self, tokens: jnp.ndarray, z_grad: jax.Array | None = None) -> jnp.ndarray: + bases = jnp.array(self.bins_per_dim) + digits = self.digitize(tokens) + + z_q = digits / (bases - 1) * 2 - 1 + + if z_grad is not None: + chex.assert_equal_shape([z_q, z_grad]) + z_q = jax.lax.stop_gradient(z_q - z_grad) + z_grad + + return self.proj_up(z_q) + + def undigitize(self, digits: jnp.ndarray) -> jnp.ndarray: + return jnp.sum(digits * jnp.array(self.place_values), axis=-1) + + def digitize(self, tokens: jnp.ndarray) -> jnp.ndarray: + return (tokens[..., None] // jnp.array(self.place_values)) % jnp.array(self.bins_per_dim) + + @property + def vocab_size(self) -> int: + return math.prod(self.bins_per_dim) + + +class ResNetDownBlock(nn.Module): + stride: int = 1 + n_filters: int = 64 + dropout_rate: float = 0.0 + group_size: int = 32 + + @nn.compact + def __call__(self, x: jnp.ndarray, *, train: bool = True) -> jnp.ndarray: + skip = x + + if self.stride > 1 or x.shape[-1] != self.n_filters: + skip = nn.Conv(self.n_filters, (self.stride,), (self.stride,), "SAME")(skip) + + x = nn.Conv(self.n_filters, (3,), (self.stride,), "SAME")(x) + x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x) + x = nn.Dropout(self.dropout_rate)(x, deterministic=not train) + x = nn.relu(x) + x = nn.Conv(self.n_filters, (3,), (1,), "SAME")(x) + + return skip + x + + +class ResNetUpBlock(nn.Module): + stride: int = 1 + n_filters: int = 64 + dropout_rate: float = 0.0 + group_size: int = 32 + + @nn.compact + def __call__(self, x: jnp.ndarray, *, train: bool = True) -> jnp.ndarray: + skip = x + + if self.stride > 1: + skip = nn.ConvTranspose(self.n_filters, (self.stride,), (self.stride,), "SAME")(skip) + + x = nn.ConvTranspose(self.n_filters, (3,), (self.stride,), "SAME")(x) + x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x) + x = nn.Dropout(self.dropout_rate)(x, deterministic=not train) + x = nn.relu(x) + x = nn.ConvTranspose(self.n_filters, (3,), (1,), "SAME")(x) + + return skip + x + + +@dataclass +class LfqCodebookOutput: + tokens: jnp.ndarray + z: jnp.ndarray + z_q: jnp.ndarray + token_log_probs: jnp.ndarray + commit_loss: jnp.ndarray + + +class LookupFreeQuantization(nn.Module): + num_dims: int + latent_dim: int + + def setup(self): + self.codebook = jnp.array([-1, 1]) + self.activation = nn.tanh + + self.project_down = nn.Dense(self.num_dims) + self.project_up = nn.Dense(self.latent_dim) + + def encode(self, z: jnp.ndarray) -> jnp.ndarray: + z = self.project_down(z) + token_squared_distances = jnp.square(z[..., None] - self.codebook) + token_bits = jnp.argmin(token_squared_distances, axis=-1) + return jnp.sum(token_bits * (2 ** jnp.arange(self.num_dims)), axis=-1) + + def decode(self, tokens: jnp.ndarray) -> jnp.ndarray: + token_bits = (tokens[..., None] & (2 ** jnp.arange(self.num_dims))).astype(jnp.int32) + return self.project_up(self.codebook[token_bits]) + + def loss(self, x: jnp.ndarray) -> LfqCodebookOutput: + z = self.project_down(x) + z = self.activation(z) + + token_squared_distances = jnp.square(z[..., None] - self.codebook) + tokens = jnp.argmin(token_squared_distances, axis=-1) + + token_bit_log_probs = -token_squared_distances + # Compute token log probs for tokens 0..2^num_dims-1 by summing corresponding log-probs + token_bit_expansions = jnp.bitwise_and( + jnp.arange(2**self.num_dims)[None, :], 2 ** jnp.arange(self.num_dims)[:, None] + ).astype(jnp.int32) + token_log_probs = ( + token_bit_log_probs[..., 0] @ (1 - token_bit_expansions) + + token_bit_log_probs[..., 1] @ token_bit_expansions + ) # (batch_size, num_tokens, 2 ** num_dims) + token_log_probs = jax.lax.stop_gradient(jax.nn.log_softmax(token_log_probs, axis=-1)) + chex.assert_shape(token_log_probs, (*x.shape[:-1], 2**self.num_dims)) + + z_q = self.codebook[tokens] + commit_loss = jnp.square(z - z_q).mean() + z_q = jax.lax.stop_gradient(z_q - z) + z + + z_q = self.project_up(z_q) + z = self.project_up(z) + + tokens = jnp.sum(tokens * (len(self.codebook) ** jnp.arange(self.num_dims)), axis=-1) + return LfqCodebookOutput( + tokens=tokens, + z=z, + z_q=z_q, + token_log_probs=jnp.zeros(()), + commit_loss=commit_loss, + ) + + +def make_block_causal_attention_matrix(q: jnp.ndarray, k: jnp.ndarray, bs_q: int, bs_k: int) -> jnp.ndarray: + return nn.make_attention_mask(q, k, pairwise_fn=lambda x, y: jnp.greater_equal(x // bs_k, y // bs_q)) + + +class GeGLU(Module): + """Gated Linear Unit with GELU (GeGLU) activation function. + GeGLU is a Flax layer that combines a linear transformation with a GELU + activation function in a gating mechanism. It is often used in Transformer models + to provide non-linear capabilities while preserving a strong linear component. + + Attributes: + features: the number of output features (default: None). + """ + + output_dim: int = -1 + + @compact + def __call__(self, inputs: Array) -> Array: + """Applies the GeGLU activation to the inputs. + Args: + inputs: the nd-array to apply the GeGLU activation function to. + Returns: + The transformed input. + """ + output_dim = inputs.shape[-1] if self.output_dim == -1 else self.output_dim + + x = nn.Dense(output_dim * 2)(inputs) + x, gate = x[..., :output_dim], x[..., output_dim:] + return x * nn.gelu(gate) + + +class CrossAttentionLayer(nn.Module): + dropout_rate: float = 0.0 + num_heads: int = None + causal: bool = False + mlp_ratio: float = 4.0 + + @nn.compact + def __call__( + self, + x: jnp.ndarray, + y: jnp.ndarray, + *, + mask_self: jnp.ndarray | None = None, + mask_cross: jnp.ndarray | None = None, + train: bool = True, + ) -> jnp.ndarray: + d_embed = x.shape[-1] + seq_len_q = x.shape[-2] + seq_len_k = y.shape[-2] + + if self.causal: + # One block size will be 1 + bs_q = max(seq_len_q // seq_len_k, 1) + bs_k = max(seq_len_k // seq_len_q, 1) + + mask_self = nn.make_causal_mask(x[..., 0]) + mask_cross = make_block_causal_attention_matrix(x[..., 0], y[..., 0], bs_q, bs_k) + + # Self-attention block + skip = x + x = nn.LayerNorm()(x) + x = nn.MultiHeadDotProductAttention( + num_heads=self.num_heads or d_embed // 64, + dropout_rate=self.dropout_rate, + deterministic=not train, + )(x, x, x, mask=mask_self) + x = skip + x + + # Cross-attention block + skip = x + x = nn.LayerNorm()(x) + x = nn.MultiHeadDotProductAttention( + num_heads=self.num_heads or d_embed // 64, + dropout_rate=self.dropout_rate, + deterministic=not train, + )(x, y, y, mask=mask_cross) + x = skip + x + + # MLP block + skip = x + x = nn.LayerNorm()(x) + x = nn.Dense(int(d_embed * self.mlp_ratio))(x) + x = nn.Dropout(self.dropout_rate)(x, deterministic=not train) + x = GeGLU()(x) + x = nn.Dense(d_embed)(x) + return skip + x + + +def sinusoidal_pe_init(_, shape: tuple[int, int]) -> jnp.ndarray: + seq_len, d_embed = shape + + position = jnp.arange(0, seq_len, 1) + div_term = jnp.exp(jnp.arange(0, d_embed, 2) * -(jnp.log(10000.0) / d_embed)) + return jnp.concatenate( + [ + jnp.sin(position[:, jnp.newaxis] * div_term), + jnp.cos(position[:, jnp.newaxis] * div_term), + ], + axis=-1, + ) + + +class TokenizerEncoderDecoder(nn.Module): + num_tokens: int + num_cross_tokens: int + num_layers: int + causal: bool + + mlp_ratio: float = 4.0 + use_state_conditioning: bool = False + + @nn.compact + def __call__( + self, + y: jnp.ndarray, + *, + train: bool = True, + state_conditioning: jnp.ndarray | None = None, + mask: jnp.ndarray | None = None, + ) -> jnp.ndarray: + x = self.param("q_embed", sinusoidal_pe_init, (self.num_tokens, y.shape[-1])) + x = jax.numpy.broadcast_to(x, y.shape[:-2] + x.shape[-2:]) + + if mask is not None: + # mask is (batch_dims..., num_cross_tokens) + chex.assert_equal_shape([y[..., 0], mask]) + attn_mask = einops.repeat(mask, "... kv -> ... 1 q kv", q=self.num_tokens) + else: + attn_mask = jnp.ones((*y.shape[:-2], 1, self.num_tokens, self.num_cross_tokens)) + + if self.use_state_conditioning: + assert state_conditioning is not None, "State conditioning is required for this model." + state_embed = nn.Dense(y.shape[-1], name="state_proj")(state_conditioning)[..., None, :] + y = jnp.concatenate([y, state_embed], axis=-2) + attn_mask = jnp.concatenate([attn_mask, jnp.ones_like(attn_mask[..., 0:1])], axis=-1) + + y = y + self.param("y_pos_enc", sinusoidal_pe_init, y.shape[-2:]) + + for _ in range(self.num_layers): + x = CrossAttentionLayer(causal=self.causal, mlp_ratio=self.mlp_ratio)( + x, y, train=train, mask_self=None, mask_cross=attn_mask + ) + + return x + + +class FsqAttentionTokenizer(nn.Module): + embed_dim: int + data_dim: int + data_horizon: int + num_tokens: int + num_layers: int + target_codebook_size: int + causal: bool = False + mlp_ratio: float = 2.0 + + bound: float | None = None + + use_state_conditioning: bool = False + + @property + def vocab_size(self) -> int: + return math.prod(FsqCodebook._get_bins_fsq(self.target_codebook_size)) # noqa: SLF001 + + def setup(self): + self.proj = nn.Dense(self.embed_dim) + self.encoder = TokenizerEncoderDecoder( + num_tokens=self.num_tokens, + num_cross_tokens=self.data_horizon, + num_layers=self.num_layers, + causal=self.causal, + use_state_conditioning=self.use_state_conditioning, + mlp_ratio=self.mlp_ratio, + ) + self.codebook = FsqCodebook( + input_dim=self.embed_dim, + target_codebook_size=self.target_codebook_size, + codebook_type="custom", + ) + self.decoder = TokenizerEncoderDecoder( + num_tokens=self.data_horizon, + num_cross_tokens=self.num_tokens, + num_layers=self.num_layers, + causal=self.causal, + use_state_conditioning=self.use_state_conditioning, + mlp_ratio=self.mlp_ratio, + ) + + self.proj_mean = nn.Dense(self.data_dim) + self.out_scale = self.param("out_scale", lambda _: jnp.full((), 1.0)) + + def tokenize( + self, action: jnp.ndarray, *, obs: jnp.ndarray | None = None, train: bool = False + ) -> tuple[jnp.ndarray, jnp.ndarray]: + if self.bound is not None: + action = jnp.clip(action, -self.bound, self.bound) + + x = self.proj(action) + x = self.encoder(x, train=train, state_conditioning=obs) + + return self.codebook.encode(x) + + def detokenize(self, tokens: jnp.ndarray, *, obs: jnp.ndarray | None = None) -> jnp.ndarray: + x = self.decoder(self.codebook.decode(tokens), state_conditioning=obs) + mean = self.proj_mean(x) + return mean * self.out_scale + + def loss( + self, action: jnp.ndarray, *, obs: jnp.ndarray | None = None, train: bool = True + ) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]: + # Encode + x = self.proj(action) + z = self.encoder(x, train=train, state_conditioning=obs) + + # Quantize + tokens, z = self.codebook(z) + + # Decode + x = self.decoder(z, train=train, state_conditioning=obs) + mean = self.proj_mean(x) * self.out_scale + + mse = jnp.mean(jnp.square(action - mean)) + mae = jnp.mean(jnp.abs(action - mean)) + + return mse, { + "mse": mse, + "mae": mae, + } + + def __call__(self, *args: Any, **kwargs: Any) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]: + """ + Dummy for .init + """ + return self.loss(*args, **kwargs) diff --git a/capvector-pi05/src/openpi/models/vit.py b/capvector-pi05/src/openpi/models/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..1408e28d5ae9dd75aaf19aa44dd6c82fd9dda7bb --- /dev/null +++ b/capvector-pi05/src/openpi/models/vit.py @@ -0,0 +1,307 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ViT implementation adapted from https://github.com/google-research/vision_transformer/blob/main/vit_jax/models_vit.py.""" + +from collections.abc import Callable +from typing import Any + +import flax.linen as nn +import jax +import jax.numpy as jnp + +from openpi.models import resnet as models_resnet + +Array = Any +PRNGKey = Any +Shape = tuple[int] +Dtype = Any + + +class IdentityLayer(nn.Module): + """Identity layer, convenient for giving a name to an array.""" + + @nn.compact + def __call__(self, x): + return x + + +class AddPositionEmbs(nn.Module): + """Adds learned positional embeddings to the inputs. + + Attributes: + posemb_init: positional embedding initializer. + """ + + posemb_init: Callable[[PRNGKey, Shape, Dtype], Array] + param_dtype: Dtype = jnp.float32 + + @nn.compact + def __call__(self, inputs): + """Applies the AddPositionEmbs module. + + Args: + inputs: Inputs to the layer. + + Returns: + Output tensor with shape `(bs, timesteps, in_dim)`. + """ + # inputs.shape is (batch_size, seq_len, emb_dim). + assert inputs.ndim == 3, f"Number of dimensions should be 3, but it is: {inputs.ndim}" + pos_emb_shape = (1, inputs.shape[1], inputs.shape[2]) + pe = self.param("pos_embedding", self.posemb_init, pos_emb_shape, self.param_dtype) + return inputs + pe + + +class MlpBlock(nn.Module): + """Transformer MLP / feed-forward block.""" + + mlp_dim: int + dtype: Dtype = jnp.float32 + param_dtype: Dtype = jnp.float32 + out_dim: int | None = None + dropout_rate: float = 0.1 + kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.xavier_uniform() + bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.normal(stddev=1e-6) + + @nn.compact + def __call__(self, inputs, *, deterministic): + """Applies Transformer MlpBlock module.""" + actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim + x = nn.Dense( + features=self.mlp_dim, + dtype=self.dtype, + param_dtype=self.param_dtype, + kernel_init=self.kernel_init, + bias_init=self.bias_init, + )( # pytype: disable=wrong-arg-types + inputs + ) + x = nn.gelu(x) + x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) + output = nn.Dense( + features=actual_out_dim, + dtype=self.dtype, + param_dtype=self.param_dtype, + kernel_init=self.kernel_init, + bias_init=self.bias_init, + )( # pytype: disable=wrong-arg-types + x + ) + return nn.Dropout(rate=self.dropout_rate)(output, deterministic=deterministic) + + +class Encoder1DBlock(nn.Module): + """Transformer encoder layer. + + Attributes: + inputs: input data. + mlp_dim: dimension of the mlp on top of attention block. + dtype: the dtype of the computation (default: float32). + dropout_rate: dropout rate. + attention_dropout_rate: dropout for attention heads. + deterministic: bool, deterministic or not (to apply dropout). + num_heads: Number of heads in nn.MultiHeadDotProductAttention + """ + + mlp_dim: int + num_heads: int + dtype: Dtype = jnp.float32 + dropout_rate: float = 0.1 + attention_dropout_rate: float = 0.1 + + @nn.compact + def __call__(self, inputs, deterministic): + """Applies Encoder1DBlock module. + + Args: + inputs: Inputs to the layer. + deterministic: Dropout will not be applied when set to true. + + Returns: + output after transformer encoder block. + """ + + # Attention block. + assert inputs.ndim == 3, f"Expected (batch, seq, hidden) got {inputs.shape}" + x = nn.LayerNorm(dtype=self.dtype)(inputs) + x = nn.MultiHeadDotProductAttention( + dtype=self.dtype, + kernel_init=nn.initializers.xavier_uniform(), + broadcast_dropout=False, + deterministic=deterministic, + dropout_rate=self.attention_dropout_rate, + num_heads=self.num_heads, + # why isn't this true by default??? + force_fp32_for_softmax=True, + )(x, x) + x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) + x = x + inputs + + # MLP block. + y = nn.LayerNorm(dtype=self.dtype)(x) + y = MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate)( + y, deterministic=deterministic + ) + + return x + y, None + + +class Encoder(nn.Module): + """Transformer Model Encoder for sequence to sequence translation. + + Attributes: + num_layers: number of layers + mlp_dim: dimension of the mlp on top of attention block + num_heads: Number of heads in nn.MultiHeadDotProductAttention + dropout_rate: dropout rate. + attention_dropout_rate: dropout rate in self attention. + """ + + dtype: jax.typing.DTypeLike + num_layers: int + mlp_dim: int + num_heads: int + dropout_rate: float = 0.1 + attention_dropout_rate: float = 0.1 + add_position_embedding: bool = True + + @nn.compact + def __call__(self, x, *, train): + """Applies Transformer model on the inputs. + + Args: + x: Inputs to the layer. + train: Set to `True` when training. + + Returns: + output of a transformer encoder. + """ + assert x.ndim == 3 # (batch, len, emb) + + if self.add_position_embedding: + x = AddPositionEmbs( + posemb_init=nn.initializers.normal(stddev=0.02), # from BERT. + name="posembed_input", + )(x) + x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) + + x = x.astype(self.dtype) + # Input Encoder + block = nn.remat(Encoder1DBlock, prevent_cse=False, static_argnums=(2,)) + x, _ = nn.scan( + block, + variable_axes={"params": 0}, + split_rngs={"params": True, "dropout": True}, + in_axes=nn.broadcast, + length=self.num_layers, + )( + name="encoderblock", + mlp_dim=self.mlp_dim, + dropout_rate=self.dropout_rate, + attention_dropout_rate=self.attention_dropout_rate, + dtype=self.dtype, + num_heads=self.num_heads, + )(x, not train) + return nn.LayerNorm(name="encoder_norm", dtype=self.dtype)(x) + + +class VisionTransformer(nn.Module): + """VisionTransformer.""" + + dtype: jax.typing.DTypeLike + num_classes: int + patches: Any + transformer: Any + hidden_size: int + resnet: Any | None = None + representation_size: int | None = None + classifier: str = "token" + head_bias_init: float = 0.0 + encoder: type[nn.Module] = Encoder + model_name: str | None = None + + @nn.compact + def __call__(self, inputs, *, train): + x = inputs + # (Possibly partial) ResNet root. + if self.resnet is not None: + width = int(64 * self.resnet.width_factor) + + # Root block. + x = models_resnet.StdConv( + features=width, kernel_size=(7, 7), strides=(2, 2), use_bias=False, name="conv_root" + )(x) + x = nn.GroupNorm(name="gn_root")(x) + x = nn.relu(x) + x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding="SAME") + + # ResNet stages. + if self.resnet.num_layers: + x = models_resnet.ResNetStage( + block_size=self.resnet.num_layers[0], nout=width, first_stride=(1, 1), name="block1" + )(x) + for i, block_size in enumerate(self.resnet.num_layers[1:], 1): + x = models_resnet.ResNetStage( + block_size=block_size, nout=width * 2**i, first_stride=(2, 2), name=f"block{i + 1}" + )(x) + + n, h, w, c = x.shape + + # We can merge s2d+emb into a single conv; it's the same. + x = nn.Conv( + features=self.hidden_size, + kernel_size=self.patches.size, + strides=self.patches.size, + padding="VALID", + name="embedding", + )(x) + + # Here, x is a grid of embeddings. + + # (Possibly partial) Transformer. + if self.transformer is not None: + n, h, w, c = x.shape + x = jnp.reshape(x, [n, h * w, c]) + + # If we want to add a class token, add it here. + if self.classifier in ["token", "token_unpooled"]: + cls = self.param("cls", nn.initializers.zeros, (1, 1, c)) + cls = jnp.tile(cls, [n, 1, 1]) + x = jnp.concatenate([cls, x], axis=1) + + x = self.encoder(name="Transformer", **self.transformer, dtype=self.dtype)(x, train=train) + + if self.classifier == "token": + x = x[:, 0] + elif self.classifier == "gap": + x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) # (1,) or (1,2) + elif self.classifier in ["unpooled", "token_unpooled"]: + pass + else: + raise ValueError(f"Invalid classifier={self.classifier}") + + if self.representation_size is not None: + x = nn.Dense(features=self.representation_size, name="pre_logits")(x) + x = nn.tanh(x) + else: + x = IdentityLayer(name="pre_logits")(x) + + if self.num_classes: + x = nn.Dense( + features=self.num_classes, + name="head", + kernel_init=nn.initializers.zeros, + bias_init=nn.initializers.constant(self.head_bias_init), + )(x) + return x diff --git a/capvector-pi05/src/openpi/models_pytorch/gemma_pytorch.py b/capvector-pi05/src/openpi/models_pytorch/gemma_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..54ee1693970659b0780d0ce8f0413df8854511b2 --- /dev/null +++ b/capvector-pi05/src/openpi/models_pytorch/gemma_pytorch.py @@ -0,0 +1,291 @@ +from typing import Literal + +import pytest +import torch +from torch import nn +from transformers import GemmaForCausalLM +from transformers import PaliGemmaForConditionalGeneration +from transformers.models.auto import CONFIG_MAPPING +from transformers.models.gemma import modeling_gemma + + +class PaliGemmaWithExpertModel(nn.Module): + def __init__( + self, + vlm_config, + action_expert_config, + use_adarms=None, + precision: Literal["bfloat16", "float32"] = "bfloat16", + ): + if use_adarms is None: + use_adarms = [False, False] + super().__init__() + + vlm_config_hf = CONFIG_MAPPING["paligemma"]() + vlm_config_hf._vocab_size = 257152 # noqa: SLF001 + vlm_config_hf.image_token_index = 257152 + vlm_config_hf.text_config.hidden_size = vlm_config.width + vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim + vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads + vlm_config_hf.text_config.head_dim = vlm_config.head_dim + vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth + vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads + vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh" + vlm_config_hf.text_config.torch_dtype = "float32" + vlm_config_hf.text_config.vocab_size = 257152 + vlm_config_hf.text_config.use_adarms = use_adarms[0] + vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None + vlm_config_hf.vision_config.intermediate_size = 4304 + vlm_config_hf.vision_config.projection_dim = 2048 + vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast" + vlm_config_hf.vision_config.torch_dtype = "float32" + + action_expert_config_hf = CONFIG_MAPPING["gemma"]( + head_dim=action_expert_config.head_dim, + hidden_size=action_expert_config.width, + intermediate_size=action_expert_config.mlp_dim, + num_attention_heads=action_expert_config.num_heads, + num_hidden_layers=action_expert_config.depth, + num_key_value_heads=action_expert_config.num_kv_heads, + vocab_size=257152, + hidden_activation="gelu_pytorch_tanh", + torch_dtype="float32", + use_adarms=use_adarms[1], + adarms_cond_dim=action_expert_config.width if use_adarms[1] else None, + ) + + self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf) + self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf) + self.gemma_expert.model.embed_tokens = None + + self.to_bfloat16_for_selected_params(precision) + + def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"): + if precision == "bfloat16": + self.to(dtype=torch.bfloat16) + elif precision == "float32": + self.to(dtype=torch.float32) + return + else: + raise ValueError(f"Invalid precision: {precision}") + + params_to_keep_float32 = [ + "vision_tower.vision_model.embeddings.patch_embedding.weight", + "vision_tower.vision_model.embeddings.patch_embedding.bias", + "vision_tower.vision_model.embeddings.position_embedding.weight", + "input_layernorm", + "post_attention_layernorm", + "model.norm", + ] + + for name, param in self.named_parameters(): + if any(selector in name for selector in params_to_keep_float32): + param.data = param.data.to(dtype=torch.float32) + + def embed_image(self, image: torch.Tensor): + return self.paligemma.model.get_image_features(image) + + def embed_language_tokens(self, tokens: torch.Tensor): + return self.paligemma.language_model.embed_tokens(tokens) + + def forward( + self, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | pytest.Cache | None = None, + inputs_embeds: list[torch.FloatTensor] | None = None, + use_cache: bool | None = None, + adarms_cond: list[torch.Tensor] | None = None, + output_hidden_states: bool | None = None, + ): + if adarms_cond is None: + adarms_cond = [None, None] + if inputs_embeds[1] is None: + prefix_output = self.paligemma.language_model.forward( + inputs_embeds=inputs_embeds[0], + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + adarms_cond=adarms_cond[0] if adarms_cond is not None else None, + ) + prefix_past_key_values = prefix_output.past_key_values + prefix_output = prefix_output.last_hidden_state + suffix_output = None + elif inputs_embeds[0] is None: + suffix_output = self.gemma_expert.model.forward( + inputs_embeds=inputs_embeds[1], + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + adarms_cond=adarms_cond[1] if adarms_cond is not None else None, + ) + suffix_output = suffix_output.last_hidden_state + prefix_output = None + prefix_past_key_values = None + else: + models = [self.paligemma.language_model, self.gemma_expert.model] + num_layers = self.paligemma.config.text_config.num_hidden_layers + + # Check if gradient checkpointing is enabled for any of the models + use_gradient_checkpointing = ( + hasattr(self.gemma_expert.model, "gradient_checkpointing") + and self.gemma_expert.model.gradient_checkpointing + and self.training + ) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training) + + # Force enable gradient checkpointing if we're in training mode and the model supports it + if self.training and hasattr(self.gemma_expert.model, "gradient_checkpointing"): + if not self.gemma_expert.model.gradient_checkpointing: + print("Forcing gradient checkpointing to be enabled for Gemma expert model") + self.gemma_expert.model.gradient_checkpointing = True + use_gradient_checkpointing = True + + # Debug gradient checkpointing status + if hasattr(self, "_debug_gc_printed") and not self._debug_gc_printed: + print(f"Gemma expert model gradient checkpointing: {use_gradient_checkpointing}") + print(f"Model training mode: {self.training}") + print( + f"Gemma expert model has gradient_checkpointing attr: {hasattr(self.gemma_expert.model, 'gradient_checkpointing')}" + ) + if hasattr(self.gemma_expert.model, "gradient_checkpointing"): + print( + f"Gemma expert model gradient_checkpointing value: {self.gemma_expert.model.gradient_checkpointing}" + ) + self._debug_gc_printed = True + + # Define the complete layer computation function for gradient checkpointing + def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond): + models = [self.paligemma.language_model, self.gemma_expert.model] + + query_states = [] + key_states = [] + value_states = [] + gates = [] + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901 + gates.append(gate) + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) + query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + query_states.append(query_state) + key_states.append(key_state) + value_states.append(value_state) + + # Concatenate and process attention + query_states = torch.cat(query_states, dim=2) + key_states = torch.cat(key_states, dim=2) + value_states = torch.cat(value_states, dim=2) + + dummy_tensor = torch.zeros( + query_states.shape[0], + query_states.shape[2], + query_states.shape[-1], + device=query_states.device, + dtype=query_states.dtype, + ) + cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids) + query_states, key_states = modeling_gemma.apply_rotary_pos_emb( + query_states, key_states, cos, sin, unsqueeze_dim=1 + ) + + batch_size = query_states.shape[0] + scaling = self.paligemma.language_model.layers[layer_idx].self_attn.scaling + + # Attention computation + att_output, _ = modeling_gemma.eager_attention_forward( + self.paligemma.language_model.layers[layer_idx].self_attn, + query_states, + key_states, + value_states, + attention_mask, + scaling, + ) + # Get head_dim from the current layer, not from the model + head_dim = self.paligemma.language_model.layers[layer_idx].self_attn.head_dim + att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim) + + # Process layer outputs + outputs_embeds = [] + start_pos = 0 + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + end_pos = start_pos + hidden_states.shape[1] + + if att_output.dtype != layer.self_attn.o_proj.weight.dtype: + att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) + out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos]) + + # first residual + out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001 + after_first_residual = out_emb.clone() + out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i]) + # Convert to bfloat16 if the next layer (mlp) uses bfloat16 + if layer.mlp.up_proj.weight.dtype == torch.bfloat16: + out_emb = out_emb.to(dtype=torch.bfloat16) + + out_emb = layer.mlp(out_emb) + # second residual + out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001 + outputs_embeds.append(out_emb) + start_pos = end_pos + + return outputs_embeds + + # Process all layers with gradient checkpointing if enabled + all_hidden_states = () if output_hidden_states else None + for layer_idx in range(num_layers): + if output_hidden_states: + all_hidden_states += (inputs_embeds,) + if use_gradient_checkpointing: + inputs_embeds = torch.utils.checkpoint.checkpoint( + compute_layer_complete, + layer_idx, + inputs_embeds, + attention_mask, + position_ids, + adarms_cond, + use_reentrant=False, + preserve_rng_state=False, + ) + else: + inputs_embeds = compute_layer_complete( + layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond + ) + + # Old code removed - now using compute_layer_complete function above + + # final norm + # Define final norm computation function for gradient checkpointing + def compute_final_norms(inputs_embeds, adarms_cond): + outputs_embeds = [] + for i, hidden_states in enumerate(inputs_embeds): + out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i]) + outputs_embeds.append(out_emb) + return outputs_embeds + + # Apply gradient checkpointing to final norm if enabled + if use_gradient_checkpointing: + outputs_embeds = torch.utils.checkpoint.checkpoint( + compute_final_norms, inputs_embeds, adarms_cond, use_reentrant=False, preserve_rng_state=False + ) + else: + outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond) + + if output_hidden_states: + all_hidden_states += (outputs_embeds,) + + prefix_output = outputs_embeds[0] + suffix_output = outputs_embeds[1] + prefix_past_key_values = None + + if output_hidden_states: + return [prefix_output, suffix_output], prefix_past_key_values, all_hidden_states + + return [prefix_output, suffix_output], prefix_past_key_values diff --git a/capvector-pi05/src/openpi/models_pytorch/pi0_align_pytorch.py b/capvector-pi05/src/openpi/models_pytorch/pi0_align_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..abd9a73f968808fe49d7ca5d168c095ab699e466 --- /dev/null +++ b/capvector-pi05/src/openpi/models_pytorch/pi0_align_pytorch.py @@ -0,0 +1,528 @@ +import logging +import math + +import torch +from torch import Tensor +from torch import nn +import torch.nn.functional as F # noqa: N812 + +import openpi.models.gemma as _gemma +from openpi.models_pytorch.gemma_pytorch import PaliGemmaWithExpertModel +import openpi.models_pytorch.preprocessing_pytorch as _preprocessing + +from vggt.utils.load_fn import preprocess_images_from_openpi +from vggt.heads.utils import custom_pooling + + +def get_safe_dtype(target_dtype, device_type): + """Get a safe dtype for the given device type.""" + if device_type == "cpu": + # CPU doesn't support bfloat16, use float32 instead + if target_dtype == torch.bfloat16: + return torch.float32 + if target_dtype == torch.float64: + return torch.float64 + return target_dtype + + +def create_sinusoidal_pos_embedding( + time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu" +) -> Tensor: + """Computes sine-cosine positional embedding vectors for scalar positions.""" + if dimension % 2 != 0: + raise ValueError(f"dimension ({dimension}) must be divisible by 2") + + if time.ndim != 1: + raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") + + dtype = get_safe_dtype(torch.float64, device.type) + fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) + period = min_period * (max_period / min_period) ** fraction + + # Compute the outer product + scaling_factor = 1.0 / period * 2 * math.pi + sin_input = scaling_factor[None, :] * time[:, None] + return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) + + +def sample_beta(alpha, beta, bsize, device): + alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device) + beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device) + dist = torch.distributions.Beta(alpha_t, beta_t) + return dist.sample((bsize,)) + + +def make_att_2d_masks(pad_masks, att_masks): + """Copied from big_vision. + + Tokens can attend to valid inputs tokens which have a cumulative mask_ar + smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to + setup several types of attention, for example: + + [[1 1 1 1 1 1]]: pure causal attention. + + [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between + themselves and the last 3 tokens have a causal attention. The first + entry could also be a 1 without changing behaviour. + + [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a + block can attend all previous blocks and all tokens on the same block. + + Args: + input_mask: bool[B, N] true if its part of the input, false if padding. + mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on + it and 0 where it shares the same attention mask as the previous token. + """ + if att_masks.ndim != 2: + raise ValueError(att_masks.ndim) + if pad_masks.ndim != 2: + raise ValueError(pad_masks.ndim) + + cumsum = torch.cumsum(att_masks, dim=1) + att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None] + pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] + return att_2d_masks & pad_2d_masks + + +class PI0Pytorch(nn.Module): + def __init__(self, config, extra_config): + super().__init__() + self.config = config + self.pi05 = config.pi05 + + paligemma_config = _gemma.get_config(config.paligemma_variant) + action_expert_config = _gemma.get_config(config.action_expert_variant) + + self.LLM_width = paligemma_config.width + + self.paligemma_with_expert = PaliGemmaWithExpertModel( + paligemma_config, + action_expert_config, + use_adarms=[False, True] if self.pi05 else [False, False], + precision=config.dtype, + ) + + self.action_in_proj = nn.Linear(32, action_expert_config.width) + self.action_out_proj = nn.Linear(action_expert_config.width, 32) + + if self.pi05: + self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width) + self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) + else: + self.state_proj = nn.Linear(32, action_expert_config.width) + self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width) + self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) + + torch.set_float32_matmul_precision("high") + self.sample_actions = torch.compile(self.sample_actions, mode="max-autotune") + + # Initialize gradient checkpointing flag + self.gradient_checkpointing_enabled = False + + # Specific config for SpatialForcing alignment + self.vla_layers_align = extra_config.vla_layers_align + self.vggt_layers_align = extra_config.vggt_layers_align + self.pooling_func = extra_config.pooling_func + self.use_vggt_pe = extra_config.use_vggt_pe + + msg = "transformers_replace is not installed correctly. Please install it with `uv pip install transformers==4.53.2` and `cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/`." + try: + from transformers.models.siglip import check + + if not check.check_whether_transformers_replace_is_installed_correctly(): + raise ValueError(msg) + except ImportError: + raise ValueError(msg) from None + + def gradient_checkpointing_enable(self): + """Enable gradient checkpointing for memory optimization.""" + self.gradient_checkpointing_enabled = True + self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True + self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True + self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True + + logging.info("Enabled gradient checkpointing for PI0Pytorch model") + + def gradient_checkpointing_disable(self): + """Disable gradient checkpointing.""" + self.gradient_checkpointing_enabled = False + self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False + self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False + self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False + + logging.info("Disabled gradient checkpointing for PI0Pytorch model") + + def is_gradient_checkpointing_enabled(self): + """Check if gradient checkpointing is enabled.""" + return self.gradient_checkpointing_enabled + + def _apply_checkpoint(self, func, *args, **kwargs): + """Helper method to apply gradient checkpointing if enabled.""" + if self.gradient_checkpointing_enabled and self.training: + return torch.utils.checkpoint.checkpoint( + func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs + ) + return func(*args, **kwargs) + + def _prepare_attention_masks_4d(self, att_2d_masks): + """Helper method to prepare 4D attention masks for transformer.""" + att_2d_masks_4d = att_2d_masks[:, None, :, :] + return torch.where(att_2d_masks_4d, 0.0, -2.3819763e38) + + def _preprocess_observation(self, observation, *, train=True, get_wo_aug=False): + """Helper method to preprocess observation.""" + observation = _preprocessing.preprocess_observation_pytorch(observation, train=train, get_wo_aug=get_wo_aug) + return ( + list(observation.images.values()), + list(observation.img_wo_aug.values()) if get_wo_aug else None, + list(observation.image_padding_mask.values()), + list(observation.image_masks.values()), + observation.tokenized_prompt, + observation.tokenized_prompt_mask, + observation.state, + ) + + def sample_noise(self, shape, device): + return torch.normal( + mean=0.0, + std=1.0, + size=shape, + dtype=torch.float32, + device=device, + ) + + def sample_time(self, bsize, device): + time_beta = sample_beta(1.5, 1.0, bsize, device) + time = time_beta * 0.999 + 0.001 + return time.to(dtype=torch.float32, device=device) + + def embed_prefix( + self, images, img_masks, lang_tokens, lang_masks + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Embed images with SigLIP and language tokens with embedding layer to prepare + for PaliGemma transformer processing. + """ + embs = [] + pad_masks = [] + att_masks = [] + + # Process images + for img, img_mask in zip(images, img_masks, strict=True): + + def image_embed_func(img): + return self.paligemma_with_expert.embed_image(img) + + img_emb = self._apply_checkpoint(image_embed_func, img) + + bsize, num_img_embs = img_emb.shape[:2] + + embs.append(img_emb) + pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs)) + + # Create attention masks so that image tokens attend to each other + att_masks += [0] * num_img_embs + + img_len = len(att_masks) + + # Process language tokens + def lang_embed_func(lang_tokens): + lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) + lang_emb_dim = lang_emb.shape[-1] + return lang_emb * math.sqrt(lang_emb_dim) + + lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens) + + embs.append(lang_emb) + pad_masks.append(lang_masks) + + # full attention between image and language inputs + num_lang_embs = lang_emb.shape[1] + att_masks += [0] * num_lang_embs + + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) + + # Get batch size from the first dimension of the concatenated tensors + bsize = pad_masks.shape[0] + att_masks = att_masks[None, :].expand(bsize, len(att_masks)) + + return embs, pad_masks, att_masks, img_len + + def embed_suffix(self, state, noisy_actions, timestep): + """Embed state, noisy_actions, timestep to prepare for Expert Gemma processing.""" + embs = [] + pad_masks = [] + att_masks = [] + + if not self.pi05: + if self.state_proj.weight.dtype == torch.float32: + state = state.to(torch.float32) + + # Embed state + def state_proj_func(state): + return self.state_proj(state) + + state_emb = self._apply_checkpoint(state_proj_func, state) + + embs.append(state_emb[:, None, :]) + bsize = state_emb.shape[0] + device = state_emb.device + + state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device) + pad_masks.append(state_mask) + + # Set attention masks so that image and language inputs do not attend to state or actions + att_masks += [1] + + # Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1] + time_emb = create_sinusoidal_pos_embedding( + timestep, self.action_in_proj.out_features, min_period=4e-3, max_period=4.0, device=timestep.device + ) + time_emb = time_emb.type(dtype=timestep.dtype) + + # Fuse timestep + action information using an MLP + def action_proj_func(noisy_actions): + return self.action_in_proj(noisy_actions) + + action_emb = self._apply_checkpoint(action_proj_func, noisy_actions) + + if not self.pi05: + time_emb = time_emb[:, None, :].expand_as(action_emb) + action_time_emb = torch.cat([action_emb, time_emb], dim=2) + + # Apply MLP layers + def mlp_func(action_time_emb): + x = self.action_time_mlp_in(action_time_emb) + x = F.silu(x) # swish == silu + return self.action_time_mlp_out(x) + + action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb) + adarms_cond = None + else: + # time MLP (for adaRMS) + def time_mlp_func(time_emb): + x = self.time_mlp_in(time_emb) + x = F.silu(x) # swish == silu + x = self.time_mlp_out(x) + return F.silu(x) + + time_emb = self._apply_checkpoint(time_mlp_func, time_emb) + action_time_emb = action_emb + adarms_cond = time_emb + + # Add to input tokens + embs.append(action_time_emb) + + bsize, action_time_dim = action_time_emb.shape[:2] + action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device) + pad_masks.append(action_time_mask) + + # Set attention masks so that image, language and state inputs do not attend to action tokens + att_masks += [1] + ([0] * (self.config.action_horizon - 1)) + + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device) + att_masks = att_masks[None, :].expand(bsize, len(att_masks)) + + return embs, pad_masks, att_masks, adarms_cond + + def forward(self, observation, actions, vggt, align_proj, noise=None, time=None) -> Tensor: + """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)""" + images, img_wo_aug, img_padding_mask, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation( + observation, train=True, get_wo_aug=True + ) + img_resize_wo_aug = preprocess_images_from_openpi(img_wo_aug) # specific for VGGT with 518px input + + # =================================== VLA action loss =================================== + + if noise is None: + noise = self.sample_noise(actions.shape, actions.device) + + if time is None: + time = self.sample_time(actions.shape[0], actions.device) + + time_expanded = time[:, None, None] + x_t = time_expanded * noise + (1 - time_expanded) * actions + u_t = noise - actions + + prefix_embs, prefix_pad_masks, prefix_att_masks, img_len = self.embed_prefix( + images, img_masks, lang_tokens, lang_masks + ) + suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time) + if ( + self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype + == torch.bfloat16 + ): + suffix_embs = suffix_embs.to(dtype=torch.bfloat16) + prefix_embs = prefix_embs.to(dtype=torch.bfloat16) + + pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) + att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) + + att_2d_masks = make_att_2d_masks(pad_masks, att_masks) + position_ids = torch.cumsum(pad_masks, dim=1) - 1 + + # Prepare attention masks + att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks) + + # Apply gradient checkpointing if enabled + def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond): + (_, suffix_out), _, all_hidden_states = self.paligemma_with_expert.forward( + attention_mask=att_2d_masks_4d, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, suffix_embs], + use_cache=False, + adarms_cond=[None, adarms_cond], + output_hidden_states=True, + ) + return suffix_out, all_hidden_states + + suffix_out, all_hidden_states = self._apply_checkpoint( + forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond + ) + + suffix_out = suffix_out[:, -self.config.action_horizon :] + suffix_out = suffix_out.to(dtype=torch.float32) + + # Apply gradient checkpointing to final action projection if enabled + def action_out_proj_func(suffix_out): + return self.action_out_proj(suffix_out) + + v_t = self._apply_checkpoint(action_out_proj_func, suffix_out) + + action_loss = F.mse_loss(u_t, v_t) + + # =================================== Alignment loss =================================== + + # VLA hidden states + (prefix_hidden, _) = all_hidden_states[self.vla_layers_align] # 18 total layers of paligemma + vision_hidden = prefix_hidden[:, :img_len, :] + + # VGGT hidden states + with torch.autocast("cuda", dtype=torch.bfloat16), torch.no_grad(): + vggt_output = vggt(img_resize_wo_aug) + agg_vggt_hidden = vggt_output["features"][self.vggt_layers_align] # 24 for total layers of VGGT + patch_start_idx = vggt_output["patch_start_idx"] + original_img = vggt_output["images"] + vggt_hidden = agg_vggt_hidden[:, :, patch_start_idx:, :] + + # Resample VGGT hidden states to match the resolution of VLA hidden states + H, W = original_img.shape[-2:] + patch_h, patch_w = H // vggt.patch_size, W // vggt.patch_size + vggt_hidden = custom_pooling( + vggt_hidden, (patch_h, patch_w), (H, W), vision_hidden, self.pooling_func, self.use_vggt_pe + ) + + # empty image feature masks for alignment loss + tokens_per_img = img_len // len(images) + img_masks_stack = torch.stack(img_masks, dim=1) + align_mask = torch.repeat_interleave(img_masks_stack, repeats=tokens_per_img, dim=1) + + # useless non-rectangular image padding feature masks for alignment loss + img_padding_mask = torch.stack(img_padding_mask, dim=1) + target_size = img_padding_mask.shape[-1] // 14 # 224/14, where 14 is the patch size of Gemma encoder + mask_downsampled = F.interpolate( + img_padding_mask.float(), + size=(target_size, target_size), + mode='nearest' + ).bool().flatten(start_dim=1) + assert align_mask.shape == mask_downsampled.shape, \ + "align_mask shape don't match img_padding_mask shape, please manually modify the patch size of Gemma encoder (now is 14)" + align_mask = mask_downsampled & align_mask + + # calculate align loss + with torch.autocast("cuda", dtype=torch.bfloat16): + align_loss = align_proj(vision_hidden, vggt_hidden, align_mask) + + return action_loss, align_loss + + @torch.no_grad() + def sample_actions(self, device, observation, noise=None, num_steps=10) -> Tensor: + """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)""" + bsize = observation.state.shape[0] + if noise is None: + actions_shape = (bsize, self.config.action_horizon, self.config.action_dim) + noise = self.sample_noise(actions_shape, device) + + images, _, _, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(observation, train=False) + + prefix_embs, prefix_pad_masks, prefix_att_masks, _ = self.embed_prefix(images, img_masks, lang_tokens, lang_masks) + prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) + prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 + + # Compute image and language key value cache + prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks) + self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001 + + _, past_key_values = self.paligemma_with_expert.forward( + attention_mask=prefix_att_2d_masks_4d, + position_ids=prefix_position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, None], + use_cache=True, + ) + + dt = -1.0 / num_steps + dt = torch.tensor(dt, dtype=torch.float32, device=device) + + x_t = noise + time = torch.tensor(1.0, dtype=torch.float32, device=device) + while time >= -dt / 2: + expanded_time = time.expand(bsize) + v_t = self.denoise_step( + state, + prefix_pad_masks, + past_key_values, + x_t, + expanded_time, + ) + + # Euler step - use new tensor assignment instead of in-place operation + x_t = x_t + dt * v_t + time += dt + return x_t + + def denoise_step( + self, + state, + prefix_pad_masks, + past_key_values, + x_t, + timestep, + ): + """Apply one denoising step of the noise `x_t` at a given timestep.""" + suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep) + + suffix_len = suffix_pad_masks.shape[1] + batch_size = prefix_pad_masks.shape[0] + prefix_len = prefix_pad_masks.shape[1] + + prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len) + + suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) + + full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2) + + prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] + position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 + + # Prepare attention masks + full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks) + self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001 + + outputs_embeds, _ = self.paligemma_with_expert.forward( + attention_mask=full_att_2d_masks_4d, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=[None, suffix_embs], + use_cache=False, + adarms_cond=[None, adarms_cond], + ) + + suffix_out = outputs_embeds[1] + suffix_out = suffix_out[:, -self.config.action_horizon :] + suffix_out = suffix_out.to(dtype=torch.float32) + return self.action_out_proj(suffix_out) diff --git a/capvector-pi05/src/openpi/models_pytorch/pi0_pytorch.py b/capvector-pi05/src/openpi/models_pytorch/pi0_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..c32308574dd3d76149a1b1236e5d956d6de100cf --- /dev/null +++ b/capvector-pi05/src/openpi/models_pytorch/pi0_pytorch.py @@ -0,0 +1,461 @@ +import logging +import math + +import torch +from torch import Tensor +from torch import nn +import torch.nn.functional as F # noqa: N812 + +import openpi.models.gemma as _gemma +from openpi.models_pytorch.gemma_pytorch import PaliGemmaWithExpertModel +import openpi.models_pytorch.preprocessing_pytorch as _preprocessing + + +def get_safe_dtype(target_dtype, device_type): + """Get a safe dtype for the given device type.""" + if device_type == "cpu": + # CPU doesn't support bfloat16, use float32 instead + if target_dtype == torch.bfloat16: + return torch.float32 + if target_dtype == torch.float64: + return torch.float64 + return target_dtype + + +def create_sinusoidal_pos_embedding( + time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu" +) -> Tensor: + """Computes sine-cosine positional embedding vectors for scalar positions.""" + if dimension % 2 != 0: + raise ValueError(f"dimension ({dimension}) must be divisible by 2") + + if time.ndim != 1: + raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") + + dtype = get_safe_dtype(torch.float64, device.type) + fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) + period = min_period * (max_period / min_period) ** fraction + + # Compute the outer product + scaling_factor = 1.0 / period * 2 * math.pi + sin_input = scaling_factor[None, :] * time[:, None] + return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) + + +def sample_beta(alpha, beta, bsize, device): + alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device) + beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device) + dist = torch.distributions.Beta(alpha_t, beta_t) + return dist.sample((bsize,)) + + +def make_att_2d_masks(pad_masks, att_masks): + """Copied from big_vision. + + Tokens can attend to valid inputs tokens which have a cumulative mask_ar + smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to + setup several types of attention, for example: + + [[1 1 1 1 1 1]]: pure causal attention. + + [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between + themselves and the last 3 tokens have a causal attention. The first + entry could also be a 1 without changing behaviour. + + [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a + block can attend all previous blocks and all tokens on the same block. + + Args: + input_mask: bool[B, N] true if its part of the input, false if padding. + mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on + it and 0 where it shares the same attention mask as the previous token. + """ + if att_masks.ndim != 2: + raise ValueError(att_masks.ndim) + if pad_masks.ndim != 2: + raise ValueError(pad_masks.ndim) + + cumsum = torch.cumsum(att_masks, dim=1) + att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None] + pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] + return att_2d_masks & pad_2d_masks + + +class PI0Pytorch(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.pi05 = config.pi05 + + paligemma_config = _gemma.get_config(config.paligemma_variant) + action_expert_config = _gemma.get_config(config.action_expert_variant) + + self.paligemma_with_expert = PaliGemmaWithExpertModel( + paligemma_config, + action_expert_config, + use_adarms=[False, True] if self.pi05 else [False, False], + precision=config.dtype, + ) + + self.action_in_proj = nn.Linear(32, action_expert_config.width) + self.action_out_proj = nn.Linear(action_expert_config.width, 32) + + if self.pi05: + self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width) + self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) + else: + self.state_proj = nn.Linear(32, action_expert_config.width) + self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width) + self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) + + torch.set_float32_matmul_precision("high") + self.sample_actions = torch.compile(self.sample_actions, mode="max-autotune") + + # Initialize gradient checkpointing flag + self.gradient_checkpointing_enabled = False + + msg = "transformers_replace is not installed correctly. Please install it with `uv pip install transformers==4.53.2` and `cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/`." + try: + from transformers.models.siglip import check + + if not check.check_whether_transformers_replace_is_installed_correctly(): + raise ValueError(msg) + except ImportError: + raise ValueError(msg) from None + + def gradient_checkpointing_enable(self): + """Enable gradient checkpointing for memory optimization.""" + self.gradient_checkpointing_enabled = True + self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True + self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True + self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True + + logging.info("Enabled gradient checkpointing for PI0Pytorch model") + + def gradient_checkpointing_disable(self): + """Disable gradient checkpointing.""" + self.gradient_checkpointing_enabled = False + self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False + self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False + self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False + + logging.info("Disabled gradient checkpointing for PI0Pytorch model") + + def is_gradient_checkpointing_enabled(self): + """Check if gradient checkpointing is enabled.""" + return self.gradient_checkpointing_enabled + + def _apply_checkpoint(self, func, *args, **kwargs): + """Helper method to apply gradient checkpointing if enabled.""" + if self.gradient_checkpointing_enabled and self.training: + return torch.utils.checkpoint.checkpoint( + func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs + ) + return func(*args, **kwargs) + + def _prepare_attention_masks_4d(self, att_2d_masks): + """Helper method to prepare 4D attention masks for transformer.""" + att_2d_masks_4d = att_2d_masks[:, None, :, :] + return torch.where(att_2d_masks_4d, 0.0, -2.3819763e38) + + def _preprocess_observation(self, observation, *, train=True): + """Helper method to preprocess observation.""" + observation = _preprocessing.preprocess_observation_pytorch(observation, train=train) + return ( + list(observation.images.values()), + list(observation.image_masks.values()), + observation.tokenized_prompt, + observation.tokenized_prompt_mask, + observation.state, + ) + + def sample_noise(self, shape, device): + return torch.normal( + mean=0.0, + std=1.0, + size=shape, + dtype=torch.float32, + device=device, + ) + + def sample_time(self, bsize, device): + time_beta = sample_beta(1.5, 1.0, bsize, device) + time = time_beta * 0.999 + 0.001 + return time.to(dtype=torch.float32, device=device) + + def embed_prefix( + self, images, img_masks, lang_tokens, lang_masks + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Embed images with SigLIP and language tokens with embedding layer to prepare + for PaliGemma transformer processing. + """ + embs = [] + pad_masks = [] + att_masks = [] + + # Process images + for img, img_mask in zip(images, img_masks, strict=True): + + def image_embed_func(img): + return self.paligemma_with_expert.embed_image(img) + + img_emb = self._apply_checkpoint(image_embed_func, img) + + bsize, num_img_embs = img_emb.shape[:2] + + embs.append(img_emb) + pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs)) + + # Create attention masks so that image tokens attend to each other + att_masks += [0] * num_img_embs + + # Process language tokens + def lang_embed_func(lang_tokens): + lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) + lang_emb_dim = lang_emb.shape[-1] + return lang_emb * math.sqrt(lang_emb_dim) + + lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens) + + embs.append(lang_emb) + pad_masks.append(lang_masks) + + # full attention between image and language inputs + num_lang_embs = lang_emb.shape[1] + att_masks += [0] * num_lang_embs + + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) + + # Get batch size from the first dimension of the concatenated tensors + bsize = pad_masks.shape[0] + att_masks = att_masks[None, :].expand(bsize, len(att_masks)) + + return embs, pad_masks, att_masks + + def embed_suffix(self, state, noisy_actions, timestep): + """Embed state, noisy_actions, timestep to prepare for Expert Gemma processing.""" + embs = [] + pad_masks = [] + att_masks = [] + + if not self.pi05: + if self.state_proj.weight.dtype == torch.float32: + state = state.to(torch.float32) + + # Embed state + def state_proj_func(state): + return self.state_proj(state) + + state_emb = self._apply_checkpoint(state_proj_func, state) + + embs.append(state_emb[:, None, :]) + bsize = state_emb.shape[0] + device = state_emb.device + + state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device) + pad_masks.append(state_mask) + + # Set attention masks so that image and language inputs do not attend to state or actions + att_masks += [1] + + # Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1] + time_emb = create_sinusoidal_pos_embedding( + timestep, self.action_in_proj.out_features, min_period=4e-3, max_period=4.0, device=timestep.device + ) + time_emb = time_emb.type(dtype=timestep.dtype) + + # Fuse timestep + action information using an MLP + def action_proj_func(noisy_actions): + return self.action_in_proj(noisy_actions) + + action_emb = self._apply_checkpoint(action_proj_func, noisy_actions) + + if not self.pi05: + time_emb = time_emb[:, None, :].expand_as(action_emb) + action_time_emb = torch.cat([action_emb, time_emb], dim=2) + + # Apply MLP layers + def mlp_func(action_time_emb): + x = self.action_time_mlp_in(action_time_emb) + x = F.silu(x) # swish == silu + return self.action_time_mlp_out(x) + + action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb) + adarms_cond = None + else: + # time MLP (for adaRMS) + def time_mlp_func(time_emb): + x = self.time_mlp_in(time_emb) + x = F.silu(x) # swish == silu + x = self.time_mlp_out(x) + return F.silu(x) + + time_emb = self._apply_checkpoint(time_mlp_func, time_emb) + action_time_emb = action_emb + adarms_cond = time_emb + + # Add to input tokens + embs.append(action_time_emb) + + bsize, action_time_dim = action_time_emb.shape[:2] + action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device) + pad_masks.append(action_time_mask) + + # Set attention masks so that image, language and state inputs do not attend to action tokens + att_masks += [1] + ([0] * (self.config.action_horizon - 1)) + + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device) + att_masks = att_masks[None, :].expand(bsize, len(att_masks)) + + return embs, pad_masks, att_masks, adarms_cond + + def forward(self, observation, actions, noise=None, time=None) -> Tensor: + """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)""" + images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(observation, train=True) + + if noise is None: + noise = self.sample_noise(actions.shape, actions.device) + + if time is None: + time = self.sample_time(actions.shape[0], actions.device) + + time_expanded = time[:, None, None] + x_t = time_expanded * noise + (1 - time_expanded) * actions + u_t = noise - actions + + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks) + suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time) + if ( + self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype + == torch.bfloat16 + ): + suffix_embs = suffix_embs.to(dtype=torch.bfloat16) + prefix_embs = prefix_embs.to(dtype=torch.bfloat16) + + pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) + att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) + + att_2d_masks = make_att_2d_masks(pad_masks, att_masks) + position_ids = torch.cumsum(pad_masks, dim=1) - 1 + + # Prepare attention masks + att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks) + + # Apply gradient checkpointing if enabled + def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond): + (_, suffix_out), _ = self.paligemma_with_expert.forward( + attention_mask=att_2d_masks_4d, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, suffix_embs], + use_cache=False, + adarms_cond=[None, adarms_cond], + ) + return suffix_out + + suffix_out = self._apply_checkpoint( + forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond + ) + + suffix_out = suffix_out[:, -self.config.action_horizon :] + suffix_out = suffix_out.to(dtype=torch.float32) + + # Apply gradient checkpointing to final action projection if enabled + def action_out_proj_func(suffix_out): + return self.action_out_proj(suffix_out) + + v_t = self._apply_checkpoint(action_out_proj_func, suffix_out) + + return F.mse_loss(u_t, v_t, reduction="none") + + @torch.no_grad() + def sample_actions(self, device, observation, noise=None, num_steps=10) -> Tensor: + """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)""" + bsize = observation.state.shape[0] + if noise is None: + actions_shape = (bsize, self.config.action_horizon, self.config.action_dim) + noise = self.sample_noise(actions_shape, device) + + images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(observation, train=False) + + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks) + prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) + prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 + + # Compute image and language key value cache + prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks) + self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001 + + _, past_key_values = self.paligemma_with_expert.forward( + attention_mask=prefix_att_2d_masks_4d, + position_ids=prefix_position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, None], + use_cache=True, + ) + + dt = -1.0 / num_steps + dt = torch.tensor(dt, dtype=torch.float32, device=device) + + x_t = noise + time = torch.tensor(1.0, dtype=torch.float32, device=device) + while time >= -dt / 2: + expanded_time = time.expand(bsize) + v_t = self.denoise_step( + state, + prefix_pad_masks, + past_key_values, + x_t, + expanded_time, + ) + + # Euler step - use new tensor assignment instead of in-place operation + x_t = x_t + dt * v_t + time += dt + return x_t + + def denoise_step( + self, + state, + prefix_pad_masks, + past_key_values, + x_t, + timestep, + ): + """Apply one denoising step of the noise `x_t` at a given timestep.""" + suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep) + + suffix_len = suffix_pad_masks.shape[1] + batch_size = prefix_pad_masks.shape[0] + prefix_len = prefix_pad_masks.shape[1] + + prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len) + + suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) + + full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2) + + prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] + position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 + + # Prepare attention masks + full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks) + self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001 + + outputs_embeds, _ = self.paligemma_with_expert.forward( + attention_mask=full_att_2d_masks_4d, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=[None, suffix_embs], + use_cache=False, + adarms_cond=[None, adarms_cond], + ) + + suffix_out = outputs_embeds[1] + suffix_out = suffix_out[:, -self.config.action_horizon :] + suffix_out = suffix_out.to(dtype=torch.float32) + return self.action_out_proj(suffix_out) diff --git a/capvector-pi05/src/openpi/models_pytorch/preprocessing_pytorch.py b/capvector-pi05/src/openpi/models_pytorch/preprocessing_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..cb90f3672a761f4dcd6b2a51ba51e12c2d2ecce1 --- /dev/null +++ b/capvector-pi05/src/openpi/models_pytorch/preprocessing_pytorch.py @@ -0,0 +1,190 @@ +from collections.abc import Sequence +import logging + +import torch + +from openpi.shared import image_tools + +logger = logging.getLogger("openpi") + +# Constants moved from model.py +IMAGE_KEYS = ( + "base_0_rgb", + "left_wrist_0_rgb", + "right_wrist_0_rgb", +) + +IMAGE_RESOLUTION = (224, 224) + + +def preprocess_observation_pytorch( + observation, + *, + train: bool = False, + get_wo_aug: bool = False, + image_keys: Sequence[str] = IMAGE_KEYS, + image_resolution: tuple[int, int] = IMAGE_RESOLUTION, +): + """Torch.compile-compatible version of preprocess_observation_pytorch with simplified type annotations. + + This function avoids complex type annotations that can cause torch.compile issues. + """ + if not set(image_keys).issubset(observation.images): + raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}") + + batch_shape = observation.state.shape[:-1] + + out_images = {} + out_images_wo_aug = {} + for key in image_keys: + image = observation.images[key] + + # TODO: This is a hack to handle both [B, C, H, W] and [B, H, W, C] formats + # Handle both [B, C, H, W] and [B, H, W, C] formats + is_channels_first = image.shape[1] == 3 # Check if channels are in dimension 1 + + if is_channels_first: + # Convert [B, C, H, W] to [B, H, W, C] for processing + image = image.permute(0, 2, 3, 1) + + if image.shape[1:3] != image_resolution: + logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}") + image = image_tools.resize_with_pad_torch(image, *image_resolution) + + if train: + # Convert from [-1, 1] to [0, 1] for PyTorch augmentations + image = image / 2.0 + 0.5 + + # Apply PyTorch-based augmentations + if "wrist" not in key and not get_wo_aug: + # Geometric augmentations for non-wrist cameras + height, width = image.shape[1:3] + + # Random crop and resize + crop_height = int(height * 0.95) + crop_width = int(width * 0.95) + + # Random crop + max_h = height - crop_height + max_w = width - crop_width + if max_h > 0 and max_w > 0: + # Use tensor operations instead of .item() for torch.compile compatibility + start_h = torch.randint(0, max_h + 1, (1,), device=image.device) + start_w = torch.randint(0, max_w + 1, (1,), device=image.device) + image = image[:, start_h : start_h + crop_height, start_w : start_w + crop_width, :] + + # Resize back to original size + image = torch.nn.functional.interpolate( + image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w] + size=(height, width), + mode="bilinear", + align_corners=False, + ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + + # Random rotation (small angles) + # Use tensor operations instead of .item() for torch.compile compatibility + angle = torch.rand(1, device=image.device) * 10 - 5 # Random angle between -5 and 5 degrees + if torch.abs(angle) > 0.1: # Only rotate if angle is significant + # Convert to radians + angle_rad = angle * torch.pi / 180.0 + + # Create rotation matrix + cos_a = torch.cos(angle_rad) + sin_a = torch.sin(angle_rad) + + # Apply rotation using grid_sample + grid_x = torch.linspace(-1, 1, width, device=image.device) + grid_y = torch.linspace(-1, 1, height, device=image.device) + + # Create meshgrid + grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing="ij") + + # Expand to batch dimension + grid_x = grid_x.unsqueeze(0).expand(image.shape[0], -1, -1) + grid_y = grid_y.unsqueeze(0).expand(image.shape[0], -1, -1) + + # Apply rotation transformation + grid_x_rot = grid_x * cos_a - grid_y * sin_a + grid_y_rot = grid_x * sin_a + grid_y * cos_a + + # Stack and reshape for grid_sample + grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1) + + image = torch.nn.functional.grid_sample( + image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w] + grid, + mode="bilinear", + padding_mode="zeros", + align_corners=False, + ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + + # Save original images (with color_aug, but without rotation) for VGGT input + img_inv_padding = image.clone() if not (image == 0).all() else torch.ones_like(image) + img_inv_padding[~observation.image_padding_mask[key]] = 1.0 # Set padding areas to white + img_inv_padding = img_inv_padding.permute(0, 3, 1, 2) if is_channels_first else img_inv_padding + out_images_wo_aug[key] = img_inv_padding.contiguous() + + # Color augmentations for all cameras + # Random brightness + # Use tensor operations instead of .item() for torch.compile compatibility + brightness_factor = 0.7 + torch.rand(1, device=image.device) * 0.6 # Random factor between 0.7 and 1.3 + image = image * brightness_factor + + # Random contrast + # Use tensor operations instead of .item() for torch.compile compatibility + contrast_factor = 0.6 + torch.rand(1, device=image.device) * 0.8 # Random factor between 0.6 and 1.4 + mean = image.mean(dim=[1, 2, 3], keepdim=True) + image = (image - mean) * contrast_factor + mean + + # Random saturation (convert to HSV, modify S, convert back) + # For simplicity, we'll just apply a random scaling to the color channels + # Use tensor operations instead of .item() for torch.compile compatibility + saturation_factor = 0.5 + torch.rand(1, device=image.device) * 1.0 # Random factor between 0.5 and 1.5 + gray = image.mean(dim=-1, keepdim=True) + image = gray + (image - gray) * saturation_factor + + # Clamp values to [0, 1] + image = torch.clamp(image, 0, 1) + + # Back to [-1, 1] + image = image * 2.0 - 1.0 + + # Convert back to [B, C, H, W] format if it was originally channels-first + if is_channels_first: + image = image.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] + + out_images[key] = image + + # obtain mask + out_masks = {} + for key in out_images: + if key not in observation.image_masks: + # do not mask by default + out_masks[key] = torch.ones(batch_shape, dtype=torch.bool, device=observation.state.device) + else: + out_masks[key] = observation.image_masks[key] + + # obtain image padding mask for non-rectangular images + img_padding_mask = {key: observation.image_padding_mask[key] for key in out_images} + + # Create a simple object with the required attributes instead of using the complex Observation class + class SimpleProcessedObservation: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + result_kwargs = { + "images": out_images, + "image_padding_mask": img_padding_mask, + "image_masks": out_masks, + "state": observation.state, + "tokenized_prompt": observation.tokenized_prompt, + "tokenized_prompt_mask": observation.tokenized_prompt_mask, + "token_ar_mask": observation.token_ar_mask, + "token_loss_mask": observation.token_loss_mask, + } + + if get_wo_aug: + result_kwargs["img_wo_aug"] = out_images_wo_aug + + return SimpleProcessedObservation(**result_kwargs) \ No newline at end of file diff --git a/capvector-pi05/src/openpi/models_pytorch/projectors.py b/capvector-pi05/src/openpi/models_pytorch/projectors.py new file mode 100644 index 0000000000000000000000000000000000000000..d9ff14bd50cb0d343f5450dae307aa2f3db871ac --- /dev/null +++ b/capvector-pi05/src/openpi/models_pytorch/projectors.py @@ -0,0 +1,64 @@ +"""Implementation of additional projectors for additional inputs to the VLA models.""" +import torch +import torch.nn as nn +import openpi.models.gemma as _gemma + +class AlignProjector(nn.Module): + """ + calculate the alignment between LLM and VGGT embeddings. + """ + def __init__( + self, + llm_dim: int, + vggt_dim: int, + use_vlm_norm: bool = False, + ) -> None: + super().__init__() + + self.llm_dim = llm_dim + self.vggt_dim = vggt_dim + + self.fc1 = nn.Linear(self.llm_dim, 2 * self.vggt_dim, bias=True) + self.fc2 = nn.Linear(2 * self.vggt_dim, 2 * self.vggt_dim, bias=True) + self.act_fn1 = nn.GELU() + + self.vlm_norm = nn.LayerNorm(llm_dim) if use_vlm_norm else None + + self.initialize_weights() + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + def align_dimension(self, LLM_embedding: torch.Tensor = None) -> torch.Tensor: + if self.vlm_norm is not None: + LLM_embedding = self.vlm_norm(LLM_embedding) + projected_features = self.fc1(LLM_embedding) + projected_features = self.act_fn1(projected_features) + projected_features = self.fc2(projected_features) + return projected_features + + def compute_align_loss_cosine(self, vision_hidden, vggt_hidden, align_mask): + # vision_hidden has a shape of (bs, N, D) + def mean_flat(x): + return torch.mean(x, dim=list(range(1, len(x.size())))) + align_loss = 0 + bsz = vision_hidden.shape[0] + for _vision, _vggt, _mask in zip(vision_hidden, vggt_hidden, align_mask): + _vision = torch.nn.functional.normalize(_vision, dim=-1) + _vggt = torch.nn.functional.normalize(_vggt, dim=-1) + # align_loss += 1 - torch.mean(vision_hidden * vggt_hidden).sum(dim=-1).mean() + align_loss += 1 - mean_flat((_vision * _vggt)[_mask].sum(dim=-1)) # Cosine similarity loss + align_loss /= bsz # Average over batch size + return align_loss + + def forward(self, LLM_emb, target_emb, align_mask): + # project vla dimension and calculate align loss + LLM_emb = self.align_dimension(LLM_emb) + align_loss = self.compute_align_loss_cosine(LLM_emb, target_emb, align_mask).mean() # mean for sequence length + return align_loss diff --git a/capvector-pi05/src/openpi/models_pytorch/transformers_replace/models/gemma/configuration_gemma.py b/capvector-pi05/src/openpi/models_pytorch/transformers_replace/models/gemma/configuration_gemma.py new file mode 100644 index 0000000000000000000000000000000000000000..472dd16f9377a6868e5f7659a76519b37483d31a --- /dev/null +++ b/capvector-pi05/src/openpi/models_pytorch/transformers_replace/models/gemma/configuration_gemma.py @@ -0,0 +1,173 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_gemma.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional +from ...configuration_utils import PretrainedConfig + + +class GemmaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Gemma-7B. + e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b) + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 256000): + Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GemmaModel`] + hidden_size (`int`, *optional*, defaults to 3072): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 24576): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 28): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 16): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + head_dim (`int`, *optional*, defaults to 256): + The attention head dimension. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The legacy activation function. It is overwritten by the `hidden_activation`. + hidden_activation (`str` or `function`, *optional*): + The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` + if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. + max_position_embeddings (`int`, *optional*, defaults to 8192): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + use_adarms (`bool`, *optional*, defaults to `False`): + Whether to use ADARMS. + adarms_cond_dim (`int`, *optional*, defaults to `None`): + The dimension of the ADARMS condition. + ```python + >>> from transformers import GemmaModel, GemmaConfig + >>> # Initializing a Gemma gemma-7b style configuration + >>> configuration = GemmaConfig() + >>> # Initializing a model from the gemma-7b style configuration + >>> model = GemmaModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "gemma" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=256000, + hidden_size=3072, + intermediate_size=24576, + num_hidden_layers=28, + num_attention_heads=16, + num_key_value_heads=16, + head_dim=256, + hidden_act="gelu_pytorch_tanh", + hidden_activation=None, + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + bos_token_id=2, + tie_word_embeddings=True, + rope_theta=10000.0, + attention_bias=False, + attention_dropout=0.0, + use_adarms: bool = False, + adarms_cond_dim: Optional[int] = None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.hidden_activation = hidden_activation + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.use_adarms = use_adarms + self.adarms_cond_dim = adarms_cond_dim + + # Set default for adarms_cond_dim if use_adarms is True + if self.use_adarms and self.adarms_cond_dim is None: + self.adarms_cond_dim = self.hidden_size + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["GemmaConfig"] \ No newline at end of file diff --git a/capvector-pi05/src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py b/capvector-pi05/src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py new file mode 100644 index 0000000000000000000000000000000000000000..8377a5bf8562945fe0f3b2c3545a91c7d7ac9238 --- /dev/null +++ b/capvector-pi05/src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py @@ -0,0 +1,862 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_gemma.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging +from .configuration_gemma import GemmaConfig + + +logger = logging.get_logger(__name__) + + +class GemmaRMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6, cond_dim: Optional[int] = None): + super().__init__() + self.eps = eps + self.dim = dim + self.cond_dim = cond_dim + + # Dense layer for adaptive normalization (if cond_dim is provided) + if cond_dim is not None: + #self.dense = nn.Linear(cond_dim, dim * 3, bias=True, dtype=torch.bfloat16) + self.dense = nn.Linear(cond_dim, dim * 3, bias=True) + # Initialize with zeros (matches source implementation) + nn.init.zeros_(self.dense.weight) + else: + self.weight = nn.Parameter(torch.zeros(dim, dtype=torch.bfloat16)) + self.dense = None + + def _norm(self, x): + # Compute variance in float32 (like the source implementation) + var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True) + # Compute normalization in float32 + normed_inputs = x * torch.rsqrt(var + self.eps) + return normed_inputs + + def forward(self, x, cond=None): + dtype = x.dtype # original dtype, could be half-precision + normed_inputs = self._norm(x) + + if cond is None or self.dense is None: + # regular RMSNorm + # scale by learned parameter in float32 (matches source implementation) + normed_inputs = normed_inputs * (1.0 + self.weight.float()) + return normed_inputs.to(dtype), None # return in original dtype with None gate + + # adaptive RMSNorm (if cond is provided and dense layer exists) + if cond.shape[-1] != self.cond_dim: + raise ValueError(f"Expected cond dimension {self.cond_dim}, got {cond.shape[-1]}") + + #self.dense.to(dtype=torch.bfloat16).to(dtype=torch.float32) + modulation = self.dense(cond) + # Reshape modulation to broadcast properly: [batch, 1, features] for [batch, seq, features] + if len(x.shape) == 3: # [batch, seq, features] + modulation = modulation.unsqueeze(1) + + scale, shift, gate = torch.chunk(modulation, 3, dim=-1) + + # Apply adaptive normalization: use model weight dtype to ensure compatibility + # model_dtype = self.dense.weight.dtype # Use the model's dtype (bfloat16) + # scale = scale.to(model_dtype) + # shift = shift.to(model_dtype) + # gate = gate.to(model_dtype) + # normed_inputs = normed_inputs.to(model_dtype) # Convert normed_inputs to model dtype + + normed_inputs = normed_inputs * (1 + scale.to(torch.float32)) + shift.to(torch.float32) + + return normed_inputs.to(dtype), gate.to(dtype) + + def extra_repr(self): + repr_str = f"{tuple(self.weight.shape)}, eps={self.eps}" + if self.dense is not None: + repr_str += f", adaptive=True, cond_dim={self.cond_dim}" + return repr_str + + +class GemmaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class GemmaRotaryEmbedding(nn.Module): + def __init__(self, config: GemmaConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def _gated_residual(x, y, gate): + """ + Applies gated residual connection with optional gate parameter. + + Args: + x: Input tensor (residual) + y: Output tensor to be added + gate: Optional gate tensor to modulate the addition + + Returns: + x + y if gate is None, otherwise x + y * gate + """ + if x is None and y is None: + return None + if x is None or y is None: + return x if x is not None else y + if gate is None: + return x + y + return x + y * gate + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class GemmaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: GemmaConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: bool = False, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + # Use cache if provided + if past_key_value is not None: + if use_cache: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + else: + key_states = torch.cat([past_key_value[self.layer_idx][0], key_states], dim=2) + value_states = torch.cat([past_key_value[self.layer_idx][1], value_states], dim=2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class GemmaDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: GemmaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx) + + self.mlp = GemmaMLP(config) + cond_dim = getattr(config, 'adarms_cond_dim', None) if getattr(config, 'use_adarms', False) else None + self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim) + self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + adarms_cond: Optional[torch.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states, gate = self.input_layernorm(hidden_states, adarms_cond) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = _gated_residual(residual, hidden_states, gate) + + # Fully Connected + residual = hidden_states + hidden_states, gate = self.post_attention_layernorm(hidden_states, adarms_cond) + hidden_states = self.mlp(hidden_states) + hidden_states = _gated_residual(residual, hidden_states, gate) + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +@auto_docstring +class GemmaPreTrainedModel(PreTrainedModel): + config_class = GemmaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["GemmaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, GemmaRMSNorm): + if hasattr(module, 'weight'): + module.weight.data.fill_(1.0) + + +@auto_docstring +class GemmaModel(GemmaPreTrainedModel): + def __init__(self, config: GemmaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + + cond_dim = getattr(config, 'adarms_cond_dim', None) if getattr(config, 'use_adarms', False) else None + self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim) + self.rotary_emb = GemmaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + adarms_cond: Optional[torch.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + """ + adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): + Condition for ADARMS. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + # embed positions + hidden_states = inputs_embeds + # Convert to bfloat16 if the first layer uses bfloat16 + if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.bfloat16) + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # normalized + # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) + #hidden_states = hidden_states * normalizer + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + adarms_cond=adarms_cond, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states, _ = self.norm(hidden_states, adarms_cond) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +@auto_docstring +class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = GemmaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + adarms_cond: Optional[torch.Tensor] = None, + **kwargs: Unpack[KwargsForCausalLM], + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): + Condition for ADARMS. + + Example: + + ```python + >>> from transformers import AutoTokenizer, GemmaForCausalLM + + >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + adarms_cond=adarms_cond, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The Gemma Model transformer with a sequence classification head on top (linear layer). + + [`GemmaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """ +) +class GemmaForSequenceClassification(GemmaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = GemmaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + adarms_cond: Optional[torch.Tensor] = None, + ) -> SequenceClassifierOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): + Condition for ADARMS. + """ + + transformer_outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + adarms_cond=adarms_cond, + ) + hidden_states = transformer_outputs.last_hidden_state + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@auto_docstring +class GemmaForTokenClassification(GemmaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = GemmaModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + adarms_cond: Optional[torch.Tensor] = None, + ) -> TokenClassifierOutput: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): + Condition for ADARMS. + """ + + outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + adarms_cond=adarms_cond, + ) + sequence_output = outputs.last_hidden_state + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "GemmaModel", + "GemmaForCausalLM", + "GemmaForSequenceClassification", + "GemmaForTokenClassification", + "GemmaPreTrainedModel", +] diff --git a/capvector-pi05/src/openpi/models_pytorch/transformers_replace/models/paligemma/modeling_paligemma.py b/capvector-pi05/src/openpi/models_pytorch/transformers_replace/models/paligemma/modeling_paligemma.py new file mode 100644 index 0000000000000000000000000000000000000000..a627b73246277095e3354b93158c98e1fa776897 --- /dev/null +++ b/capvector-pi05/src/openpi/models_pytorch/transformers_replace/models/paligemma/modeling_paligemma.py @@ -0,0 +1,622 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch PaliGemmamodel.""" + +from dataclasses import dataclass +from typing import Optional, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...cache_utils import Cache, HybridCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import LossKwargs, ModelOutput, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ..auto import AutoModel +from .configuration_paligemma import PaliGemmaConfig + + +logger = logging.get_logger(__name__) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Paligemma outputs, with hidden states and attentions. + """ +) +class PaligemmaModelOutputWithPast(BaseModelOutputWithPast): + r""" + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + image_hidden_states: Optional[torch.FloatTensor] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for PaliGemma causal language model (or autoregressive) outputs. + """ +) +class PaliGemmaCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder after projecting last hidden state. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + + +class PaliGemmaMultiModalProjector(nn.Module): + def __init__(self, config: PaliGemmaConfig): + super().__init__() + self.linear = nn.Linear(config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True) + + def forward(self, image_features): + hidden_states = self.linear(image_features) + + return hidden_states + + +@auto_docstring +class PaliGemmaPreTrainedModel(PreTrainedModel): + config_class = PaliGemmaConfig + base_model_prefix = "" + supports_gradient_checkpointing = True + _no_split_modules = ["PaliGemmaMultiModalProjector"] + _skip_keys_device_placement = "past_key_values" + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_attention_backend = True + + def _init_weights(self, module): + # important: this ported version of PaliGemmaisn't meant for training from scratch - only + # inference and fine-tuning + std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) + + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + + +@auto_docstring( + custom_intro=""" + The Base Paligemma model which consists of a vision backbone and a language model withou language modeling head., + """ +) +class PaliGemmaModel(PaliGemmaPreTrainedModel): + _checkpoint_conversion_mapping = {"language_model.model": "language_model"} + # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch + accepts_loss_kwargs = False + + def __init__(self, config: PaliGemmaConfig): + super().__init__(config) + self.vision_tower = AutoModel.from_config(config=config.vision_config) + self.multi_modal_projector = PaliGemmaMultiModalProjector(config) + self.vocab_size = config.text_config.vocab_size + + language_model = AutoModel.from_config(config=config.text_config) + self.language_model = language_model + + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self.post_init() + + # Copied from transformers.models.llava.modeling_llava.LlavaModel.get_input_embeddings with Llava->PaliGemma + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + # Copied from transformers.models.llava.modeling_llava.LlavaModel.set_input_embeddings with Llava->PaliGemma + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model + + def _update_causal_mask( + self, + attention_mask, + token_type_ids=None, + past_key_values=None, + cache_position=None, + input_tensor=None, + is_training: Optional[bool] = None, + ): + if self.config.text_config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + is_training = is_training if is_training is not None else self.training + using_static_cache = isinstance(past_key_values, StaticCache) + min_dtype = torch.finfo(self.dtype).min + if input_tensor is None: + input_tensor = attention_mask + + inputs_lead_dim, sequence_length = input_tensor.shape[:2] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + elif isinstance(past_key_values, HybridCache): + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else cache_position[0] + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + return attention_mask + + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device + ) + # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below + if sequence_length != 1: + if is_training: + causal_mask = torch.triu(causal_mask, diagonal=1) + else: + causal_mask[:, :sequence_length] = 0.0 + + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + + # First unmask prefix tokens during training + if is_training: + if token_type_ids is None: + raise ValueError("Token type ids must be provided during training") + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0 + ) + + # Then apply padding mask (will mask pad tokens) + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + def get_image_features(self, pixel_values: torch.FloatTensor): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + The tensors corresponding to the input images. + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + image_outputs = self.vision_tower(pixel_values) + selected_image_feature = image_outputs.last_hidden_state + image_features = self.multi_modal_projector(selected_image_feature) + return image_features + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None, + token_type_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, PaligemmaModelOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration + + >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224") + >>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224") + + >>> prompt = "Where is the cat standing?" + >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs,) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Where is the cat standing?\nsnow" + ```""" + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + is_training = token_type_ids is not None and labels is not None + + # Replace image id woth PAD if the image token if OOV, to avoid index-errors + if input_ids is not None and self.config.image_token_id >= self.vocab_size: + special_image_mask = input_ids == self.config.image_token_id + llm_input_ids = input_ids.clone() + llm_input_ids[special_image_mask] = 0 + else: + llm_input_ids = input_ids + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(llm_input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed + + # Merge text and images + if pixel_values is not None: + image_features = self.get_image_features(pixel_values) + + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + else: + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] + raise ValueError( + f"Number of images does not match number of special image tokens in the input text. " + f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " + "tokens from image embeddings." + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + causal_mask = self._update_causal_mask( + attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training + ) + outputs = self.language_model( + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + return PaligemmaModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +@auto_docstring( + custom_intro=""" + The Base Paligemma model which consists of a vision backbone and a language model without language modeling head., + """ +) +class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = { + "^language_model.model": "model.language_model", + "^vision_tower": "model.vision_tower", + "^multi_modal_projector": "model.multi_modal_projector", + "^language_model.lm_head": "lm_head", + } + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: PaliGemmaConfig): + super().__init__(config) + self.model = PaliGemmaModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.set_decoder(decoder) + + def get_decoder(self): + return self.model.get_decoder() + + def get_image_features(self, pixel_values): + return self.model.get_image_features(pixel_values) + + # Make modules available throught conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def vision_tower(self): + return self.model.vision_tower + + @property + def multi_modal_projector(self): + return self.model.multi_modal_projector + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None, + token_type_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[tuple, PaliGemmaCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration + + >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224") + >>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224") + + >>> prompt = "Where is the cat standing?" + >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs,) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Where is the cat standing?\nsnow" + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + labels=labels, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return PaliGemmaCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + pixel_values=None, + attention_mask=None, + token_type_ids=None, + use_cache=True, + logits_to_keep=None, + labels=None, + **kwargs, + ): + # Overwritten -- custom `position_ids` and `pixel_values` handling + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + cache_position=cache_position, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + token_type_ids=token_type_ids, + **kwargs, + ) + + # position_ids in Paligemma are 1-indexed + if model_inputs.get("position_ids") is not None: + model_inputs["position_ids"] += 1 + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always + if cache_position[0] == 0: + model_inputs["pixel_values"] = pixel_values + is_training = token_type_ids is not None and labels is not None + if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): + input_tensor = inputs_embeds if inputs_embeds is not None else input_ids + causal_mask = self.model._update_causal_mask( + attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training + ) + model_inputs["attention_mask"] = causal_mask + + return model_inputs + + @staticmethod + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +__all__ = ["PaliGemmaForConditionalGeneration", "PaliGemmaPreTrainedModel", "PaliGemmaModel"] diff --git a/capvector-pi05/src/openpi/models_pytorch/transformers_replace/models/siglip/check.py b/capvector-pi05/src/openpi/models_pytorch/transformers_replace/models/siglip/check.py new file mode 100644 index 0000000000000000000000000000000000000000..89cc2ad4359d5273f3631410cbebbe845100cce6 --- /dev/null +++ b/capvector-pi05/src/openpi/models_pytorch/transformers_replace/models/siglip/check.py @@ -0,0 +1,4 @@ +import transformers + +def check_whether_transformers_replace_is_installed_correctly(): + return transformers.__version__ == "4.53.2" \ No newline at end of file diff --git a/capvector-pi05/src/openpi/models_pytorch/transformers_replace/models/siglip/modeling_siglip.py b/capvector-pi05/src/openpi/models_pytorch/transformers_replace/models/siglip/modeling_siglip.py new file mode 100644 index 0000000000000000000000000000000000000000..0bf8bec4a068c964dd038ce9060513f61a0b8ff4 --- /dev/null +++ b/capvector-pi05/src/openpi/models_pytorch/transformers_replace/models/siglip/modeling_siglip.py @@ -0,0 +1,1237 @@ +# coding=utf-8 +# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Siglip model.""" + +import math +import warnings +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from torch.nn.init import _calculate_fan_in_and_fan_out + +from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int +from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig + + +logger = logging.get_logger(__name__) + + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + + +def trunc_normal_tf_( + tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 +) -> torch.Tensor: + """Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \\leq \text{mean} \\leq b`. + + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsequently scaled and shifted by the mean and std args. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.mul_(std).add_(mean) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + with torch.no_grad(): + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + with torch.no_grad(): + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def default_flax_embed_init(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="normal") + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + """ +) +# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip +class SiglipVisionModelOutput(ModelOutput): + r""" + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for text model's outputs that also contains a pooling of the last hidden states. + """ +) +# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip +class SiglipTextModelOutput(ModelOutput): + r""" + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The text embeddings obtained by applying the projection layer to the pooler_output. + """ + + text_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring +# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip +class SiglipOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`]. + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`]. + text_model_output (`BaseModelOutputWithPooling`): + The output of the [`SiglipTextModel`]. + vision_model_output (`BaseModelOutputWithPooling`): + The output of the [`SiglipVisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits_per_image: Optional[torch.FloatTensor] = None + logits_per_text: Optional[torch.FloatTensor] = None + text_embeds: Optional[torch.FloatTensor] = None + image_embeds: Optional[torch.FloatTensor] = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +class SiglipVisionEmbeddings(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing and no class embeddings. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] + num_positions = self.position_embedding.weight.shape[0] + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embedding(self.position_ids) + + patch_pos_embed = self.position_embedding.weight.unsqueeze(0) + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return patch_pos_embed + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: + _, _, height, width = pixel_values.shape + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip +class SiglipTextEmbeddings(nn.Module): + def __init__(self, config: SiglipTextConfig): + super().__init__() + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + max_position_embedding = self.position_embedding.weight.shape[0] + + if seq_length > max_position_embedding: + raise ValueError( + f"Sequence length must be less than max_position_embeddings (got `sequence length`: " + f"{seq_length} and max_position_embeddings: {max_position_embedding}" + ) + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class SiglipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.is_causal = False + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + batch_size, seq_length, embed_dim = hidden_states.shape + + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) + + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + is_causal=self.is_causal, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + ) + + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip +class SiglipMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class SiglipEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Union[SiglipVisionConfig, SiglipTextConfig]): + super().__init__() + self.embed_dim = config.hidden_size + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.self_attn = SiglipAttention(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +@auto_docstring +class SiglipPreTrainedModel(PreTrainedModel): + config_class = SiglipConfig + base_model_prefix = "siglip" + supports_gradient_checkpointing = True + + _no_split_modules = [ + "SiglipTextEmbeddings", + "SiglipEncoderLayer", + "SiglipVisionEmbeddings", + "SiglipEncoderLayer", + "SiglipMultiheadAttentionPoolingHead", + ] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_attention_backend = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, SiglipVisionEmbeddings): + width = ( + self.config.vision_config.hidden_size + if isinstance(self.config, SiglipConfig) + else self.config.hidden_size + ) + nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + elif isinstance(module, nn.Embedding): + default_flax_embed_init(module.weight) + elif isinstance(module, SiglipAttention): + nn.init.xavier_uniform_(module.q_proj.weight) + nn.init.xavier_uniform_(module.k_proj.weight) + nn.init.xavier_uniform_(module.v_proj.weight) + nn.init.xavier_uniform_(module.out_proj.weight) + nn.init.zeros_(module.q_proj.bias) + nn.init.zeros_(module.k_proj.bias) + nn.init.zeros_(module.v_proj.bias) + nn.init.zeros_(module.out_proj.bias) + elif isinstance(module, SiglipMLP): + nn.init.xavier_uniform_(module.fc1.weight) + nn.init.xavier_uniform_(module.fc2.weight) + nn.init.normal_(module.fc1.bias, std=1e-6) + nn.init.normal_(module.fc2.bias, std=1e-6) + elif isinstance(module, SiglipMultiheadAttentionPoolingHead): + nn.init.xavier_uniform_(module.probe.data) + nn.init.xavier_uniform_(module.attention.in_proj_weight.data) + nn.init.zeros_(module.attention.in_proj_bias.data) + elif isinstance(module, SiglipModel): + logit_scale_init = torch.log(torch.tensor(1.0)) + module.logit_scale.data.fill_(logit_scale_init) + module.logit_bias.data.zero_() + elif isinstance(module, SiglipForImageClassification): + nn.init.normal_( + module.classifier.weight, + std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor, + ) + elif isinstance(module, (nn.Linear, nn.Conv2d)): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Siglip +class SiglipEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`SiglipEncoderLayer`]. + + Args: + config: SiglipConfig + """ + + def __init__(self, config: SiglipConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + # Ignore copy + @can_return_tuple + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> BaseModelOutput: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions, + ) + + +class SiglipTextTransformer(nn.Module): + def __init__(self, config: SiglipTextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = SiglipTextEmbeddings(config) + self.encoder = SiglipEncoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.head = nn.Linear(embed_dim, config.projection_size) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> BaseModelOutputWithPooling: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model. + # expand attention_mask + if attention_mask is not None and not self._use_flash_attention_2: + # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # Assuming "sticky" EOS tokenization, last token is always EOS. + pooled_output = last_hidden_state[:, -1, :] + pooled_output = self.head(pooled_output) + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The text model from SigLIP without any head or projection on top. + """ +) +class SiglipTextModel(SiglipPreTrainedModel): + config_class = SiglipTextConfig + + def __init__(self, config: SiglipTextConfig): + super().__init__(config) + self.text_model = SiglipTextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.text_model.embeddings.token_embedding = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> BaseModelOutputWithPooling: + r""" + Examples: + + ```python + >>> from transformers import AutoTokenizer, SiglipTextModel + + >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224") + >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") + + >>> # important: make sure to set padding="max_length" as that's how the model was trained + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + +class SiglipVisionTransformer(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SiglipVisionEmbeddings(config) + self.encoder = SiglipEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head + if self.use_head: + self.head = SiglipMultiheadAttentionPoolingHead(config) + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = False, + ) -> BaseModelOutputWithPooling: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + # Convert to bfloat16 if the encoder uses bfloat16 + if len(self.encoder.layers) > 0 and self.encoder.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.bfloat16) + + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooler_output = self.head(last_hidden_state) if self.use_head else None + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooler_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class SiglipMultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + + def forward(self, hidden_state): + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention(probe, hidden_state, hidden_state)[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +@auto_docstring( + custom_intro=""" + The vision model from SigLIP without any head or projection on top. + """ +) +class SiglipVisionModel(SiglipPreTrainedModel): + config_class = SiglipVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: SiglipVisionConfig): + super().__init__(config) + + self.vision_model = SiglipVisionTransformer(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + ) -> BaseModelOutputWithPooling: + r""" + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, SiglipVisionModel + + >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled features + ```""" + + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + +@auto_docstring +class SiglipModel(SiglipPreTrainedModel): + config_class = SiglipConfig + + def __init__(self, config: SiglipConfig): + super().__init__(config) + + if not isinstance(config.text_config, SiglipTextConfig): + raise TypeError( + "config.text_config is expected to be of type SiglipTextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, SiglipVisionConfig): + raise TypeError( + "config.vision_config is expected to be of type SiglipVisionConfig but is of type" + f" {type(config.vision_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + + # First, initialize the text and vision models with proper attention implementation + text_model = SiglipTextModel._from_config(text_config) + vision_model = SiglipVisionModel._from_config(vision_config) + + # Second, get the text and vision submodules (for backward compatibility) + self.text_model = text_model.text_model + self.vision_model = vision_model.vision_model + + self.logit_scale = nn.Parameter(torch.randn(1)) + self.logit_bias = nn.Parameter(torch.randn(1)) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`SiglipTextModel`]. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, AutoModel + >>> import torch + + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") + + >>> # important: make sure to set padding="max_length" as that's how the model was trained + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") + >>> with torch.no_grad(): + ... text_features = model.get_text_features(**inputs) + ```""" + # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + text_outputs: BaseModelOutputWithPooling = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + pooled_output = text_outputs.pooler_output + + return pooled_output + + @auto_docstring + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`SiglipVisionModel`]. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModel + >>> import torch + + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> with torch.no_grad(): + ... image_features = model.get_image_features(**inputs) + ```""" + # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + vision_outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + pooled_output = vision_outputs.pooler_output + + return pooled_output + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + ) -> SiglipOutput: + r""" + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModel + >>> import torch + + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"] + >>> # important: we pass `padding=max_length` since the model was trained with this + >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> logits_per_image = outputs.logits_per_image + >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities + >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") + 31.9% that image 0 is 'a photo of 2 cats' + ```""" + # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + vision_outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + text_outputs: BaseModelOutputWithPooling = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + image_embeds = vision_outputs.pooler_output + text_embeds = text_outputs.pooler_output + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) + + logit_scale, logit_bias = self.logit_scale.to(text_embeds.device), self.logit_bias.to(text_embeds.device) + logits_per_text = logits_per_text * logit_scale.exp() + logit_bias + + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip.py#L287 + eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device) + m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye + loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text) + nll = -torch.sum(loglik, dim=-1) + loss = nll.mean() + + return SiglipOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +@auto_docstring( + custom_intro=""" + SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of + the patch tokens) e.g. for ImageNet. + """ +) +class SiglipForImageClassification(SiglipPreTrainedModel): + main_input_name = "pixel_values" + + def __init__(self, config: SiglipConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + + # Create the vision model with proper attention + # and take only vision_model submodule (for backward compatibility) + vision_model = SiglipVisionModel._from_config(config.vision_config) + self.vision_model = vision_model.vision_model + + # Classifier head + self.classifier = ( + nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + ) -> ImageClassifierOutput: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, SiglipForImageClassification + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> # note: we are loading a `SiglipModel` from the hub here, + >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above. + >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224") + >>> model = SiglipForImageClassification.from_pretrained("google/siglip-base-patch16-224") + + >>> inputs = image_processor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> # model predicts one of the two classes + >>> predicted_class_idx = logits.argmax(-1).item() + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + Predicted class: LABEL_1 + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + sequence_output = outputs.last_hidden_state + + # average pool the patch tokens + sequence_output = torch.mean(sequence_output, dim=1) + # apply classifier + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "SiglipModel", + "SiglipPreTrainedModel", + "SiglipTextModel", + "SiglipVisionModel", + "SiglipForImageClassification", +] \ No newline at end of file diff --git a/capvector-pi05/src/openpi/policies/aloha_policy.py b/capvector-pi05/src/openpi/policies/aloha_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..b006f736096b6d8301262be477ac6b2329707337 --- /dev/null +++ b/capvector-pi05/src/openpi/policies/aloha_policy.py @@ -0,0 +1,202 @@ +import dataclasses +from typing import ClassVar + +import einops +import numpy as np + +from openpi import transforms + + +def make_aloha_example() -> dict: + """Creates a random input example for the Aloha policy.""" + return { + "state": np.ones((14,)), + "images": { + "cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), + "cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), + "cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), + "cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), + }, + "prompt": "do something", + } + + +@dataclasses.dataclass(frozen=True) +class AlohaInputs(transforms.DataTransformFn): + """Inputs for the Aloha policy. + + Expected inputs: + - images: dict[name, img] where img is [channel, height, width]. name must be in EXPECTED_CAMERAS. + - state: [14] + - actions: [action_horizon, 14] + """ + + # If true, this will convert the joint and gripper values from the standard Aloha space to + # the space used by the pi internal runtime which was used to train the base model. + adapt_to_pi: bool = True + + # The expected cameras names. All input cameras must be in this set. Missing cameras will be + # replaced with black images and the corresponding `image_mask` will be set to False. + EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = ("cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist") + + def __call__(self, data: dict) -> dict: + data = _decode_aloha(data, adapt_to_pi=self.adapt_to_pi) + + in_images = data["images"] + if set(in_images) - set(self.EXPECTED_CAMERAS): + raise ValueError(f"Expected images to contain {self.EXPECTED_CAMERAS}, got {tuple(in_images)}") + + # Assume that base image always exists. + base_image = in_images["cam_high"] + + images = { + "base_0_rgb": base_image, + } + image_masks = { + "base_0_rgb": np.True_, + } + + # Add the extra images. + extra_image_names = { + "left_wrist_0_rgb": "cam_left_wrist", + "right_wrist_0_rgb": "cam_right_wrist", + } + for dest, source in extra_image_names.items(): + if source in in_images: + images[dest] = in_images[source] + image_masks[dest] = np.True_ + else: + images[dest] = np.zeros_like(base_image) + image_masks[dest] = np.False_ + + inputs = { + "image": images, + "image_mask": image_masks, + "state": data["state"], + } + + # Actions are only available during training. + if "actions" in data: + actions = np.asarray(data["actions"]) + actions = _encode_actions_inv(actions, adapt_to_pi=self.adapt_to_pi) + inputs["actions"] = actions + + if "prompt" in data: + inputs["prompt"] = data["prompt"] + + return inputs + + +@dataclasses.dataclass(frozen=True) +class AlohaOutputs(transforms.DataTransformFn): + """Outputs for the Aloha policy.""" + + # If true, this will convert the joint and gripper values from the standard Aloha space to + # the space used by the pi internal runtime which was used to train the base model. + adapt_to_pi: bool = True + + def __call__(self, data: dict) -> dict: + # Only return the first 14 dims. + actions = np.asarray(data["actions"][:, :14]) + return {"actions": _encode_actions(actions, adapt_to_pi=self.adapt_to_pi)} + + +def _joint_flip_mask() -> np.ndarray: + """Used to convert between aloha and pi joint angles.""" + return np.array([1, -1, -1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1]) + + +def _normalize(x, min_val, max_val): + return (x - min_val) / (max_val - min_val) + + +def _unnormalize(x, min_val, max_val): + return x * (max_val - min_val) + min_val + + +def _gripper_to_angular(value): + # Aloha transforms the gripper positions into a linear space. The following code + # reverses this transformation to be consistent with pi0 which is pretrained in + # angular space. + # + # These values are coming from the Aloha code: + # PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED + value = _unnormalize(value, min_val=0.01844, max_val=0.05800) + + # This is the inverse of the angular to linear transformation inside the Interbotix code. + def linear_to_radian(linear_position, arm_length, horn_radius): + value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position) + return np.arcsin(np.clip(value, -1.0, 1.0)) + + # The constants are taken from the Interbotix code. + value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022) + + # pi0 gripper data is normalized (0, 1) between encoder counts (2405, 3110). + # There are 4096 total encoder counts and aloha uses a zero of 2048. + # Converting this to radians means that the normalized inputs are between (0.5476, 1.6296) + return _normalize(value, min_val=0.5476, max_val=1.6296) + + +def _gripper_from_angular(value): + # Convert from the gripper position used by pi0 to the gripper position that is used by Aloha. + # Note that the units are still angular but the range is different. + + # We do not scale the output since the trossen model predictions are already in radians. + # See the comment in _gripper_to_angular for a derivation of the constant + value = value + 0.5476 + + # These values are coming from the Aloha code: + # PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE + return _normalize(value, min_val=-0.6213, max_val=1.4910) + + +def _gripper_from_angular_inv(value): + # Directly inverts the gripper_from_angular function. + value = _unnormalize(value, min_val=-0.6213, max_val=1.4910) + return value - 0.5476 + + +def _decode_aloha(data: dict, *, adapt_to_pi: bool = False) -> dict: + # state is [left_arm_joint_angles, left_arm_gripper, right_arm_joint_angles, right_arm_gripper] + # dim sizes: [6, 1, 6, 1] + state = np.asarray(data["state"]) + state = _decode_state(state, adapt_to_pi=adapt_to_pi) + + def convert_image(img): + img = np.asarray(img) + # Convert to uint8 if using float images. + if np.issubdtype(img.dtype, np.floating): + img = (255 * img).astype(np.uint8) + # Convert from [channel, height, width] to [height, width, channel]. + return einops.rearrange(img, "c h w -> h w c") + + images = data["images"] + images_dict = {name: convert_image(img) for name, img in images.items()} + + data["images"] = images_dict + data["state"] = state + return data + + +def _decode_state(state: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray: + if adapt_to_pi: + # Flip the joints. + state = _joint_flip_mask() * state + # Reverse the gripper transformation that is being applied by the Aloha runtime. + state[[6, 13]] = _gripper_to_angular(state[[6, 13]]) + return state + + +def _encode_actions(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray: + if adapt_to_pi: + # Flip the joints. + actions = _joint_flip_mask() * actions + actions[:, [6, 13]] = _gripper_from_angular(actions[:, [6, 13]]) + return actions + + +def _encode_actions_inv(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray: + if adapt_to_pi: + actions = _joint_flip_mask() * actions + actions[:, [6, 13]] = _gripper_from_angular_inv(actions[:, [6, 13]]) + return actions diff --git a/capvector-pi05/src/openpi/policies/droid_policy.py b/capvector-pi05/src/openpi/policies/droid_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..786985d9268f1aed0a57bf2bb780ca6da692683f --- /dev/null +++ b/capvector-pi05/src/openpi/policies/droid_policy.py @@ -0,0 +1,81 @@ +import dataclasses + +import einops +import numpy as np + +from openpi import transforms +from openpi.models import model as _model + + +def make_droid_example() -> dict: + """Creates a random input example for the Droid policy.""" + return { + "observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), + "observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), + "observation/joint_position": np.random.rand(7), + "observation/gripper_position": np.random.rand(1), + "prompt": "do something", + } + + +def _parse_image(image) -> np.ndarray: + image = np.asarray(image) + if np.issubdtype(image.dtype, np.floating): + image = (255 * image).astype(np.uint8) + if image.shape[0] == 3: + image = einops.rearrange(image, "c h w -> h w c") + return image + + +@dataclasses.dataclass(frozen=True) +class DroidInputs(transforms.DataTransformFn): + # Determines which model will be used. + model_type: _model.ModelType + + def __call__(self, data: dict) -> dict: + gripper_pos = np.asarray(data["observation/gripper_position"]) + if gripper_pos.ndim == 0: + # Ensure gripper position is a 1D array, not a scalar, so we can concatenate with joint positions + gripper_pos = gripper_pos[np.newaxis] + state = np.concatenate([data["observation/joint_position"], gripper_pos]) + + # Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically + # stores as float32 (C,H,W), gets skipped for policy inference + base_image = _parse_image(data["observation/exterior_image_1_left"]) + wrist_image = _parse_image(data["observation/wrist_image_left"]) + + match self.model_type: + case _model.ModelType.PI0 | _model.ModelType.PI05: + names = ("base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb") + images = (base_image, wrist_image, np.zeros_like(base_image)) + image_masks = (np.True_, np.True_, np.False_) + case _model.ModelType.PI0_FAST: + names = ("base_0_rgb", "base_1_rgb", "wrist_0_rgb") + # We don't mask out padding images for FAST models. + images = (base_image, np.zeros_like(base_image), wrist_image) + image_masks = (np.True_, np.True_, np.True_) + case _: + raise ValueError(f"Unsupported model type: {self.model_type}") + + inputs = { + "state": state, + "image": dict(zip(names, images, strict=True)), + "image_mask": dict(zip(names, image_masks, strict=True)), + } + + if "actions" in data: + inputs["actions"] = np.asarray(data["actions"]) + + if "prompt" in data: + if isinstance(data["prompt"], bytes): + data["prompt"] = data["prompt"].decode("utf-8") + inputs["prompt"] = data["prompt"] + + return inputs + + +@dataclasses.dataclass(frozen=True) +class DroidOutputs(transforms.DataTransformFn): + def __call__(self, data: dict) -> dict: + # Only return the first 8 dims. + return {"actions": np.asarray(data["actions"][:, :8])} diff --git a/capvector-pi05/src/openpi/policies/libero_policy.py b/capvector-pi05/src/openpi/policies/libero_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..7b51e93d201e73e155aba27db9c0fe531d93d074 --- /dev/null +++ b/capvector-pi05/src/openpi/policies/libero_policy.py @@ -0,0 +1,100 @@ +import dataclasses + +import einops +import numpy as np + +from openpi import transforms +from openpi.models import model as _model + + +def make_libero_example() -> dict: + """Creates a random input example for the Libero policy.""" + return { + "observation/state": np.random.rand(8), + "observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), + "observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), + "prompt": "do something", + } + + +def _parse_image(image) -> np.ndarray: + image = np.asarray(image) + if np.issubdtype(image.dtype, np.floating): + image = (255 * image).astype(np.uint8) + if image.shape[0] == 3: + image = einops.rearrange(image, "c h w -> h w c") + return image + + +@dataclasses.dataclass(frozen=True) +class LiberoInputs(transforms.DataTransformFn): + """ + This class is used to convert inputs to the model to the expected format. It is used for both training and inference. + + For your own dataset, you can copy this class and modify the keys based on the comments below to pipe + the correct elements of your dataset into the model. + """ + + # Determines which model will be used. + # Do not change this for your own dataset. + model_type: _model.ModelType + + def __call__(self, data: dict) -> dict: + # Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically + # stores as float32 (C,H,W), gets skipped for policy inference. + # Keep this for your own dataset, but if your dataset stores the images + # in a different key than "observation/image" or "observation/wrist_image", + # you should change it below. + # Pi0 models support three image inputs at the moment: one third-person view, + # and two wrist views (left and right). If your dataset does not have a particular type + # of image, e.g. wrist images, you can comment it out here and replace it with zeros like we do for the + # right wrist image below. + base_image = _parse_image(data["observation/image"]) + wrist_image = _parse_image(data["observation/wrist_image"]) + + # Create inputs dict. Do not change the keys in the dict below. + inputs = { + "state": data["observation/state"], + "image": { + "base_0_rgb": base_image, + "left_wrist_0_rgb": wrist_image, + # Pad any non-existent images with zero-arrays of the appropriate shape. + "right_wrist_0_rgb": np.zeros_like(base_image), + }, + "image_mask": { + "base_0_rgb": np.True_, + "left_wrist_0_rgb": np.True_, + # We only mask padding images for pi0 model, not pi0-FAST. Do not change this for your own dataset. + "right_wrist_0_rgb": np.True_ if self.model_type == _model.ModelType.PI0_FAST else np.False_, + }, + } + + # Pad actions to the model action dimension. Keep this for your own dataset. + # Actions are only available during training. + if "actions" in data: + inputs["actions"] = data["actions"] + + # Pass the prompt (aka language instruction) to the model. + # Keep this for your own dataset (but modify the key if the instruction is not + # stored in "prompt"; the output dict always needs to have the key "prompt"). + if "prompt" in data: + inputs["prompt"] = data["prompt"] + + return inputs + + +@dataclasses.dataclass(frozen=True) +class LiberoOutputs(transforms.DataTransformFn): + """ + This class is used to convert outputs from the model back the the dataset specific format. It is + used for inference only. + + For your own dataset, you can copy this class and modify the action dimension based on the comments below. + """ + + def __call__(self, data: dict) -> dict: + # Only return the first N actions -- since we padded actions above to fit the model action + # dimension, we need to now parse out the correct number of actions in the return dict. + # For Libero, we only return the first 7 actions (since the rest is padding). + # For your own dataset, replace `7` with the action dimension of your dataset. + return {"actions": np.asarray(data["actions"][:, :7])} diff --git a/capvector-pi05/src/openpi/policies/policy.py b/capvector-pi05/src/openpi/policies/policy.py new file mode 100644 index 0000000000000000000000000000000000000000..334f50d68038e02780d0a937cfdae4a9826c0b7e --- /dev/null +++ b/capvector-pi05/src/openpi/policies/policy.py @@ -0,0 +1,135 @@ +from collections.abc import Sequence +import logging +import pathlib +import time +from typing import Any, TypeAlias + +import flax +import flax.traverse_util +import jax +import jax.numpy as jnp +import numpy as np +from openpi_client import base_policy as _base_policy +import torch +from typing_extensions import override + +from openpi import transforms as _transforms +from openpi.models import model as _model +from openpi.shared import array_typing as at +from openpi.shared import nnx_utils + +BasePolicy: TypeAlias = _base_policy.BasePolicy + + +class Policy(BasePolicy): + def __init__( + self, + model: _model.BaseModel, + *, + rng: at.KeyArrayLike | None = None, + transforms: Sequence[_transforms.DataTransformFn] = (), + output_transforms: Sequence[_transforms.DataTransformFn] = (), + sample_kwargs: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, + pytorch_device: str = "cpu", + is_pytorch: bool = False, + ): + """Initialize the Policy. + + Args: + model: The model to use for action sampling. + rng: Random number generator key for JAX models. Ignored for PyTorch models. + transforms: Input data transformations to apply before inference. + output_transforms: Output data transformations to apply after inference. + sample_kwargs: Additional keyword arguments to pass to model.sample_actions. + metadata: Additional metadata to store with the policy. + pytorch_device: Device to use for PyTorch models (e.g., "cpu", "cuda:0"). + Only relevant when is_pytorch=True. + is_pytorch: Whether the model is a PyTorch model. If False, assumes JAX model. + """ + self._model = model + self._input_transform = _transforms.compose(transforms) + self._output_transform = _transforms.compose(output_transforms) + self._sample_kwargs = sample_kwargs or {} + self._metadata = metadata or {} + self._is_pytorch_model = is_pytorch + self._pytorch_device = pytorch_device + + if self._is_pytorch_model: + self._model = self._model.to(pytorch_device) + self._model.eval() + self._sample_actions = model.sample_actions + else: + # JAX model setup + self._sample_actions = nnx_utils.module_jit(model.sample_actions) + self._rng = rng or jax.random.key(0) + + @override + def infer(self, obs: dict, *, noise: np.ndarray | None = None) -> dict: # type: ignore[misc] + # Make a copy since transformations may modify the inputs in place. + inputs = jax.tree.map(lambda x: x, obs) + inputs = self._input_transform(inputs) + if not self._is_pytorch_model: + # Make a batch and convert to jax.Array. + inputs = jax.tree.map(lambda x: jnp.asarray(x)[np.newaxis, ...], inputs) + self._rng, sample_rng_or_pytorch_device = jax.random.split(self._rng) + else: + # Convert inputs to PyTorch tensors and move to correct device + inputs = jax.tree.map(lambda x: torch.from_numpy(np.array(x)).to(self._pytorch_device)[None, ...], inputs) + sample_rng_or_pytorch_device = self._pytorch_device + + # Prepare kwargs for sample_actions + sample_kwargs = dict(self._sample_kwargs) + if noise is not None: + noise = torch.from_numpy(noise).to(self._pytorch_device) if self._is_pytorch_model else jnp.asarray(noise) + + if noise.ndim == 2: # If noise is (action_horizon, action_dim), add batch dimension + noise = noise[None, ...] # Make it (1, action_horizon, action_dim) + sample_kwargs["noise"] = noise + + observation = _model.Observation.from_dict(inputs) + start_time = time.monotonic() + outputs = { + "state": inputs["state"], + "actions": self._sample_actions(sample_rng_or_pytorch_device, observation, **sample_kwargs), + } + model_time = time.monotonic() - start_time + if self._is_pytorch_model: + outputs = jax.tree.map(lambda x: np.asarray(x[0, ...].detach().cpu()), outputs) + else: + outputs = jax.tree.map(lambda x: np.asarray(x[0, ...]), outputs) + + outputs = self._output_transform(outputs) + outputs["policy_timing"] = { + "infer_ms": model_time * 1000, + } + return outputs + + @property + def metadata(self) -> dict[str, Any]: + return self._metadata + + +class PolicyRecorder(_base_policy.BasePolicy): + """Records the policy's behavior to disk.""" + + def __init__(self, policy: _base_policy.BasePolicy, record_dir: str): + self._policy = policy + + logging.info(f"Dumping policy records to: {record_dir}") + self._record_dir = pathlib.Path(record_dir) + self._record_dir.mkdir(parents=True, exist_ok=True) + self._record_step = 0 + + @override + def infer(self, obs: dict) -> dict: # type: ignore[misc] + results = self._policy.infer(obs) + + data = {"inputs": obs, "outputs": results} + data = flax.traverse_util.flatten_dict(data, sep="/") + + output_path = self._record_dir / f"step_{self._record_step}" + self._record_step += 1 + + np.save(output_path, np.asarray(data)) + return results diff --git a/capvector-pi05/src/openpi/policies/policy_config.py b/capvector-pi05/src/openpi/policies/policy_config.py new file mode 100644 index 0000000000000000000000000000000000000000..18bc2211348f24fc29df58872115aaf826636e1c --- /dev/null +++ b/capvector-pi05/src/openpi/policies/policy_config.py @@ -0,0 +1,94 @@ +import logging +import os +import pathlib +from typing import Any + +import jax.numpy as jnp + +import openpi.models.model as _model +import openpi.policies.policy as _policy +import openpi.shared.download as download +from openpi.training import checkpoints as _checkpoints +from openpi.training import config as _config +import openpi.transforms as transforms + + +def create_trained_policy( + train_config: _config.TrainConfig, + checkpoint_dir: pathlib.Path | str, + *, + repack_transforms: transforms.Group | None = None, + sample_kwargs: dict[str, Any] | None = None, + default_prompt: str | None = None, + norm_stats: dict[str, transforms.NormStats] | None = None, + pytorch_device: str | None = None, +) -> _policy.Policy: + """Create a policy from a trained checkpoint. + + Args: + train_config: The training config to use to create the model. + checkpoint_dir: The directory to load the model from. + repack_transforms: Optional transforms that will be applied before any other transforms. + sample_kwargs: The kwargs to pass to the `sample_actions` method. If not provided, the default + kwargs will be used. + default_prompt: The default prompt to use for the policy. Will inject the prompt into the input + data if it doesn't already exist. + norm_stats: The norm stats to use for the policy. If not provided, the norm stats will be loaded + from the checkpoint directory. + pytorch_device: Device to use for PyTorch models (e.g., "cpu", "cuda", "cuda:0"). + If None and is_pytorch=True, will use "cuda" if available, otherwise "cpu". + + Note: + The function automatically detects whether the model is PyTorch-based by checking for the + presence of "model.safensors" in the checkpoint directory. + """ + repack_transforms = repack_transforms or transforms.Group() + checkpoint_dir = download.maybe_download(str(checkpoint_dir)) + + # Check if this is a PyTorch model by looking for model.safetensors + weight_path = os.path.join(checkpoint_dir, "model.safetensors") + is_pytorch = os.path.exists(weight_path) + + logging.info("Loading model...") + if is_pytorch: + model = train_config.model.load_pytorch(train_config, weight_path) + model.paligemma_with_expert.to_bfloat16_for_selected_params("bfloat16") + else: + model = train_config.model.load(_model.restore_params(checkpoint_dir / "params", dtype=jnp.bfloat16)) + data_config = train_config.data.create(train_config.assets_dirs, train_config.model) + if norm_stats is None: + # We are loading the norm stats from the checkpoint instead of the config assets dir to make sure + # that the policy is using the same normalization stats as the original training process. + if data_config.asset_id is None: + raise ValueError("Asset id is required to load norm stats.") + norm_stats = _checkpoints.load_norm_stats(checkpoint_dir / "assets", data_config.asset_id) + + # Determine the device to use for PyTorch models + if is_pytorch and pytorch_device is None: + try: + import torch + + pytorch_device = "cuda" if torch.cuda.is_available() else "cpu" + except ImportError: + pytorch_device = "cpu" + + return _policy.Policy( + model, + transforms=[ + *repack_transforms.inputs, + transforms.InjectDefaultPrompt(default_prompt), + *data_config.data_transforms.inputs, + transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm), + *data_config.model_transforms.inputs, + ], + output_transforms=[ + *data_config.model_transforms.outputs, + transforms.Unnormalize(norm_stats, use_quantiles=data_config.use_quantile_norm), + *data_config.data_transforms.outputs, + *repack_transforms.outputs, + ], + sample_kwargs=sample_kwargs, + metadata=train_config.policy_metadata, + is_pytorch=is_pytorch, + pytorch_device=pytorch_device if is_pytorch else None, + ) diff --git a/capvector-pi05/src/openpi/py.typed b/capvector-pi05/src/openpi/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/capvector-pi05/src/openpi/shared/array_typing.py b/capvector-pi05/src/openpi/shared/array_typing.py new file mode 100644 index 0000000000000000000000000000000000000000..fed20bfa1c79d540815bf2c444e803ba34383c9f --- /dev/null +++ b/capvector-pi05/src/openpi/shared/array_typing.py @@ -0,0 +1,89 @@ +import contextlib +import functools as ft +import inspect +from typing import TypeAlias, TypeVar, cast + +import beartype +import jax +import jax._src.tree_util as private_tree_util +import jax.core +from jaxtyping import ArrayLike +from jaxtyping import Bool # noqa: F401 +from jaxtyping import DTypeLike # noqa: F401 +from jaxtyping import Float +from jaxtyping import Int # noqa: F401 +from jaxtyping import Key # noqa: F401 +from jaxtyping import Num # noqa: F401 +from jaxtyping import PyTree +from jaxtyping import Real # noqa: F401 +from jaxtyping import UInt8 # noqa: F401 +from jaxtyping import config +from jaxtyping import jaxtyped +import jaxtyping._decorator +import torch + +# patch jaxtyping to handle https://github.com/patrick-kidger/jaxtyping/issues/277. +# the problem is that custom PyTree nodes are sometimes initialized with arbitrary types (e.g., `jax.ShapeDtypeStruct`, +# `jax.Sharding`, or even ) due to JAX tracing operations. this patch skips typechecking when the stack trace +# contains `jax._src.tree_util`, which should only be the case during tree unflattening. +_original_check_dataclass_annotations = jaxtyping._decorator._check_dataclass_annotations # noqa: SLF001 +# Redefine Array to include both JAX arrays and PyTorch tensors +Array = jax.Array | torch.Tensor + + +def _check_dataclass_annotations(self, typechecker): + if not any( + frame.frame.f_globals.get("__name__") in {"jax._src.tree_util", "flax.nnx.transforms.compilation"} + for frame in inspect.stack() + ): + return _original_check_dataclass_annotations(self, typechecker) + return None + + +jaxtyping._decorator._check_dataclass_annotations = _check_dataclass_annotations # noqa: SLF001 + +KeyArrayLike: TypeAlias = jax.typing.ArrayLike +Params: TypeAlias = PyTree[Float[ArrayLike, "..."]] + +T = TypeVar("T") + + +# runtime type-checking decorator +def typecheck(t: T) -> T: + return cast(T, ft.partial(jaxtyped, typechecker=beartype.beartype)(t)) + + +@contextlib.contextmanager +def disable_typechecking(): + initial = config.jaxtyping_disable + config.update("jaxtyping_disable", True) # noqa: FBT003 + yield + config.update("jaxtyping_disable", initial) + + +def check_pytree_equality(*, expected: PyTree, got: PyTree, check_shapes: bool = False, check_dtypes: bool = False): + """Checks that two PyTrees have the same structure and optionally checks shapes and dtypes. Creates a much nicer + error message than if `jax.tree.map` is naively used on PyTrees with different structures. + """ + + if errors := list(private_tree_util.equality_errors(expected, got)): + raise ValueError( + "PyTrees have different structure:\n" + + ( + "\n".join( + f" - at keypath '{jax.tree_util.keystr(path)}': expected {thing1}, got {thing2}, so {explanation}.\n" + for path, thing1, thing2, explanation in errors + ) + ) + ) + + if check_shapes or check_dtypes: + + def check(kp, x, y): + if check_shapes and x.shape != y.shape: + raise ValueError(f"Shape mismatch at {jax.tree_util.keystr(kp)}: expected {x.shape}, got {y.shape}") + + if check_dtypes and x.dtype != y.dtype: + raise ValueError(f"Dtype mismatch at {jax.tree_util.keystr(kp)}: expected {x.dtype}, got {y.dtype}") + + jax.tree_util.tree_map_with_path(check, expected, got) diff --git a/capvector-pi05/src/openpi/transforms.py b/capvector-pi05/src/openpi/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..782cf4cab861f39d4535563fb1f5b8aaf1198b49 --- /dev/null +++ b/capvector-pi05/src/openpi/transforms.py @@ -0,0 +1,469 @@ +from collections.abc import Callable, Mapping, Sequence +import dataclasses +import re +from typing import Protocol, TypeAlias, TypeVar, runtime_checkable + +import flax.traverse_util as traverse_util +import jax +import numpy as np +from openpi_client import image_tools + +from openpi.models import tokenizer as _tokenizer +from openpi.shared import array_typing as at +from openpi.shared import normalize as _normalize + +DataDict: TypeAlias = at.PyTree +NormStats: TypeAlias = _normalize.NormStats + + +T = TypeVar("T") +S = TypeVar("S") + + +@runtime_checkable +class DataTransformFn(Protocol): + def __call__(self, data: DataDict) -> DataDict: + """Apply transformation to the data. + + Args: + data: The data to apply the transform to. This is a possibly nested dictionary that contains + unbatched data elements. Each leaf is expected to be a numpy array. Using JAX arrays is allowed + but not recommended since it may result in extra GPU memory usage inside data loader worker + processes. + + Returns: + The transformed data. Could be the input `data` that was modified in place, or a new data structure. + """ + + +@dataclasses.dataclass(frozen=True) +class Group: + """A group of transforms.""" + + # Transforms that are applied to the model input data. + inputs: Sequence[DataTransformFn] = () + + # Transforms that are applied to the model output data. + outputs: Sequence[DataTransformFn] = () + + def push(self, *, inputs: Sequence[DataTransformFn] = (), outputs: Sequence[DataTransformFn] = ()) -> "Group": + """Append transforms to the group and return a new group. + + Args: + inputs: Appended to the *end* of the current input transforms. + outputs: Appended to the *beginning* of the current output transforms. + + Returns: + A new group with the appended transforms. + """ + return Group(inputs=(*self.inputs, *inputs), outputs=(*outputs, *self.outputs)) + + +@dataclasses.dataclass(frozen=True) +class CompositeTransform(DataTransformFn): + """A composite transform that applies a sequence of transforms in order.""" + + transforms: Sequence[DataTransformFn] + + def __call__(self, data: DataDict) -> DataDict: + for transform in self.transforms: + data = transform(data) + return data + + +def compose(transforms: Sequence[DataTransformFn]) -> DataTransformFn: + """Compose a sequence of transforms into a single transform.""" + return CompositeTransform(transforms) + + +@dataclasses.dataclass(frozen=True) +class RepackTransform(DataTransformFn): + """Repacks an input dictionary into a new dictionary. + + Repacking is defined using a dictionary where the keys are the new keys and the values + are the flattened paths to the old keys. We use '/' as the separator during flattening. + + Example: + { + "images": { + "cam_high": "observation.images.top", + "cam_low": "observation.images.bottom", + }, + "state": "observation.state", + "actions": "action", + } + """ + + structure: at.PyTree[str] + + def __call__(self, data: DataDict) -> DataDict: + flat_item = flatten_dict(data) + return jax.tree.map(lambda k: flat_item[k], self.structure) + + +@dataclasses.dataclass(frozen=True) +class InjectDefaultPrompt(DataTransformFn): + prompt: str | None + + def __call__(self, data: DataDict) -> DataDict: + if self.prompt is not None and "prompt" not in data: + data["prompt"] = np.asarray(self.prompt) + return data + + +@dataclasses.dataclass(frozen=True) +class Normalize(DataTransformFn): + norm_stats: at.PyTree[NormStats] | None + # If true, will use quantile normalization. Otherwise, normal z-score normalization will be used. + use_quantiles: bool = False + # If true, will raise an error if any of the keys in the norm stats are not present in the data. + strict: bool = False + + def __post_init__(self): + if self.norm_stats is not None and self.use_quantiles: + _assert_quantile_stats(self.norm_stats) + + def __call__(self, data: DataDict) -> DataDict: + if self.norm_stats is None: + return data + + return apply_tree( + data, + self.norm_stats, + self._normalize_quantile if self.use_quantiles else self._normalize, + strict=self.strict, + ) + + def _normalize(self, x, stats: NormStats): + mean, std = stats.mean[..., : x.shape[-1]], stats.std[..., : x.shape[-1]] + return (x - mean) / (std + 1e-6) + + def _normalize_quantile(self, x, stats: NormStats): + assert stats.q01 is not None + assert stats.q99 is not None + q01, q99 = stats.q01[..., : x.shape[-1]], stats.q99[..., : x.shape[-1]] + return (x - q01) / (q99 - q01 + 1e-6) * 2.0 - 1.0 + + +@dataclasses.dataclass(frozen=True) +class Unnormalize(DataTransformFn): + norm_stats: at.PyTree[NormStats] | None + # If true, will use quantile normalization. Otherwise, normal z-score normalization will be used. + use_quantiles: bool = False + + def __post_init__(self): + if self.norm_stats is not None and self.use_quantiles: + _assert_quantile_stats(self.norm_stats) + + def __call__(self, data: DataDict) -> DataDict: + if self.norm_stats is None: + return data + + # Make sure that all the keys in the norm stats are present in the data. + return apply_tree( + data, + self.norm_stats, + self._unnormalize_quantile if self.use_quantiles else self._unnormalize, + strict=True, + ) + + def _unnormalize(self, x, stats: NormStats): + mean = pad_to_dim(stats.mean, x.shape[-1], axis=-1, value=0.0) + std = pad_to_dim(stats.std, x.shape[-1], axis=-1, value=1.0) + return x * (std + 1e-6) + mean + + def _unnormalize_quantile(self, x, stats: NormStats): + assert stats.q01 is not None + assert stats.q99 is not None + q01, q99 = stats.q01, stats.q99 + if (dim := q01.shape[-1]) < x.shape[-1]: + return np.concatenate([(x[..., :dim] + 1.0) / 2.0 * (q99 - q01 + 1e-6) + q01, x[..., dim:]], axis=-1) + return (x + 1.0) / 2.0 * (q99 - q01 + 1e-6) + q01 + + +@dataclasses.dataclass(frozen=True) +class ResizeImages(DataTransformFn): + height: int + width: int + + def __call__(self, data: DataDict) -> DataDict: + data["image_padding_mask"] = dict() + for cam in data["image"]: + resized_img, img_padding_mask = image_tools.resize_with_pad( + data["image"][cam], + self.height, + self.width, + return_mask=True + ) + data["image"][cam] = resized_img + data["image_padding_mask"][cam] = img_padding_mask + return data + + +@dataclasses.dataclass(frozen=True) +class SubsampleActions(DataTransformFn): + stride: int + + def __call__(self, data: DataDict) -> DataDict: + data["actions"] = data["actions"][:: self.stride] + return data + + +@dataclasses.dataclass(frozen=True) +class DeltaActions(DataTransformFn): + """Repacks absolute actions into delta action space.""" + + # Boolean mask for the action dimensions to be repacked into delta action space. Length + # can be smaller than the actual number of dimensions. If None, this transform is a no-op. + # See `make_bool_mask` for more details. + mask: Sequence[bool] | None + + def __call__(self, data: DataDict) -> DataDict: + if "actions" not in data or self.mask is None: + return data + + state, actions = data["state"], data["actions"] + mask = np.asarray(self.mask) + dims = mask.shape[-1] + actions[..., :dims] -= np.expand_dims(np.where(mask, state[..., :dims], 0), axis=-2) + data["actions"] = actions + + return data + + +@dataclasses.dataclass(frozen=True) +class AbsoluteActions(DataTransformFn): + """Repacks delta actions into absolute action space.""" + + # Boolean mask for the action dimensions to be repacked into absolute action space. Length + # can be smaller than the actual number of dimensions. If None, this transform is a no-op. + # See `make_bool_mask` for more details. + mask: Sequence[bool] | None + + def __call__(self, data: DataDict) -> DataDict: + if "actions" not in data or self.mask is None: + return data + + state, actions = data["state"], data["actions"] + mask = np.asarray(self.mask) + dims = mask.shape[-1] + actions[..., :dims] += np.expand_dims(np.where(mask, state[..., :dims], 0), axis=-2) + data["actions"] = actions + + return data + + +@dataclasses.dataclass(frozen=True) +class TokenizePrompt(DataTransformFn): + tokenizer: _tokenizer.PaligemmaTokenizer + discrete_state_input: bool = False + + def __call__(self, data: DataDict) -> DataDict: + if (prompt := data.pop("prompt", None)) is None: + raise ValueError("Prompt is required") + + if self.discrete_state_input: + if (state := data.get("state", None)) is None: + raise ValueError("State is required.") + else: + state = None + + if not isinstance(prompt, str): + prompt = prompt.item() + + tokens, token_masks = self.tokenizer.tokenize(prompt, state) + return {**data, "tokenized_prompt": tokens, "tokenized_prompt_mask": token_masks} + + +@dataclasses.dataclass(frozen=True) +class TokenizeFASTInputs(DataTransformFn): + tokenizer: _tokenizer.FASTTokenizer + + def __call__(self, data: DataDict) -> DataDict: + if (prompt := data.pop("prompt", None)) is None: + raise ValueError("Prompt is required") + + if not isinstance(prompt, str): + prompt = prompt.item() + + state, actions = data["state"], data.get("actions") + tokens, token_mask, ar_mask, loss_mask = self.tokenizer.tokenize(prompt, state, actions) + return { + **data, + "tokenized_prompt": tokens, + "tokenized_prompt_mask": token_mask, + "token_ar_mask": ar_mask, + "token_loss_mask": loss_mask, + } + + +@dataclasses.dataclass(frozen=True) +class ExtractFASTActions(DataTransformFn): + tokenizer: _tokenizer.FASTTokenizer + action_horizon: int + action_dim: int + + def __call__(self, data: DataDict) -> DataDict: + if "actions" not in data: + return data + # Model outputs are saved in "actions", but for FAST models they represent tokens. + tokens = data.pop("actions") + actions = self.tokenizer.extract_actions(tokens.astype(np.int32), self.action_horizon, self.action_dim) + return { + **data, + "actions": actions, + } + + +@dataclasses.dataclass(frozen=True) +class PromptFromLeRobotTask(DataTransformFn): + """Extracts a prompt from the current LeRobot dataset task.""" + + # Contains the LeRobot dataset tasks (dataset.meta.tasks). + tasks: dict[int, str] + + def __call__(self, data: DataDict) -> DataDict: + if "task_index" not in data: + raise ValueError('Cannot extract prompt without "task_index"') + + task_index = int(data["task_index"]) + if (prompt := self.tasks.get(task_index)) is None: + raise ValueError(f"{task_index=} not found in task mapping: {self.tasks}") + + return {**data, "prompt": prompt} + + +@dataclasses.dataclass(frozen=True) +class PadStatesAndActions(DataTransformFn): + """Zero-pads states and actions to the model action dimension.""" + + model_action_dim: int + + def __call__(self, data: DataDict) -> DataDict: + data["state"] = pad_to_dim(data["state"], self.model_action_dim, axis=-1) + if "actions" in data: + data["actions"] = pad_to_dim(data["actions"], self.model_action_dim, axis=-1) + return data + + +def flatten_dict(tree: at.PyTree) -> dict: + """Flatten a nested dictionary. Uses '/' as the separator.""" + return traverse_util.flatten_dict(tree, sep="/") + + +def unflatten_dict(tree: dict) -> at.PyTree: + """Unflatten a flattened dictionary. Assumes that '/' was used as a separator.""" + return traverse_util.unflatten_dict(tree, sep="/") + + +def transform_dict(patterns: Mapping[str, str | None], tree: at.PyTree) -> at.PyTree: + """Transform the structure of a nested dictionary using a set of patterns. + + The transformation is defined using the `patterns` dictionary. The keys are the + input keys that should be matched and the values are the new names inside the output + dictionary. If the value is None, the input key is removed. + + Both keys and values should represent flattened paths using '/' as the separator. + Keys can be regular expressions and values can include backreferences to the + matched groups (see `re.sub` for more details). Note that the regular expression + must match the entire key. + + The order inside the `patterns` dictionary is important. Only the first pattern that + matches the input key will be used. + + See unit tests for more examples. + + Args: + patterns: A mapping from old keys to new keys. + tree: The nested dictionary to transform. + + Returns: + The transformed nested dictionary. + """ + data = flatten_dict(tree) + + # Compile the patterns. + compiled = {re.compile(k): v for k, v in patterns.items()} + + output = {} + for k in data: + for pattern, repl in compiled.items(): + if pattern.fullmatch(k): + new_k = pattern.sub(repl, k, count=1) if repl is not None else None + break + else: + # Use the original key if no match is found. + new_k = k + + if new_k is not None: + if new_k in output: + raise ValueError(f"Key '{new_k}' already exists in output") + output[new_k] = data[k] + + # Validate the output structure to make sure that it can be unflattened. + names = sorted(output) + for i in range(len(names) - 1): + name, next_name = names[i : i + 2] + if next_name.startswith(name + "/"): + raise ValueError(f"Leaf '{name}' aliases a node of '{next_name}'") + + return unflatten_dict(output) + + +def apply_tree( + tree: at.PyTree[T], selector: at.PyTree[S], fn: Callable[[T, S], T], *, strict: bool = False +) -> at.PyTree[T]: + tree = flatten_dict(tree) + selector = flatten_dict(selector) + + def transform(k: str, v: T) -> T: + if k in selector: + return fn(v, selector[k]) + return v + + if strict: + for k in selector: + if k not in tree: + raise ValueError(f"Selector key {k} not found in tree") + + return unflatten_dict({k: transform(k, v) for k, v in tree.items()}) + + +def pad_to_dim(x: np.ndarray, target_dim: int, axis: int = -1, value: float = 0.0) -> np.ndarray: + """Pad an array to the target dimension with zeros along the specified axis.""" + current_dim = x.shape[axis] + if current_dim < target_dim: + pad_width = [(0, 0)] * len(x.shape) + pad_width[axis] = (0, target_dim - current_dim) + return np.pad(x, pad_width, constant_values=value) + return x + + +def make_bool_mask(*dims: int) -> tuple[bool, ...]: + """Make a boolean mask for the given dimensions. + + Example: + make_bool_mask(2, -2, 2) == (True, True, False, False, True, True) + make_bool_mask(2, 0, 2) == (True, True, True, True) + + Args: + dims: The dimensions to make the mask for. + + Returns: + A tuple of booleans. + """ + result = [] + for dim in dims: + if dim > 0: + result.extend([True] * (dim)) + else: + result.extend([False] * (-dim)) + return tuple(result) + + +def _assert_quantile_stats(norm_stats: at.PyTree[NormStats]) -> None: + for k, v in flatten_dict(norm_stats).items(): + if v.q01 is None or v.q99 is None: + raise ValueError( + f"quantile stats must be provided if use_quantile_norm is True. Key {k} is missing q01 or q99." + ) diff --git a/capvector-pi05/src/openpi/transforms_test.py b/capvector-pi05/src/openpi/transforms_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2ef17015132c940545a7af27b470806412431fbe --- /dev/null +++ b/capvector-pi05/src/openpi/transforms_test.py @@ -0,0 +1,121 @@ +import numpy as np +import pytest + +import openpi.models.tokenizer as _tokenizer +import openpi.transforms as _transforms + + +def test_repack_transform(): + transform = _transforms.RepackTransform( + structure={ + "a": {"b": "b/c"}, + "d": "e/f", + } + ) + item = {"b": {"c": 1}, "e": {"f": 2}} + assert transform(item) == {"a": {"b": 1}, "d": 2} + + +def test_delta_actions(): + item = {"state": np.array([1, 2, 3]), "actions": np.array([[3, 4, 5], [5, 6, 7]])} + + transform = _transforms.DeltaActions(mask=[False, True]) + transformed = transform(item) + + assert np.all(transformed["state"] == np.array([1, 2, 3])) + assert np.all(transformed["actions"] == np.array([[3, 2, 5], [5, 4, 7]])) + + +def test_delta_actions_noop(): + item = {"state": np.array([1, 2, 3]), "actions": np.array([[3, 4, 5], [5, 6, 7]])} + + # No-op when the mask is disabled. + transform = _transforms.DeltaActions(mask=None) + assert transform(item) is item + + # No-op when there are no actions in the input. + del item["actions"] + transform = _transforms.DeltaActions(mask=[True, False]) + assert transform(item) is item + + +def test_absolute_actions(): + item = {"state": np.array([1, 2, 3]), "actions": np.array([[3, 4, 5], [5, 6, 7]])} + + transform = _transforms.AbsoluteActions(mask=[False, True]) + transformed = transform(item) + + assert np.all(transformed["state"] == np.array([1, 2, 3])) + assert np.all(transformed["actions"] == np.array([[3, 6, 5], [5, 8, 7]])) + + +def test_absolute_actions_noop(): + item = {"state": np.array([1, 2, 3]), "actions": np.array([[3, 4, 5], [5, 6, 7]])} + + # No-op when the mask is disabled. + transform = _transforms.AbsoluteActions(mask=None) + assert transform(item) is item + + # No-op when there are no actions in the input. + del item["actions"] + transform = _transforms.AbsoluteActions(mask=[True, False]) + assert transform(item) is item + + +def test_make_bool_mask(): + assert _transforms.make_bool_mask(2, -2, 2) == (True, True, False, False, True, True) + assert _transforms.make_bool_mask(2, 0, 2) == (True, True, True, True) + + +def test_tokenize_prompt(): + tokenizer = _tokenizer.PaligemmaTokenizer(max_len=12) + transform = _transforms.TokenizePrompt(tokenizer) + + data = transform({"prompt": "Hello, world!"}) + + tok_prompt, tok_mask = tokenizer.tokenize("Hello, world!") + assert np.allclose(tok_prompt, data["tokenized_prompt"]) + assert np.allclose(tok_mask, data["tokenized_prompt_mask"]) + + +def test_tokenize_no_prompt(): + transform = _transforms.TokenizePrompt(_tokenizer.PaligemmaTokenizer()) + + with pytest.raises(ValueError, match="Prompt is required"): + transform({}) + + +def test_transform_dict(): + # Rename and remove keys. + input = {"a": {"b": 1, "c": 2}} + output = _transforms.transform_dict({"a/b": "a/c", "a/c": None}, input) + assert output == {"a": {"c": 1}} + + # Raises and error since the renamed key conflicts with an existing key. + with pytest.raises(ValueError, match="Key 'a/c' already exists in output"): + _transforms.transform_dict({"a/b": "a/c"}, input) + + # Full match is required and so nothing will be removed. + input = {"a": {"b": 1, "c": 2}} + output = _transforms.transform_dict({"a": None}, input) + assert output == input + + # The regex matches the entire key and so the entire input will be removed. + input = {"a": {"b": 1, "c": 2}} + output = _transforms.transform_dict({"a.+": None}, input) + assert output == {} + + # Replace keys using backreferences. All leaves named 'c' are replaced with 'd'. + input = {"a": {"b": 1, "c": 1}, "b": {"c": 2}} + output = _transforms.transform_dict({"(.+)/c": r"\1/d"}, input) + assert output == {"a": {"b": 1, "d": 1}, "b": {"d": 2}} + + +def test_extract_prompt_from_task(): + transform = _transforms.PromptFromLeRobotTask({1: "Hello, world!"}) + + data = transform({"task_index": 1}) + assert data["prompt"] == "Hello, world!" + + with pytest.raises(ValueError, match="task_index=2 not found in task mapping"): + transform({"task_index": 2})