diff --git a/capvector-pi05/src/openpi/policies/policy_test.py b/capvector-pi05/src/openpi/policies/policy_test.py new file mode 100644 index 0000000000000000000000000000000000000000..adae783af773deae088f7920078bb8bb598f9194 --- /dev/null +++ b/capvector-pi05/src/openpi/policies/policy_test.py @@ -0,0 +1,34 @@ +from openpi_client import action_chunk_broker +import pytest + +from openpi.policies import aloha_policy +from openpi.policies import policy_config as _policy_config +from openpi.training import config as _config + + +@pytest.mark.manual +def test_infer(): + config = _config.get_config("pi0_aloha_sim") + policy = _policy_config.create_trained_policy(config, "gs://openpi-assets/checkpoints/pi0_aloha_sim") + + example = aloha_policy.make_aloha_example() + result = policy.infer(example) + + assert result["actions"].shape == (config.model.action_horizon, 14) + + +@pytest.mark.manual +def test_broker(): + config = _config.get_config("pi0_aloha_sim") + policy = _policy_config.create_trained_policy(config, "gs://openpi-assets/checkpoints/pi0_aloha_sim") + + broker = action_chunk_broker.ActionChunkBroker( + policy, + # Only execute the first half of the chunk. + action_horizon=config.model.action_horizon // 2, + ) + + example = aloha_policy.make_aloha_example() + for _ in range(config.model.action_horizon): + outputs = broker.infer(example) + assert outputs["actions"].shape == (14,) diff --git a/capvector-pi05/src/openpi/serving/websocket_policy_server.py b/capvector-pi05/src/openpi/serving/websocket_policy_server.py new file mode 100644 index 0000000000000000000000000000000000000000..6e6916d18c52f2521400bc382b3052f0bab2ba6f --- /dev/null +++ b/capvector-pi05/src/openpi/serving/websocket_policy_server.py @@ -0,0 +1,90 @@ +import asyncio +import http +import logging +import time +import traceback + +from openpi_client import base_policy as _base_policy +from openpi_client import msgpack_numpy +import websockets.asyncio.server as _server +import websockets.frames + +logger = logging.getLogger(__name__) + + +class WebsocketPolicyServer: + """Serves a policy using the websocket protocol. See websocket_client_policy.py for a client implementation. + + Currently only implements the `load` and `infer` methods. + """ + + def __init__( + self, + policy: _base_policy.BasePolicy, + host: str = "0.0.0.0", + port: int | None = None, + metadata: dict | None = None, + ) -> None: + self._policy = policy + self._host = host + self._port = port + self._metadata = metadata or {} + logging.getLogger("websockets.server").setLevel(logging.INFO) + + def serve_forever(self) -> None: + asyncio.run(self.run()) + + async def run(self): + async with _server.serve( + self._handler, + self._host, + self._port, + compression=None, + max_size=None, + process_request=_health_check, + ) as server: + await server.serve_forever() + + async def _handler(self, websocket: _server.ServerConnection): + logger.info(f"Connection from {websocket.remote_address} opened") + packer = msgpack_numpy.Packer() + + await websocket.send(packer.pack(self._metadata)) + + prev_total_time = None + while True: + try: + start_time = time.monotonic() + obs = msgpack_numpy.unpackb(await websocket.recv()) + + infer_time = time.monotonic() + action = self._policy.infer(obs) + infer_time = time.monotonic() - infer_time + + action["server_timing"] = { + "infer_ms": infer_time * 1000, + } + if prev_total_time is not None: + # We can only record the last total time since we also want to include the send time. + action["server_timing"]["prev_total_ms"] = prev_total_time * 1000 + + await websocket.send(packer.pack(action)) + prev_total_time = time.monotonic() - start_time + + except websockets.ConnectionClosed: + logger.info(f"Connection from {websocket.remote_address} closed") + break + except Exception: + await websocket.send(traceback.format_exc()) + await websocket.close( + code=websockets.frames.CloseCode.INTERNAL_ERROR, + reason="Internal server error. Traceback included in previous frame.", + ) + raise + + +def _health_check(connection: _server.ServerConnection, request: _server.Request) -> _server.Response | None: + if request.path == "/healthz": + return connection.respond(http.HTTPStatus.OK, "OK\n") + # Continue with the normal request handling. + return None diff --git a/capvector-pi05/src/openpi/shared/__init__.py b/capvector-pi05/src/openpi/shared/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/capvector-pi05/src/openpi/shared/download.py b/capvector-pi05/src/openpi/shared/download.py new file mode 100644 index 0000000000000000000000000000000000000000..7dd5304cb78d912f1bb4d63ce55b7c96f4c555d6 --- /dev/null +++ b/capvector-pi05/src/openpi/shared/download.py @@ -0,0 +1,194 @@ +import concurrent.futures +import datetime +import logging +import os +import pathlib +import re +import shutil +import stat +import time +import urllib.parse + +import filelock +import fsspec +import fsspec.generic +import tqdm_loggable.auto as tqdm + +# Environment variable to control cache directory path, ~/.cache/openpi will be used by default. +_OPENPI_DATA_HOME = "OPENPI_DATA_HOME" +DEFAULT_CACHE_DIR = "~/.cache/openpi" + +logger = logging.getLogger(__name__) + + +def get_cache_dir() -> pathlib.Path: + cache_dir = pathlib.Path(os.getenv(_OPENPI_DATA_HOME, DEFAULT_CACHE_DIR)).expanduser().resolve() + cache_dir.mkdir(parents=True, exist_ok=True) + _set_folder_permission(cache_dir) + return cache_dir + + +def maybe_download(url: str, *, force_download: bool = False, **kwargs) -> pathlib.Path: + """Download a file or directory from a remote filesystem to the local cache, and return the local path. + + If the local file already exists, it will be returned directly. + + It is safe to call this function concurrently from multiple processes. + See `get_cache_dir` for more details on the cache directory. + + Args: + url: URL to the file to download. + force_download: If True, the file will be downloaded even if it already exists in the cache. + **kwargs: Additional arguments to pass to fsspec. + + Returns: + Local path to the downloaded file or directory. That path is guaranteed to exist and is absolute. + """ + # Don't use fsspec to parse the url to avoid unnecessary connection to the remote filesystem. + parsed = urllib.parse.urlparse(url) + + # Short circuit if this is a local path. + if parsed.scheme == "": + path = pathlib.Path(url) + if not path.exists(): + raise FileNotFoundError(f"File not found at {url}") + return path.resolve() + + cache_dir = get_cache_dir() + + local_path = cache_dir / parsed.netloc / parsed.path.strip("/") + local_path = local_path.resolve() + + # Check if the cache should be invalidated. + invalidate_cache = False + if local_path.exists(): + if force_download or _should_invalidate_cache(cache_dir, local_path): + invalidate_cache = True + else: + return local_path + + try: + lock_path = local_path.with_suffix(".lock") + with filelock.FileLock(lock_path): + # Ensure consistent permissions for the lock file. + _ensure_permissions(lock_path) + # First, remove the existing cache if it is expired. + if invalidate_cache: + logger.info(f"Removing expired cached entry: {local_path}") + if local_path.is_dir(): + shutil.rmtree(local_path) + else: + local_path.unlink() + + # Download the data to a local cache. + logger.info(f"Downloading {url} to {local_path}") + scratch_path = local_path.with_suffix(".partial") + _download_fsspec(url, scratch_path, **kwargs) + + shutil.move(scratch_path, local_path) + _ensure_permissions(local_path) + + except PermissionError as e: + msg = ( + f"Local file permission error was encountered while downloading {url}. " + f"Please try again after removing the cached data using: `rm -rf {local_path}*`" + ) + raise PermissionError(msg) from e + + return local_path + + +def _download_fsspec(url: str, local_path: pathlib.Path, **kwargs) -> None: + """Download a file from a remote filesystem to the local cache, and return the local path.""" + fs, _ = fsspec.core.url_to_fs(url, **kwargs) + info = fs.info(url) + # Folders are represented by 0-byte objects with a trailing forward slash. + if is_dir := (info["type"] == "directory" or (info["size"] == 0 and info["name"].endswith("/"))): + total_size = fs.du(url) + else: + total_size = info["size"] + with tqdm.tqdm(total=total_size, unit="iB", unit_scale=True, unit_divisor=1024) as pbar: + executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + future = executor.submit(fs.get, url, local_path, recursive=is_dir) + while not future.done(): + current_size = sum(f.stat().st_size for f in [*local_path.rglob("*"), local_path] if f.is_file()) + pbar.update(current_size - pbar.n) + time.sleep(1) + pbar.update(total_size - pbar.n) + + +def _set_permission(path: pathlib.Path, target_permission: int): + """chmod requires executable permission to be set, so we skip if the permission is already match with the target.""" + if path.stat().st_mode & target_permission == target_permission: + logger.debug(f"Skipping {path} because it already has correct permissions") + return + path.chmod(target_permission) + logger.debug(f"Set {path} to {target_permission}") + + +def _set_folder_permission(folder_path: pathlib.Path) -> None: + """Set folder permission to be read, write and searchable.""" + _set_permission(folder_path, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) + + +def _ensure_permissions(path: pathlib.Path) -> None: + """Since we are sharing cache directory with containerized runtime as well as training script, we need to + ensure that the cache directory has the correct permissions. + """ + + def _setup_folder_permission_between_cache_dir_and_path(path: pathlib.Path) -> None: + cache_dir = get_cache_dir() + relative_path = path.relative_to(cache_dir) + moving_path = cache_dir + for part in relative_path.parts: + _set_folder_permission(moving_path / part) + moving_path = moving_path / part + + def _set_file_permission(file_path: pathlib.Path) -> None: + """Set all files to be read & writable, if it is a script, keep it as a script.""" + file_rw = stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IWGRP | stat.S_IROTH | stat.S_IWOTH + if file_path.stat().st_mode & 0o100: + _set_permission(file_path, file_rw | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH) + else: + _set_permission(file_path, file_rw) + + _setup_folder_permission_between_cache_dir_and_path(path) + for root, dirs, files in os.walk(str(path)): + root_path = pathlib.Path(root) + for file in files: + file_path = root_path / file + _set_file_permission(file_path) + + for dir in dirs: + dir_path = root_path / dir + _set_folder_permission(dir_path) + + +def _get_mtime(year: int, month: int, day: int) -> float: + """Get the mtime of a given date at midnight UTC.""" + date = datetime.datetime(year, month, day, tzinfo=datetime.UTC) + return time.mktime(date.timetuple()) + + +# Map of relative paths, defined as regular expressions, to expiration timestamps (mtime format). +# Partial matching will be used from top to bottom and the first match will be chosen. +# Cached entries will be retained only if they are newer than the expiration timestamp. +_INVALIDATE_CACHE_DIRS: dict[re.Pattern, float] = { + re.compile("openpi-assets/checkpoints/pi0_aloha_pen_uncap"): _get_mtime(2025, 2, 17), + re.compile("openpi-assets/checkpoints/pi0_libero"): _get_mtime(2025, 2, 6), + re.compile("openpi-assets/checkpoints/"): _get_mtime(2025, 2, 3), +} + + +def _should_invalidate_cache(cache_dir: pathlib.Path, local_path: pathlib.Path) -> bool: + """Invalidate the cache if it is expired. Return True if the cache was invalidated.""" + + assert local_path.exists(), f"File not found at {local_path}" + + relative_path = str(local_path.relative_to(cache_dir)) + for pattern, expire_time in _INVALIDATE_CACHE_DIRS.items(): + if pattern.match(relative_path): + # Remove if not newer than the expiration timestamp. + return local_path.stat().st_mtime <= expire_time + + return False diff --git a/capvector-pi05/src/openpi/shared/download_test.py b/capvector-pi05/src/openpi/shared/download_test.py new file mode 100644 index 0000000000000000000000000000000000000000..48417fe3ad386a642e03bdb4a16b884ec951e1eb --- /dev/null +++ b/capvector-pi05/src/openpi/shared/download_test.py @@ -0,0 +1,54 @@ +import pathlib + +import pytest + +import openpi.shared.download as download + + +@pytest.fixture(scope="session", autouse=True) +def set_openpi_data_home(tmp_path_factory): + temp_dir = tmp_path_factory.mktemp("openpi_data") + with pytest.MonkeyPatch().context() as mp: + mp.setenv("OPENPI_DATA_HOME", str(temp_dir)) + yield + + +def test_download_local(tmp_path: pathlib.Path): + local_path = tmp_path / "local" + local_path.touch() + + result = download.maybe_download(str(local_path)) + assert result == local_path + + with pytest.raises(FileNotFoundError): + download.maybe_download("bogus") + + +def test_download_gs_dir(): + remote_path = "gs://openpi-assets/testdata/random" + + local_path = download.maybe_download(remote_path) + assert local_path.exists() + + new_local_path = download.maybe_download(remote_path) + assert new_local_path == local_path + + +def test_download_gs(): + remote_path = "gs://openpi-assets/testdata/random/random_512kb.bin" + + local_path = download.maybe_download(remote_path) + assert local_path.exists() + + new_local_path = download.maybe_download(remote_path) + assert new_local_path == local_path + + +def test_download_fsspec(): + remote_path = "gs://big_vision/paligemma_tokenizer.model" + + local_path = download.maybe_download(remote_path, gs={"token": "anon"}) + assert local_path.exists() + + new_local_path = download.maybe_download(remote_path, gs={"token": "anon"}) + assert new_local_path == local_path diff --git a/capvector-pi05/src/openpi/shared/image_tools.py b/capvector-pi05/src/openpi/shared/image_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..50548c1b1176616dc45f257f4d97cf744ea7fa98 --- /dev/null +++ b/capvector-pi05/src/openpi/shared/image_tools.py @@ -0,0 +1,186 @@ +import functools + +import jax +import jax.numpy as jnp +import torch +import torch.nn.functional as F # noqa: N812 + +import openpi.shared.array_typing as at + + +@functools.partial(jax.jit, static_argnums=(1, 2, 3)) +@at.typecheck +def resize_with_pad( + images: at.UInt8[at.Array, "*b h w c"] | at.Float[at.Array, "*b h w c"], + height: int, + width: int, + method: jax.image.ResizeMethod = jax.image.ResizeMethod.LINEAR, +) -> at.UInt8[at.Array, "*b {height} {width} c"] | at.Float[at.Array, "*b {height} {width} c"]: + """Replicates tf.image.resize_with_pad. Resizes an image to a target height and width without distortion + by padding with black. If the image is float32, it must be in the range [-1, 1]. + """ + has_batch_dim = images.ndim == 4 + if not has_batch_dim: + images = images[None] # type: ignore + cur_height, cur_width = images.shape[1:3] + ratio = max(cur_width / width, cur_height / height) + resized_height = int(cur_height / ratio) + resized_width = int(cur_width / ratio) + resized_images = jax.image.resize( + images, (images.shape[0], resized_height, resized_width, images.shape[3]), method=method + ) + if images.dtype == jnp.uint8: + # round from float back to uint8 + resized_images = jnp.round(resized_images).clip(0, 255).astype(jnp.uint8) + elif images.dtype == jnp.float32: + resized_images = resized_images.clip(-1.0, 1.0) + else: + raise ValueError(f"Unsupported image dtype: {images.dtype}") + + pad_h0, remainder_h = divmod(height - resized_height, 2) + pad_h1 = pad_h0 + remainder_h + pad_w0, remainder_w = divmod(width - resized_width, 2) + pad_w1 = pad_w0 + remainder_w + padded_images = jnp.pad( + resized_images, + ((0, 0), (pad_h0, pad_h1), (pad_w0, pad_w1), (0, 0)), + constant_values=0 if images.dtype == jnp.uint8 else -1.0, + ) + + if not has_batch_dim: + padded_images = padded_images[0] + return padded_images + + +def resize_with_pad_torch( + images: torch.Tensor, + height: int, + width: int, + mode: str = "bilinear", +) -> torch.Tensor: + """PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion + by padding with black. If the image is float32, it must be in the range [-1, 1]. + + Args: + images: Tensor of shape [*b, h, w, c] or [*b, c, h, w] + height: Target height + width: Target width + mode: Interpolation mode ('bilinear', 'nearest', etc.) + + Returns: + Resized and padded tensor with same shape format as input + """ + # Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w] + if images.shape[-1] <= 4: # Assume channels-last format + channels_last = True + # Convert to channels-first for torch operations + if images.dim() == 3: + images = images.unsqueeze(0) # Add batch dimension + images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w] + else: + channels_last = False + if images.dim() == 3: + images = images.unsqueeze(0) # Add batch dimension + + batch_size, channels, cur_height, cur_width = images.shape + + # Calculate resize ratio + ratio = max(cur_width / width, cur_height / height) + resized_height = int(cur_height / ratio) + resized_width = int(cur_width / ratio) + + # Resize + resized_images = F.interpolate( + images, size=(resized_height, resized_width), mode=mode, align_corners=False if mode == "bilinear" else None + ) + + # Handle dtype-specific clipping + if images.dtype == torch.uint8: + resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8) + elif images.dtype == torch.float32: + resized_images = resized_images.clamp(-1.0, 1.0) + else: + raise ValueError(f"Unsupported image dtype: {images.dtype}") + + # Calculate padding + pad_h0, remainder_h = divmod(height - resized_height, 2) + pad_h1 = pad_h0 + remainder_h + pad_w0, remainder_w = divmod(width - resized_width, 2) + pad_w1 = pad_w0 + remainder_w + + # Pad + constant_value = 0 if images.dtype == torch.uint8 else -1.0 + padded_images = F.pad( + resized_images, + (pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom + mode="constant", + value=constant_value, + ) + + # Convert back to original format if needed + if channels_last: + padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + if batch_size == 1 and images.shape[0] == 1: + padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added + + return padded_images + + +def replace_padding_0to1_torch(image: torch.Tensor,) -> torch.Tensor: + """PyTorch version of replace_padding_0to1. + OpenPI requires images with 0 value paddings, while VGGT series requires 1 value paddings. + Here it achieves this bounding-box based padding replacement. + Args: + image: Tensor of shape [*b, h, w, c] + Returns: + Padding-replaced tensor with same shape as input + """ + single = False + if image.dim() == 3: + image = image.unsqueeze(0) + single = True + + b, h, w, c = image.shape + device = image.device + + nonzero_any = (image != 0).any(dim=-1) + + row_any = nonzero_any.any(dim=2) + col_any = nonzero_any.any(dim=1) + + top = row_any.to(torch.float32).argmax(dim=1) + bottom = h - 1 - row_any.flip(dims=[1]).to(torch.float32).argmax(dim=1) + left = col_any.to(torch.float32).argmax(dim=1) + right = w - 1 - col_any.flip(dims=[1]).to(torch.float32).argmax(dim=1) + + has_any = row_any.any(dim=1) + top = torch.where(has_any, top, torch.zeros_like(top)) + bottom = torch.where(has_any, bottom, torch.full_like(bottom, h - 1)) + left = torch.where(has_any, left, torch.zeros_like(left)) + right = torch.where(has_any, right, torch.full_like(right, w - 1)) + + rows = torch.arange(h, device=device).view(1, h, 1) + cols = torch.arange(w, device=device).view(1, 1, w) + top_v = top.view(b, 1, 1) + bottom_v = bottom.view(b, 1, 1) + left_v = left.view(b, 1, 1) + right_v = right.view(b, 1, 1) + + row_mask = (rows >= top_v) & (rows <= bottom_v) + col_mask = (cols >= left_v) & (cols <= right_v) + inside_mask = row_mask & col_mask + + padding_mask = ~inside_mask + + pixel_zero = (image == 0).all(dim=-1) + + final_mask = padding_mask & pixel_zero + + if final_mask.any(): + mask_exp = final_mask.unsqueeze(-1).expand_as(image) + one_t = torch.tensor(1, dtype=image.dtype, device=device) + image = torch.where(mask_exp, one_t, image) + + if single: + image = image.squeeze(0) + return image \ No newline at end of file diff --git a/capvector-pi05/src/openpi/shared/image_tools_test.py b/capvector-pi05/src/openpi/shared/image_tools_test.py new file mode 100644 index 0000000000000000000000000000000000000000..be1b7e1bcd08106d61b04bd40e73e5eeecb9f484 --- /dev/null +++ b/capvector-pi05/src/openpi/shared/image_tools_test.py @@ -0,0 +1,37 @@ +import jax.numpy as jnp + +from openpi.shared import image_tools + + +def test_resize_with_pad_shapes(): + # Test case 1: Resize image with larger dimensions + images = jnp.zeros((2, 10, 10, 3), dtype=jnp.uint8) # Input images of shape (batch_size, height, width, channels) + height = 20 + width = 20 + resized_images = image_tools.resize_with_pad(images, height, width) + assert resized_images.shape == (2, height, width, 3) + assert jnp.all(resized_images == 0) + + # Test case 2: Resize image with smaller dimensions + images = jnp.zeros((3, 30, 30, 3), dtype=jnp.uint8) + height = 15 + width = 15 + resized_images = image_tools.resize_with_pad(images, height, width) + assert resized_images.shape == (3, height, width, 3) + assert jnp.all(resized_images == 0) + + # Test case 3: Resize image with the same dimensions + images = jnp.zeros((1, 50, 50, 3), dtype=jnp.uint8) + height = 50 + width = 50 + resized_images = image_tools.resize_with_pad(images, height, width) + assert resized_images.shape == (1, height, width, 3) + assert jnp.all(resized_images == 0) + + # Test case 3: Resize image with odd-numbered padding + images = jnp.zeros((1, 256, 320, 3), dtype=jnp.uint8) + height = 60 + width = 80 + resized_images = image_tools.resize_with_pad(images, height, width) + assert resized_images.shape == (1, height, width, 3) + assert jnp.all(resized_images == 0) diff --git a/capvector-pi05/src/openpi/shared/nnx_utils.py b/capvector-pi05/src/openpi/shared/nnx_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..08907a48a6d96bdaf5ce42f1da836c74bf0285ed --- /dev/null +++ b/capvector-pi05/src/openpi/shared/nnx_utils.py @@ -0,0 +1,69 @@ +from collections.abc import Callable +import dataclasses +import functools +import inspect +import re +from typing import Any, ParamSpec, TypeVar + +import flax.nnx as nnx +import jax + +P = ParamSpec("P") +R = TypeVar("R") + + +def module_jit(meth: Callable[P, R], *jit_args, **jit_kwargs) -> Callable[P, R]: + """A higher-order function to JIT-compile `nnx.Module` methods, freezing the module's state in the process. + + Why not `nnx.jit`? For some reason, naively applying `nnx.jit` to `nnx.Module` methods, bound or unbound, uses much + more memory than necessary. I'm guessing it has something to do with the fact that it must keep track of module + mutations. Also, `nnx.jit` has some inherent overhead compared to a standard `jax.jit`, since every call must + traverse the NNX module graph. See https://github.com/google/flax/discussions/4224 for details. + + `module_jit` is an alternative that avoids these issues by freezing the module's state. The function returned by + `module_jit` acts exactly like the original method, except that the state of the module is frozen to whatever it was + when `module_jit` was called. Mutations to the module within `meth` are still allowed, but they will be discarded + after the method call completes. + """ + if not (inspect.ismethod(meth) and isinstance(meth.__self__, nnx.Module)): + raise ValueError("module_jit must only be used on bound methods of nnx.Modules.") + + graphdef, state = nnx.split(meth.__self__) + + def fun(state: nnx.State, *args: P.args, **kwargs: P.kwargs) -> R: + module = nnx.merge(graphdef, state) + return meth.__func__(module, *args, **kwargs) + + jitted_fn = jax.jit(fun, *jit_args, **jit_kwargs) + + @functools.wraps(meth) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + return jitted_fn(state, *args, **kwargs) + + return wrapper + + +@dataclasses.dataclass(frozen=True) +class PathRegex: + """NNX Filter that matches paths using a regex. + + By default, paths are joined with a `/` separator. This can be overridden by setting the `sep` argument. + """ + + pattern: str | re.Pattern + sep: str = "/" + + def __post_init__(self): + if not isinstance(self.pattern, re.Pattern): + object.__setattr__(self, "pattern", re.compile(self.pattern)) + + def __call__(self, path: nnx.filterlib.PathParts, x: Any) -> bool: + joined_path = self.sep.join(str(x) for x in path) + assert isinstance(self.pattern, re.Pattern) + return self.pattern.fullmatch(joined_path) is not None + + +def state_map(state: nnx.State, filter: nnx.filterlib.Filter, fn: Callable[[Any], Any]) -> nnx.State: + """Apply a function to the leaves of the state that match the filter.""" + filtered_keys = set(state.filter(filter).flat_state()) + return state.map(lambda k, v: fn(v) if k in filtered_keys else v) diff --git a/capvector-pi05/src/openpi/shared/normalize.py b/capvector-pi05/src/openpi/shared/normalize.py new file mode 100644 index 0000000000000000000000000000000000000000..f4bf6100fb506a21c064cae725eb04fea0e00017 --- /dev/null +++ b/capvector-pi05/src/openpi/shared/normalize.py @@ -0,0 +1,146 @@ +import json +import pathlib + +import numpy as np +import numpydantic +import pydantic + + +@pydantic.dataclasses.dataclass +class NormStats: + mean: numpydantic.NDArray + std: numpydantic.NDArray + q01: numpydantic.NDArray | None = None # 1st quantile + q99: numpydantic.NDArray | None = None # 99th quantile + + +class RunningStats: + """Compute running statistics of a batch of vectors.""" + + def __init__(self): + self._count = 0 + self._mean = None + self._mean_of_squares = None + self._min = None + self._max = None + self._histograms = None + self._bin_edges = None + self._num_quantile_bins = 5000 # for computing quantiles on the fly + + def update(self, batch: np.ndarray) -> None: + """ + Update the running statistics with a batch of vectors. + + Args: + vectors (np.ndarray): An array where all dimensions except the last are batch dimensions. + """ + batch = batch.reshape(-1, batch.shape[-1]) + num_elements, vector_length = batch.shape + if self._count == 0: + self._mean = np.mean(batch, axis=0) + self._mean_of_squares = np.mean(batch**2, axis=0) + self._min = np.min(batch, axis=0) + self._max = np.max(batch, axis=0) + self._histograms = [np.zeros(self._num_quantile_bins) for _ in range(vector_length)] + self._bin_edges = [ + np.linspace(self._min[i] - 1e-10, self._max[i] + 1e-10, self._num_quantile_bins + 1) + for i in range(vector_length) + ] + else: + if vector_length != self._mean.size: + raise ValueError("The length of new vectors does not match the initialized vector length.") + new_max = np.max(batch, axis=0) + new_min = np.min(batch, axis=0) + max_changed = np.any(new_max > self._max) + min_changed = np.any(new_min < self._min) + self._max = np.maximum(self._max, new_max) + self._min = np.minimum(self._min, new_min) + + if max_changed or min_changed: + self._adjust_histograms() + + self._count += num_elements + + batch_mean = np.mean(batch, axis=0) + batch_mean_of_squares = np.mean(batch**2, axis=0) + + # Update running mean and mean of squares. + self._mean += (batch_mean - self._mean) * (num_elements / self._count) + self._mean_of_squares += (batch_mean_of_squares - self._mean_of_squares) * (num_elements / self._count) + + self._update_histograms(batch) + + def get_statistics(self) -> NormStats: + """ + Compute and return the statistics of the vectors processed so far. + + Returns: + dict: A dictionary containing the computed statistics. + """ + if self._count < 2: + raise ValueError("Cannot compute statistics for less than 2 vectors.") + + variance = self._mean_of_squares - self._mean**2 + stddev = np.sqrt(np.maximum(0, variance)) + q01, q99 = self._compute_quantiles([0.01, 0.99]) + return NormStats(mean=self._mean, std=stddev, q01=q01, q99=q99) + + def _adjust_histograms(self): + """Adjust histograms when min or max changes.""" + for i in range(len(self._histograms)): + old_edges = self._bin_edges[i] + new_edges = np.linspace(self._min[i], self._max[i], self._num_quantile_bins + 1) + + # Redistribute the existing histogram counts to the new bins + new_hist, _ = np.histogram(old_edges[:-1], bins=new_edges, weights=self._histograms[i]) + + self._histograms[i] = new_hist + self._bin_edges[i] = new_edges + + def _update_histograms(self, batch: np.ndarray) -> None: + """Update histograms with new vectors.""" + for i in range(batch.shape[1]): + hist, _ = np.histogram(batch[:, i], bins=self._bin_edges[i]) + self._histograms[i] += hist + + def _compute_quantiles(self, quantiles): + """Compute quantiles based on histograms.""" + results = [] + for q in quantiles: + target_count = q * self._count + q_values = [] + for hist, edges in zip(self._histograms, self._bin_edges, strict=True): + cumsum = np.cumsum(hist) + idx = np.searchsorted(cumsum, target_count) + q_values.append(edges[idx]) + results.append(np.array(q_values)) + return results + + +class _NormStatsDict(pydantic.BaseModel): + norm_stats: dict[str, NormStats] + + +def serialize_json(norm_stats: dict[str, NormStats]) -> str: + """Serialize the running statistics to a JSON string.""" + return _NormStatsDict(norm_stats=norm_stats).model_dump_json(indent=2) + + +def deserialize_json(data: str) -> dict[str, NormStats]: + """Deserialize the running statistics from a JSON string.""" + return _NormStatsDict(**json.loads(data)).norm_stats + + +def save(directory: pathlib.Path | str, norm_stats: dict[str, NormStats]) -> None: + """Save the normalization stats to a directory.""" + path = pathlib.Path(directory) / "norm_stats.json" + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(serialize_json(norm_stats)) + + +def load(directory: pathlib.Path | str) -> dict[str, NormStats]: + """Load the normalization stats from a directory.""" + path = pathlib.Path(directory) / "norm_stats.json" + if not path.exists(): + raise FileNotFoundError(f"Norm stats file not found at: {path}") + return deserialize_json(path.read_text()) diff --git a/capvector-pi05/src/openpi/shared/normalize_test.py b/capvector-pi05/src/openpi/shared/normalize_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ab0aa15f9f3f0fd3b8aac6fea2c864d32be00c9b --- /dev/null +++ b/capvector-pi05/src/openpi/shared/normalize_test.py @@ -0,0 +1,43 @@ +import numpy as np + +import openpi.shared.normalize as normalize + + +def test_normalize_update(): + arr = np.arange(12).reshape(4, 3) # 4 vectors of length 3 + + stats = normalize.RunningStats() + for i in range(len(arr)): + stats.update(arr[i : i + 1]) # Update with one vector at a time + results = stats.get_statistics() + + assert np.allclose(results.mean, np.mean(arr, axis=0)) + assert np.allclose(results.std, np.std(arr, axis=0)) + + +def test_serialize_deserialize(): + stats = normalize.RunningStats() + stats.update(np.arange(12).reshape(4, 3)) # 4 vectors of length 3 + + norm_stats = {"test": stats.get_statistics()} + norm_stats2 = normalize.deserialize_json(normalize.serialize_json(norm_stats)) + assert np.allclose(norm_stats["test"].mean, norm_stats2["test"].mean) + assert np.allclose(norm_stats["test"].std, norm_stats2["test"].std) + + +def test_multiple_batch_dimensions(): + # Test with multiple batch dimensions: (2, 3, 4) where 4 is vector dimension + batch_shape = (2, 3, 4) + arr = np.random.rand(*batch_shape) + + stats = normalize.RunningStats() + stats.update(arr) # Should handle (2, 3, 4) -> reshape to (6, 4) + results = stats.get_statistics() + + # Flatten batch dimensions and compute expected stats + flattened = arr.reshape(-1, arr.shape[-1]) # (6, 4) + expected_mean = np.mean(flattened, axis=0) + expected_std = np.std(flattened, axis=0) + + assert np.allclose(results.mean, expected_mean) + assert np.allclose(results.std, expected_std) diff --git a/capvector-pi05/src/openpi/training/checkpoints.py b/capvector-pi05/src/openpi/training/checkpoints.py new file mode 100644 index 0000000000000000000000000000000000000000..0f53c71d7e35b72e7d07c957fdb8010f09e23396 --- /dev/null +++ b/capvector-pi05/src/openpi/training/checkpoints.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +import asyncio +import concurrent.futures as futures +import dataclasses +import logging +from typing import Protocol + +from etils import epath +import jax +import orbax.checkpoint as ocp +import orbax.checkpoint.future as future + +from openpi.shared import array_typing as at +import openpi.shared.normalize as _normalize +import openpi.training.data_loader as _data_loader +import openpi.training.utils as training_utils + + +def initialize_checkpoint_dir( + checkpoint_dir: epath.Path | str, *, keep_period: int | None, overwrite: bool, resume: bool +) -> tuple[ocp.CheckpointManager, bool]: + checkpoint_dir = epath.Path(checkpoint_dir).resolve() + resuming = False + if checkpoint_dir.exists(): + if overwrite: + checkpoint_dir.rmtree() + checkpoint_dir.mkdir(parents=True, exist_ok=True) + logging.info(f"Wiped checkpoint directory {checkpoint_dir}") + elif resume: + resuming = True + else: + raise FileExistsError( + f"Checkpoint directory {checkpoint_dir} already exists. Use --overwrite or --resume " + "to indicate how to handle it." + ) + + checkpoint_dir.mkdir(parents=True, exist_ok=True) + + mngr = ocp.CheckpointManager( + checkpoint_dir, + item_handlers={ + "assets": CallbackHandler(), + "train_state": ocp.PyTreeCheckpointHandler(), + "params": ocp.PyTreeCheckpointHandler(), + }, + options=ocp.CheckpointManagerOptions( + max_to_keep=1, + keep_period=keep_period, + create=False, + async_options=ocp.AsyncOptions(timeout_secs=7200), + ), + ) + + # Special case: the checkpoint directory exists and the user requests to resume training, but the training run did + # not get to the first checkpoint saved. In this case, we don't actually want the train script to try and restore a + # checkpoint, since it will fail. + if resuming and tuple(mngr.all_steps()) in [(), (0,)]: + logging.info("Checkpoint directory exists, but does not contain any checkpoints. Aborting resume.") + resuming = False + + return mngr, resuming + + +def save_state( + checkpoint_manager: ocp.CheckpointManager, + state: training_utils.TrainState, + data_loader: _data_loader.DataLoader, + step: int, +): + def save_assets(directory: epath.Path): + # Save the normalization stats. + data_config = data_loader.data_config() + norm_stats = data_config.norm_stats + if norm_stats is not None and data_config.asset_id is not None: + _normalize.save(directory / data_config.asset_id, norm_stats) + + # Split params that can be used for inference into a separate item. + with at.disable_typechecking(): + train_state, params = _split_params(state) + items = { + "assets": save_assets, + "train_state": train_state, + "params": {"params": params}, + } + checkpoint_manager.save(step, items) + + +def restore_state( + checkpoint_manager: ocp.CheckpointManager, + state: training_utils.TrainState, + data_loader: _data_loader.DataLoader, + step: int | None = None, +) -> training_utils.TrainState: + del data_loader + + with at.disable_typechecking(): + # Split params that can be used for inference into a separate item. + train_state, params = _split_params(state) + restored = checkpoint_manager.restore( + step, + items={ + "train_state": train_state, + "params": {"params": params}, + }, + ) + return _merge_params(restored["train_state"], restored["params"]) + + +def load_norm_stats(assets_dir: epath.Path | str, asset_id: str) -> dict[str, _normalize.NormStats] | None: + norm_stats_dir = epath.Path(assets_dir) / asset_id + norm_stats = _normalize.load(norm_stats_dir) + logging.info(f"Loaded norm stats from {norm_stats_dir}") + return norm_stats + + +class Callback(Protocol): + def __call__(self, directory: epath.Path) -> None: ... + + +class CallbackHandler(ocp.AsyncCheckpointHandler): + """A CheckpointHandler for calling an arbitrary function asynchronously. Only for saving, not for restoring.""" + + def save(self, directory: epath.Path, args: CallbackSave): + if jax.process_index() == 0: + args.callback(directory) + + async def async_save(self, directory: epath.Path, args: CallbackSave) -> list[futures.Future]: + return [future.CommitFutureAwaitingContractedSignals(asyncio.to_thread(self.save, directory, args))] + + def restore(self, *args, **kwargs): + raise NotImplementedError("CallbackHandler does not support restore") + + +@ocp.args.register_with_handler(CallbackHandler, for_save=True) +@dataclasses.dataclass +class CallbackSave(ocp.args.CheckpointArgs): + callback: Callback + + +@ocp.args.register_with_handler(CallbackHandler, for_restore=True) +class CallbackRestore(ocp.args.CheckpointArgs): ... + + +def _split_params(state: training_utils.TrainState) -> tuple[training_utils.TrainState, at.Params]: + if state.ema_params is not None: + params = state.ema_params + train_state = dataclasses.replace(state, ema_params=None) + else: + params = state.params + train_state = dataclasses.replace(state, params={}) + return train_state, params + + +def _merge_params(train_state: training_utils.TrainState, params: dict[str, at.Params]) -> training_utils.TrainState: + # Revert the logic inside `_split_params`. Assumes that existence of `params` means that EMA params were used during the split. + if train_state.params: + return dataclasses.replace(train_state, ema_params=params["params"]) + return dataclasses.replace(train_state, params=params["params"]) diff --git a/capvector-pi05/src/openpi/training/config.py b/capvector-pi05/src/openpi/training/config.py new file mode 100644 index 0000000000000000000000000000000000000000..51690fa0a2385514eddde0a5716fcda15e5dcc2b --- /dev/null +++ b/capvector-pi05/src/openpi/training/config.py @@ -0,0 +1,1033 @@ +"""See _CONFIGS for the list of available configs.""" + +import abc +from collections.abc import Sequence +import dataclasses +import difflib +import logging +import pathlib +from typing import Any, Literal, Protocol, TypeAlias + +import etils.epath as epath +import flax.nnx as nnx +from typing_extensions import override +import tyro + +import openpi.models.model as _model +import openpi.models.pi0_config as pi0_config +import openpi.models.pi0_fast as pi0_fast +import openpi.models.tokenizer as _tokenizer +import openpi.policies.aloha_policy as aloha_policy +import openpi.policies.droid_policy as droid_policy +import openpi.policies.libero_policy as libero_policy +import openpi.shared.download as _download +import openpi.shared.normalize as _normalize +import openpi.training.droid_rlds_dataset as droid_rlds_dataset +import openpi.training.misc.roboarena_config as roboarena_config +import openpi.training.optimizer as _optimizer +import openpi.training.weight_loaders as weight_loaders +import openpi.transforms as _transforms + +ModelType: TypeAlias = _model.ModelType +# Work around a tyro issue with using nnx.filterlib.Filter directly. +Filter: TypeAlias = nnx.filterlib.Filter + + +@dataclasses.dataclass(frozen=True) +class AssetsConfig: + """Determines the location of assets (e.g., norm stats) that will be used to set up the data pipeline. + + These assets will be replicated inside the checkpoint under the `assets/asset_id` directory. + + This can be used to load assets from a different checkpoint (e.g., base model checkpoint) or some other + centralized location. For example, to load the norm stats for the Trossen robot from the base model checkpoint + during fine-tuning, use: + + ``` + AssetsConfig( + assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets", + asset_id="trossen", + ) + ``` + """ + + # Assets directory. If not provided, the config assets_dirs will be used. This is useful to load assets from + # a different checkpoint (e.g., base model checkpoint) or some other centralized location. + assets_dir: str | None = None + + # Asset id. If not provided, the repo id will be used. This allows users to reference assets that describe + # different robot platforms. + asset_id: str | None = None + + +@dataclasses.dataclass(frozen=True) +class DataConfig: + # LeRobot repo id. If None, fake data will be created. + repo_id: str | None = None + # Directory within the assets directory containing the data assets. + asset_id: str | None = None + # Contains precomputed normalization stats. If None, normalization will not be performed. + norm_stats: dict[str, _transforms.NormStats] | None = None + + # Used to adopt the inputs from a dataset specific format to a common format + # which is expected by the data transforms. + repack_transforms: _transforms.Group = dataclasses.field(default_factory=_transforms.Group) + # Data transforms, typically include robot specific transformations. Will be applied + # before the data is normalized. See `model.Observation` and `model.Actions` to learn about the + # normalized data. + data_transforms: _transforms.Group = dataclasses.field(default_factory=_transforms.Group) + # Model specific transforms. Will be applied after the data is normalized. + model_transforms: _transforms.Group = dataclasses.field(default_factory=_transforms.Group) + # If true, will use quantile normalization. Otherwise, normal z-score normalization will be used. + use_quantile_norm: bool = False + + # Names of keys that will be used by the data loader to generate the action sequence. The length of the + # sequence is defined by the `action_horizon` field in the model config. This should be adjusted if your + # LeRobot dataset is using different keys to represent the action. + action_sequence_keys: Sequence[str] = ("actions",) + + # If true, will use the LeRobot dataset task to define the prompt. + prompt_from_task: bool = False + + # Only used for RLDS data loader (ie currently only used for DROID). + rlds_data_dir: str | None = None + # Action space for DROID dataset. + action_space: droid_rlds_dataset.DroidActionSpace | None = None + # Path to the data filter file for DROID dataset + filter_dict_path: str | None = None + + +class GroupFactory(Protocol): + def __call__(self, model_config: _model.BaseModelConfig) -> _transforms.Group: + """Create a group.""" + + +@dataclasses.dataclass(frozen=True) +class ModelTransformFactory(GroupFactory): + """Creates model transforms for standard pi0 models.""" + + # If provided, will determine the default prompt that be used by the model. + default_prompt: str | None = None + + def __call__(self, model_config: _model.BaseModelConfig) -> _transforms.Group: + match model_config.model_type: + case _model.ModelType.PI0: + return _transforms.Group( + inputs=[ + _transforms.InjectDefaultPrompt(self.default_prompt), + _transforms.ResizeImages(224, 224), + _transforms.TokenizePrompt( + _tokenizer.PaligemmaTokenizer(model_config.max_token_len), + ), + _transforms.PadStatesAndActions(model_config.action_dim), + ], + ) + case _model.ModelType.PI05: + assert isinstance(model_config, pi0_config.Pi0Config) + return _transforms.Group( + inputs=[ + _transforms.InjectDefaultPrompt(self.default_prompt), + _transforms.ResizeImages(224, 224), + _transforms.TokenizePrompt( + _tokenizer.PaligemmaTokenizer(model_config.max_token_len), + discrete_state_input=model_config.discrete_state_input, + ), + _transforms.PadStatesAndActions(model_config.action_dim), + ], + ) + case _model.ModelType.PI0_FAST: + tokenizer_cls = ( + _tokenizer.FASTTokenizer + if model_config.fast_model_tokenizer is None + else model_config.fast_model_tokenizer + ) + tokenizer_kwargs = ( + {} if model_config.fast_model_tokenizer_kwargs is None else model_config.fast_model_tokenizer_kwargs + ) + return _transforms.Group( + inputs=[ + _transforms.InjectDefaultPrompt(self.default_prompt), + _transforms.ResizeImages(224, 224), + _transforms.TokenizeFASTInputs( + tokenizer_cls(model_config.max_token_len, **tokenizer_kwargs), + ), + ], + outputs=[ + _transforms.ExtractFASTActions( + tokenizer_cls(model_config.max_token_len, **tokenizer_kwargs), + action_horizon=model_config.action_horizon, + action_dim=model_config.action_dim, + ) + ], + ) + + +@dataclasses.dataclass(frozen=True) +class DataConfigFactory(abc.ABC): + # The LeRobot repo id. + repo_id: str = tyro.MISSING + # Determines how the assets will be loaded. + assets: AssetsConfig = dataclasses.field(default_factory=AssetsConfig) + # Base config that will be updated by the factory. + base_config: tyro.conf.Suppress[DataConfig | None] = None + + @abc.abstractmethod + def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig: + """Create a data config.""" + + def create_base_config(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig: + repo_id = self.repo_id if self.repo_id is not tyro.MISSING else None + asset_id = self.assets.asset_id or repo_id + return dataclasses.replace( + self.base_config or DataConfig(), + repo_id=repo_id, + asset_id=asset_id, + norm_stats=self._load_norm_stats(epath.Path(self.assets.assets_dir or assets_dirs), asset_id), + use_quantile_norm=model_config.model_type != ModelType.PI0, + ) + + def _load_norm_stats(self, assets_dir: epath.Path, asset_id: str | None) -> dict[str, _transforms.NormStats] | None: + if asset_id is None: + return None + try: + data_assets_dir = str(assets_dir / asset_id) + norm_stats = _normalize.load(_download.maybe_download(data_assets_dir)) + logging.info(f"Loaded norm stats from {data_assets_dir}") + return norm_stats + except FileNotFoundError: + logging.info(f"Norm stats not found in {data_assets_dir}, skipping.") + return None + + +@dataclasses.dataclass(frozen=True) +class FakeDataConfig(DataConfigFactory): + repo_id: str = "fake" + + @override + def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig: + return DataConfig(repo_id=self.repo_id) + + +@dataclasses.dataclass(frozen=True) +class SimpleDataConfig(DataConfigFactory): + # Factory for the data transforms. + data_transforms: tyro.conf.Suppress[GroupFactory] = dataclasses.field(default_factory=GroupFactory) + # Factory for the model transforms. + model_transforms: tyro.conf.Suppress[GroupFactory] = dataclasses.field(default_factory=ModelTransformFactory) + + @override + def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig: + return dataclasses.replace( + self.create_base_config(assets_dirs, model_config), + data_transforms=self.data_transforms(model_config), + model_transforms=self.model_transforms(model_config), + ) + + +@dataclasses.dataclass(frozen=True) +class LeRobotAlohaDataConfig(DataConfigFactory): + # If true, will convert joint dimensions to deltas with respect to the current state before passing to the model. + # Gripper dimensions will remain in absolute values. + use_delta_joint_actions: bool = True + # If provided, will be injected into the input data if the "prompt" key is not present. + default_prompt: str | None = None + # If true, this will convert the joint and gripper values from the standard Aloha space to + # the space used by the pi internal runtime which was used to train the base model. People who + # use standard Aloha data should set this to true. + adapt_to_pi: bool = True + + # Repack transforms. + repack_transforms: tyro.conf.Suppress[_transforms.Group] = dataclasses.field( + default=_transforms.Group( + inputs=[ + _transforms.RepackTransform( + { + "images": {"cam_high": "observation.images.top"}, + "state": "observation.state", + "actions": "action", + } + ) + ] + ) + ) + # Action keys that will be used to read the action sequence from the dataset. + action_sequence_keys: Sequence[str] = ("action",) + + @override + def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig: + data_transforms = _transforms.Group( + inputs=[aloha_policy.AlohaInputs(adapt_to_pi=self.adapt_to_pi)], + outputs=[aloha_policy.AlohaOutputs(adapt_to_pi=self.adapt_to_pi)], + ) + if self.use_delta_joint_actions: + delta_action_mask = _transforms.make_bool_mask(6, -1, 6, -1) + data_transforms = data_transforms.push( + inputs=[_transforms.DeltaActions(delta_action_mask)], + outputs=[_transforms.AbsoluteActions(delta_action_mask)], + ) + + model_transforms = ModelTransformFactory(default_prompt=self.default_prompt)(model_config) + + return dataclasses.replace( + self.create_base_config(assets_dirs, model_config), + repack_transforms=self.repack_transforms, + data_transforms=data_transforms, + model_transforms=model_transforms, + action_sequence_keys=self.action_sequence_keys, + ) + + +@dataclasses.dataclass(frozen=True) +class LeRobotLiberoDataConfig(DataConfigFactory): + """ + This config is used to configure transforms that are applied at various parts of the data pipeline. + For your own dataset, you can copy this class and modify the transforms to match your dataset based on the + comments below. + """ + + extra_delta_transform: bool = False + + @override + def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig: + # The repack transform is *only* applied to the data coming from the dataset, + # and *not* during inference. We can use it to make inputs from the dataset look + # as close as possible to those coming from the inference environment (e.g. match the keys). + # Below, we match the keys in the dataset (which we defined in the data conversion script) to + # the keys we use in our inference pipeline (defined in the inference script for libero). + # For your own dataset, first figure out what keys your environment passes to the policy server + # and then modify the mappings below so your dataset's keys get matched to those target keys. + # The repack transform simply remaps key names here. + repack_transform = _transforms.Group( + inputs=[ + _transforms.RepackTransform( + { + "observation/image": "image", + "observation/wrist_image": "wrist_image", + "observation/state": "state", + "actions": "actions", + "prompt": "prompt", + } + ) + ] + ) + + # The data transforms are applied to the data coming from the dataset *and* during inference. + # Below, we define the transforms for data going into the model (``inputs``) and the transforms + # for data coming out of the model (``outputs``) (the latter is only used during inference). + # We defined these transforms in `libero_policy.py`. You can check the detailed comments there for + # how to modify the transforms to match your dataset. Once you created your own transforms, you can + # replace the transforms below with your own. + data_transforms = _transforms.Group( + inputs=[libero_policy.LiberoInputs(model_type=model_config.model_type)], + outputs=[libero_policy.LiberoOutputs()], + ) + + # One additional data transform: pi0 models are trained on delta actions (relative to the first + # state in each action chunk). IF your data has ``absolute`` actions (e.g. target joint angles) + # you can uncomment the following line to convert the actions to delta actions. The only exception + # is for the gripper actions which are always absolute. + # In the example below, we would apply the delta conversion to the first 6 actions (joints) and + # leave the 7th action (gripper) unchanged, i.e. absolute. + # In Libero, the raw actions in the dataset are already delta actions, so we *do not* need to + # apply a separate delta conversion (that's why it's commented out). Choose whether to apply this + # transform based on whether your dataset uses ``absolute`` or ``delta`` actions out of the box. + + # LIBERO already represents actions as deltas, but we have some old Pi0 checkpoints that are trained with this + # extra delta transform. + if self.extra_delta_transform: + delta_action_mask = _transforms.make_bool_mask(6, -1) + data_transforms = data_transforms.push( + inputs=[_transforms.DeltaActions(delta_action_mask)], + outputs=[_transforms.AbsoluteActions(delta_action_mask)], + ) + + # Model transforms include things like tokenizing the prompt and action targets + # You do not need to change anything here for your own dataset. + model_transforms = ModelTransformFactory()(model_config) + + # We return all data transforms for training and inference. No need to change anything here. + return dataclasses.replace( + self.create_base_config(assets_dirs, model_config), + repack_transforms=repack_transform, + data_transforms=data_transforms, + model_transforms=model_transforms, + ) + + +@dataclasses.dataclass(frozen=True) +class RLDSDroidDataConfig(DataConfigFactory): + """ + Config for training on DROID, using RLDS data format (for efficient training on larger datasets). + """ + + rlds_data_dir: str | None = None + action_space: droid_rlds_dataset.DroidActionSpace | None = None + + # Filtering options. Can pass a path to a dictionary that maps episodes to timestep ranges + # to tuples denoting ranges of time steps to keep (start, end). Episodes are uniquely identified with + # f"{recording_folderpath}--{file_path}", both of which are present in the RLDS episode metadata. + # Path to the filter dictionary file. + filter_dict_path: str | None = "gs://openpi-assets/droid/droid_sample_ranges_v1_0_1.json" + + @override + def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig: + repack_transform = _transforms.Group( + inputs=[ + _transforms.RepackTransform( + { + "observation/exterior_image_1_left": "observation/image", + "observation/wrist_image_left": "observation/wrist_image", + "observation/joint_position": "observation/joint_position", + "observation/gripper_position": "observation/gripper_position", + "actions": "actions", + "prompt": "prompt", + } + ) + ] + ) + + data_transforms = _transforms.Group( + inputs=[droid_policy.DroidInputs(model_type=model_config.model_type)], + outputs=[droid_policy.DroidOutputs()], + ) + + if self.action_space == droid_rlds_dataset.DroidActionSpace.JOINT_POSITION: + # Data loader returns absolute joint position actions -- convert to delta actions for training. + delta_action_mask = _transforms.make_bool_mask(7, -1) + data_transforms = data_transforms.push( + inputs=[_transforms.DeltaActions(delta_action_mask)], + outputs=[_transforms.AbsoluteActions(delta_action_mask)], + ) + + model_transforms = ModelTransformFactory()(model_config) + + assert self.rlds_data_dir is not None, "Need to set rlds data dir for RLDS data loader." + + return dataclasses.replace( + self.create_base_config(assets_dirs, model_config), + repack_transforms=repack_transform, + data_transforms=data_transforms, + model_transforms=model_transforms, + rlds_data_dir=self.rlds_data_dir, + action_space=self.action_space, + filter_dict_path=self.filter_dict_path, + ) + + +@dataclasses.dataclass(frozen=True) +class LeRobotDROIDDataConfig(DataConfigFactory): + """ + Example data config for custom DROID dataset in LeRobot format. + To convert your custom DROID dataset (<10s of hours) to LeRobot format, see examples/droid/convert_droid_data_to_lerobot.py + """ + + @override + def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig: + repack_transform = _transforms.Group( + inputs=[ + _transforms.RepackTransform( + { + "observation/exterior_image_1_left": "exterior_image_1_left", + "observation/exterior_image_2_left": "exterior_image_2_left", + "observation/wrist_image_left": "wrist_image_left", + "observation/joint_position": "joint_position", + "observation/gripper_position": "gripper_position", + "actions": "actions", + "prompt": "prompt", + } + ) + ] + ) + # We assume joint *velocity* actions, so we should *not* apply an additional delta transform. + data_transforms = _transforms.Group( + inputs=[droid_policy.DroidInputs(model_type=model_config.model_type)], + outputs=[droid_policy.DroidOutputs()], + ) + model_transforms = ModelTransformFactory()(model_config) + + return dataclasses.replace( + self.create_base_config(assets_dirs, model_config), + repack_transforms=repack_transform, + data_transforms=data_transforms, + model_transforms=model_transforms, + ) + + +@dataclasses.dataclass(frozen=True) +class TrainConfig: + # Name of the config. Must be unique. Will be used to reference this config. + name: tyro.conf.Suppress[str] + # Project name. + project_name: str = "openpi" + # Experiment name. Will be used to name the metadata and checkpoint directories. + exp_name: str = tyro.MISSING + + # Defines the model config. Some attributes (action_dim, action_horizon, and max_token_len) are shared by all models + # -- see BaseModelConfig. Specific model implementations (e.g., Pi0Config) inherit from BaseModelConfig and may + # define additional attributes. + model: _model.BaseModelConfig = dataclasses.field(default_factory=pi0_config.Pi0Config) + + # A weight loader can optionally load (possibly partial) weights from disk after the model is initialized. + weight_loader: weight_loaders.WeightLoader = dataclasses.field(default_factory=weight_loaders.NoOpWeightLoader) + + # Optional path to a PyTorch checkpoint to load weights from. + pytorch_weight_path: str | None = None + + # Spatial Forcing configs + vggt_weight_path: str | None = None + vggt_dim: int = 1024 + + vla_layers_align: int | None = None # total 18 for paligemma-2b + vggt_layers_align: int | None = None # total 24 for VGGT + + pooling_func: str | None = None + use_vggt_pe: bool | None = None + use_vlm_norm: bool | None = None + + align_loss_coeff: float = 0.0 + + # CapVector configs + regularization_vector_path: str | None = None + regularization_coeff: float = 0.0 + + # Precision for PyTorch training. + pytorch_training_precision: Literal["bfloat16", "float32"] = "bfloat16" + + lr_schedule: _optimizer.LRScheduleConfig = dataclasses.field(default_factory=_optimizer.CosineDecaySchedule) + optimizer: _optimizer.OptimizerConfig = dataclasses.field(default_factory=_optimizer.AdamW) + ema_decay: float | None = 0.99 + + # Specifies which weights should be frozen. + freeze_filter: tyro.conf.Suppress[Filter] = dataclasses.field(default_factory=nnx.Nothing) + + # Determines the data to be trained on. + data: DataConfigFactory = dataclasses.field(default_factory=FakeDataConfig) + + # Base directory for config assets (e.g., norm stats). + assets_base_dir: str = "./assets" + # Base directory for checkpoints. + checkpoint_base_dir: str = "./checkpoints" + + # Random seed that will be used by random generators during training. + seed: int = 42 + # Global batch size. + batch_size: int = 32 + # Number of workers to use for the data loader. Increasing this number will speed up data loading but + # will increase memory and CPU usage. + num_workers: int = 2 + # Number of train steps (batches) to run. + num_train_steps: int = 30_000 + + # How often (in steps) to log training metrics. + log_interval: int = 100 + # How often (in steps) to save checkpoints. + save_interval: int = 1000 + # If set, any existing checkpoints matching step % keep_period == 0 will not be deleted. + keep_period: int | None = 5000 + + # If true, will overwrite the checkpoint directory if it already exists. + overwrite: bool = False + # If true, will resume training from the last checkpoint. + resume: bool = False + + # If true, will enable wandb logging. + wandb_enabled: bool = True + + # Used to pass metadata to the policy server. + policy_metadata: dict[str, Any] | None = None + + # If the value is greater than 1, FSDP will be enabled and shard across number of specified devices; overall + # device memory will be reduced but training could potentially be slower. + # eg. if total device is 4 and fsdp devices is 2; then the model will shard to 2 devices and run + # data parallel between 2 groups of devices. + fsdp_devices: int = 1 + + @property + def assets_dirs(self) -> pathlib.Path: + """Get the assets directory for this config.""" + return (pathlib.Path(self.assets_base_dir) / self.name).resolve() + + @property + def checkpoint_dir(self) -> pathlib.Path: + """Get the checkpoint directory for this config.""" + if not self.exp_name: + raise ValueError("--exp_name must be set") + return (pathlib.Path(self.checkpoint_base_dir) / self.name / self.exp_name).resolve() + + @property + def trainable_filter(self) -> nnx.filterlib.Filter: + """Get the filter for the trainable parameters.""" + return nnx.All(nnx.Param, nnx.Not(self.freeze_filter)) + + def __post_init__(self) -> None: + if self.resume and self.overwrite: + raise ValueError("Cannot resume and overwrite at the same time.") + + +# Use `get_config` if you need to get a config by name in your code. +_CONFIGS = [ + # + # Inference Aloha configs. + # + TrainConfig( + name="pi0_aloha", + model=pi0_config.Pi0Config(), + data=LeRobotAlohaDataConfig( + assets=AssetsConfig(asset_id="trossen"), + ), + policy_metadata={"reset_pose": [0, -1.5, 1.5, 0, 0, 0]}, + ), + TrainConfig( + name="pi05_aloha", + model=pi0_config.Pi0Config(pi05=True), + data=LeRobotAlohaDataConfig( + assets=AssetsConfig(asset_id="trossen"), + ), + policy_metadata={"reset_pose": [0, -1.5, 1.5, 0, 0, 0]}, + ), + TrainConfig( + name="pi0_aloha_towel", + model=pi0_config.Pi0Config(), + data=LeRobotAlohaDataConfig( + assets=AssetsConfig(asset_id="trossen"), + default_prompt="fold the towel", + ), + policy_metadata={"reset_pose": [0, -1.5, 1.5, 0, 0, 0]}, + ), + TrainConfig( + name="pi0_aloha_tupperware", + model=pi0_config.Pi0Config(), + data=LeRobotAlohaDataConfig( + assets=AssetsConfig(asset_id="trossen"), + default_prompt="open the tupperware and put the food on the plate", + ), + policy_metadata={"reset_pose": [0, -1.5, 1.5, 0, 0, 0]}, + ), + # + # Inference DROID configs. + # + TrainConfig( + name="pi0_droid", + model=pi0_config.Pi0Config(action_horizon=10), + data=SimpleDataConfig( + assets=AssetsConfig(asset_id="droid"), + data_transforms=lambda model: _transforms.Group( + inputs=[droid_policy.DroidInputs(model_type=ModelType.PI0)], + outputs=[droid_policy.DroidOutputs()], + ), + base_config=DataConfig( + prompt_from_task=True, + ), + ), + ), + TrainConfig( + name="pi0_fast_droid", + model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=10), + data=SimpleDataConfig( + assets=AssetsConfig(asset_id="droid"), + data_transforms=lambda model: _transforms.Group( + inputs=[droid_policy.DroidInputs(model_type=ModelType.PI0_FAST)], + outputs=[droid_policy.DroidOutputs()], + ), + base_config=DataConfig( + prompt_from_task=True, + ), + ), + ), + TrainConfig( + name="pi05_droid", + model=pi0_config.Pi0Config(action_horizon=15, pi05=True), + data=SimpleDataConfig( + assets=AssetsConfig(asset_id="droid"), + data_transforms=lambda model: _transforms.Group( + inputs=[droid_policy.DroidInputs(model_type=ModelType.PI05)], + outputs=[droid_policy.DroidOutputs()], + ), + base_config=DataConfig( + prompt_from_task=True, + ), + ), + ), + # + # Fine-tuning Libero configs. + # + # These train configs define the hyperparameters for fine-tuning the base model on your own dataset. + # They are used to define key elements like the dataset you are training on, the base checkpoint you + # are using, and other hyperparameters like how many training steps to run or what learning rate to use. + # For your own dataset, you can copy this class and modify the dataset name, and data transforms based on + # the comments below. + TrainConfig( + # Change the name to reflect your model and dataset. + name="pi0_libero", + # Here you define the model config -- In this example we use pi0 as the model + # architecture and perform *full* finetuning. in the examples below we show how to modify + # this to perform *low-memory* (LORA) finetuning and use pi0-FAST as an alternative architecture. + model=pi0_config.Pi0Config(), + # Here you define the dataset you are training on. In this example we use the Libero + # dataset. For your own dataset, you can change the repo_id to point to your dataset. + # Also modify the DataConfig to use the new config you made for your dataset above. + data=LeRobotLiberoDataConfig( + repo_id="physical-intelligence/libero", + base_config=DataConfig( + # This flag determines whether we load the prompt (i.e. the task instruction) from the + # ``task`` field in the LeRobot dataset. If set to True, the prompt will show up in + # a field called ``prompt`` in the input dict. The recommended setting is True. + prompt_from_task=True, + ), + extra_delta_transform=True, + ), + # Here you define which pre-trained checkpoint you want to load to initialize the model. + # This should match the model config you chose above -- i.e. in this case we use the pi0 base model. + weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"), + # Below you can define other hyperparameters like the learning rate, number of training steps, etc. + # Check the base TrainConfig class for a full list of available hyperparameters. + num_train_steps=30_000, + ), + TrainConfig( + name="pi0_libero_low_mem_finetune", + # Here is an example of loading a pi0 model for LoRA fine-tuning. + model=pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora"), + data=LeRobotLiberoDataConfig( + repo_id="physical-intelligence/libero", + base_config=DataConfig(prompt_from_task=True), + extra_delta_transform=True, + ), + weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"), + num_train_steps=30_000, + # The freeze filter defines which parameters should be frozen during training. + # We have a convenience function in the model config that returns the default freeze filter + # for the given model config for LoRA finetuning. Just make sure it matches the model config + # you chose above. + freeze_filter=pi0_config.Pi0Config( + paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora" + ).get_freeze_filter(), + # Turn off EMA for LoRA finetuning. + ema_decay=None, + ), + TrainConfig( + name="pi0_fast_libero", + # Here is an example of loading a pi0-FAST model for full finetuning. + # Modify action_dim and action_horizon to match your dataset (action horizon is equal to + # the desired action chunk length). + # The max_token_len is the maximum number of (non-image) tokens the model can handle. + # This includes the tokenized prompt, proprioceptive state, and (FAST-tokenized) action tokens. + # Choosing this value too small may chop off tokens at the end of your sequence (the code will throw + # a warning), while choosing it too large will waste memory (since we pad each batch element to the + # max_token_len). A good rule of thumb is to use approx 180 for single-arm robots, and approx 250 for + # two-arm robots. Generally, err on the lower side here first, and potentially increase the value if + # you see many warnings being thrown during training. + model=pi0_fast.Pi0FASTConfig(action_dim=7, action_horizon=10, max_token_len=180), + data=LeRobotLiberoDataConfig( + repo_id="physical-intelligence/libero", + base_config=DataConfig(prompt_from_task=True), + extra_delta_transform=True, + ), + # Note that we load the pi0-FAST base model checkpoint here. + weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_fast_base/params"), + num_train_steps=30_000, + ), + TrainConfig( + name="pi0_fast_libero_low_mem_finetune", + # Here is an example of loading a pi0-FAST model for LoRA finetuning. + # For setting action_dim, action_horizon, and max_token_len, see the comments above. + model=pi0_fast.Pi0FASTConfig( + action_dim=7, action_horizon=10, max_token_len=180, paligemma_variant="gemma_2b_lora" + ), + data=LeRobotLiberoDataConfig( + repo_id="physical-intelligence/libero", + base_config=DataConfig(prompt_from_task=True), + extra_delta_transform=True, + ), + weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_fast_base/params"), + num_train_steps=30_000, + # Again, make sure to match the model config above when extracting the freeze filter + # that specifies which parameters should be frozen during LoRA finetuning. + freeze_filter=pi0_fast.Pi0FASTConfig( + action_dim=7, action_horizon=10, max_token_len=180, paligemma_variant="gemma_2b_lora" + ).get_freeze_filter(), + # Turn off EMA for LoRA finetuning. + ema_decay=None, + ), + TrainConfig( + name="pi05_libero", + model=pi0_config.Pi0Config(pi05=True, action_horizon=10, discrete_state_input=False), + data=LeRobotLiberoDataConfig( + repo_id="physical-intelligence/libero", + base_config=DataConfig(prompt_from_task=True), + extra_delta_transform=False, + ), + batch_size=256, + lr_schedule=_optimizer.CosineDecaySchedule( + warmup_steps=10_000, + peak_lr=5e-5, + decay_steps=1_000_000, + decay_lr=5e-5, + ), + optimizer=_optimizer.AdamW(clip_gradient_norm=1.0), + ema_decay=0.999, + weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_base/params"), + pytorch_weight_path="/path/to/your/pytorch_weight_path", + num_train_steps=30_000, + ), + # + # Fine-tuning Aloha configs. + # + # Personal Tasks + TrainConfig( + name="pi05_capvector_aloha_place_block", # + model=pi0_config.Pi0Config(pi05=True, discrete_state_input=False), + data=LeRobotAlohaDataConfig( + repo_id="cobot_dataset/place_one_floor_block", # your datasets repo_id, like "/" + default_prompt="place the green block", + repack_transforms=_transforms.Group( + inputs=[ + _transforms.RepackTransform( + { + "images": { + "cam_high": "observation.images.cam_high", + "cam_left_wrist": "observation.images.cam_left_wrist", + "cam_right_wrist": "observation.images.cam_right_wrist", + }, + "state": "observation.state", + "actions": "action", + } + ) + ] + ), + ), + weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_base/params"), + pytorch_weight_path='./checkpoints/vector_init/pi05SF-LIBEROspatial_minus_pi05-LIBEROspatial', + # CapVector + regularization_vector_path='checkpoints/diff/pi05SF-LIBEROspatial_minus_pi05-LIBEROspatial.pth', + regularization_coeff=1e-4, + # + num_train_steps=30_000, + batch_size=16, + ema_decay=None, + wandb_enabled=False, + ), + # + # This is a test config that is used to illustate how train on a custom LeRobot dataset. + # For instuctions on how to convert and train on your own Aloha dataset see examples/aloha_real/README.md + TrainConfig( + name="pi0_aloha_pen_uncap", + model=pi0_config.Pi0Config(), + data=LeRobotAlohaDataConfig( + repo_id="physical-intelligence/aloha_pen_uncap_diverse", + assets=AssetsConfig( + assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets", + asset_id="trossen", + ), + default_prompt="uncap the pen", + repack_transforms=_transforms.Group( + inputs=[ + _transforms.RepackTransform( + { + "images": { + "cam_high": "observation.images.cam_high", + "cam_left_wrist": "observation.images.cam_left_wrist", + "cam_right_wrist": "observation.images.cam_right_wrist", + }, + "state": "observation.state", + "actions": "action", + } + ) + ] + ), + ), + weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"), + num_train_steps=20_000, + ), + TrainConfig( + name="pi05_aloha_pen_uncap", + model=pi0_config.Pi0Config(pi05=True), + data=LeRobotAlohaDataConfig( + repo_id="physical-intelligence/aloha_pen_uncap_diverse", + assets=AssetsConfig( + assets_dir="gs://openpi-assets/checkpoints/pi05_base/assets", + asset_id="trossen", + ), + default_prompt="uncap the pen", + repack_transforms=_transforms.Group( + inputs=[ + _transforms.RepackTransform( + { + "images": { + "cam_high": "observation.images.cam_high", + "cam_left_wrist": "observation.images.cam_left_wrist", + "cam_right_wrist": "observation.images.cam_right_wrist", + }, + "state": "observation.state", + "actions": "action", + } + ) + ] + ), + ), + weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_base/params"), + num_train_steps=20_000, + batch_size=64, + ), + # + # Fine-tuning DROID configs. + # + TrainConfig( + # This config is for fine-tuning pi0-FAST-base on the *full* DROID dataset. + # We use RLDS data loading to make training on this large dataset tractable. + # For fine-tuning on your own DROID dataset, see below. + name="pi0_fast_full_droid_finetune", + model=pi0_fast.Pi0FASTConfig( + action_dim=8, + action_horizon=16, + max_token_len=180, + ), + data=RLDSDroidDataConfig( + repo_id="droid", + # Set this to the path to your DROID RLDS dataset (the parent directory of the `droid` directory). + rlds_data_dir="", + action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION, + ), + weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_fast_base/params"), + lr_schedule=_optimizer.CosineDecaySchedule( + warmup_steps=1_000, + peak_lr=5e-5, + decay_steps=1_000_000, + decay_lr=5e-5, + ), + num_train_steps=100_000, # 100k steps should be sufficient, takes ~2 days on 8x H100s + batch_size=256, + log_interval=100, + save_interval=5000, + keep_period=20_000, + num_workers=0, # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally + ), + TrainConfig( + # This config is for fine-tuning pi05 on the *full* DROID dataset. + # We use RLDS data loading to make training on this large dataset tractable. + # For fine-tuning on your own DROID dataset, see below. + name="pi05_full_droid_finetune", + model=pi0_config.Pi0Config( + pi05=True, + action_dim=32, + action_horizon=16, + ), + data=RLDSDroidDataConfig( + repo_id="droid", + # Set this to the path to your DROID RLDS dataset (the parent directory of the `droid` directory). + rlds_data_dir="/mnt/pi-data/kevin", + action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION, + assets=AssetsConfig( + assets_dir="gs://openpi-assets/checkpoints/pi05_base/assets/", + asset_id="droid", + ), + ), + weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_base/params"), + lr_schedule=_optimizer.CosineDecaySchedule( + warmup_steps=1_000, + peak_lr=5e-5, + decay_steps=1_000_000, + decay_lr=5e-5, + ), + num_train_steps=100_000, + batch_size=256, + log_interval=100, + save_interval=5000, + keep_period=10_000, + num_workers=0, # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally + ), + TrainConfig( + # This config is for fine-tuning pi05-DROID on a custom (smaller) DROID dataset. + # Here, we use LeRobot data format (like for all other fine-tuning examples) + # To convert your custom DROID dataset (<10s of hours) to LeRobot format, see examples/droid/convert_droid_data_to_lerobot.py + name="pi05_droid_finetune", + model=pi0_config.Pi0Config( + pi05=True, + action_dim=32, # pi05 is trained with 32-dim actions + action_horizon=16, + ), + data=LeRobotDROIDDataConfig( + # Replace with your custom DROID LeRobot dataset repo id. + repo_id="your_hf_username/my_droid_dataset", + base_config=DataConfig(prompt_from_task=True), + assets=AssetsConfig( + # Important: reuse the original DROID norm stats during fine-tuning! + assets_dir="gs://openpi-assets/checkpoints/pi05_droid/assets", + asset_id="droid", + ), + ), + weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_droid/params"), + num_train_steps=20_000, + batch_size=32, + ), + # + # ALOHA Sim configs. This config is used to demonstrate how to train on a simple simulated environment. + # + TrainConfig( + name="pi0_aloha_sim", + model=pi0_config.Pi0Config(), + data=LeRobotAlohaDataConfig( + repo_id="lerobot/aloha_sim_transfer_cube_human", + default_prompt="Transfer cube", + use_delta_joint_actions=False, + ), + weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"), + num_train_steps=20_000, + ), + # + # Debugging configs. + # + TrainConfig( + name="debug", + data=FakeDataConfig(), + batch_size=2, + model=pi0_config.Pi0Config(paligemma_variant="dummy", action_expert_variant="dummy"), + save_interval=100, + overwrite=True, + exp_name="debug", + num_train_steps=10, + wandb_enabled=False, + ), + TrainConfig( + name="debug_restore", + data=FakeDataConfig(), + batch_size=2, + model=pi0_config.Pi0Config(paligemma_variant="dummy", action_expert_variant="dummy"), + weight_loader=weight_loaders.CheckpointWeightLoader("./checkpoints/debug/debug/9/params"), + overwrite=True, + exp_name="debug", + num_train_steps=10, + wandb_enabled=False, + ), + TrainConfig( + name="debug_pi05", + model=pi0_config.Pi0Config(pi05=True, paligemma_variant="dummy", action_expert_variant="dummy"), + data=FakeDataConfig(), + batch_size=2, + num_train_steps=10, + overwrite=True, + exp_name="debug_pi05", + wandb_enabled=False, + ), + # + # RoboArena configs. + # + *roboarena_config.get_roboarena_configs(), +] + +if len({config.name for config in _CONFIGS}) != len(_CONFIGS): + raise ValueError("Config names must be unique.") +_CONFIGS_DICT = {config.name: config for config in _CONFIGS} + + +def cli() -> TrainConfig: + return tyro.extras.overridable_config_cli({k: (k, v) for k, v in _CONFIGS_DICT.items()}) + + +def get_config(config_name: str) -> TrainConfig: + """Get a config by name.""" + if config_name not in _CONFIGS_DICT: + closest = difflib.get_close_matches(config_name, _CONFIGS_DICT.keys(), n=1, cutoff=0.0) + closest_str = f" Did you mean '{closest[0]}'? " if closest else "" + raise ValueError(f"Config '{config_name}' not found.{closest_str}") + + return _CONFIGS_DICT[config_name] diff --git a/capvector-pi05/src/openpi/training/data_loader.py b/capvector-pi05/src/openpi/training/data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..1847371abda186a0ea9de50e9f86b504e90a2cc2 --- /dev/null +++ b/capvector-pi05/src/openpi/training/data_loader.py @@ -0,0 +1,540 @@ +from collections.abc import Iterator, Sequence +import logging +import multiprocessing +import os +import typing +from typing import Literal, Protocol, SupportsIndex, TypeVar + +import jax +import jax.numpy as jnp +import lerobot.common.datasets.lerobot_dataset as lerobot_dataset +import numpy as np +import torch + +import openpi.models.model as _model +import openpi.training.config as _config +from openpi.training.droid_rlds_dataset import DroidRldsDataset +import openpi.transforms as _transforms + +T_co = TypeVar("T_co", covariant=True) + + +class Dataset(Protocol[T_co]): + """Interface for a dataset with random access.""" + + def __getitem__(self, index: SupportsIndex) -> T_co: + raise NotImplementedError("Subclasses of Dataset should implement __getitem__.") + + def __len__(self) -> int: + raise NotImplementedError("Subclasses of Dataset should implement __len__.") + + +class IterableDataset(Protocol[T_co]): + """Interface for an iterable dataset.""" + + def __iter__(self) -> Iterator[T_co]: + raise NotImplementedError("Subclasses of IterableDataset should implement __iter__.") + + def __len__(self) -> int: + raise NotImplementedError("Subclasses of Dataset should implement __len__.") + + +class DataLoader(Protocol[T_co]): + """Interface for a data loader.""" + + def data_config(self) -> _config.DataConfig: + """Get the data config for this data loader.""" + raise NotImplementedError("Subclasses of DataLoader should implement data_config.") + + def __iter__(self) -> Iterator[T_co]: + raise NotImplementedError("Subclasses of DataLoader should implement __iter__.") + + +class TransformedDataset(Dataset[T_co]): + def __init__(self, dataset: Dataset, transforms: Sequence[_transforms.DataTransformFn]): + self._dataset = dataset + self._transform = _transforms.compose(transforms) + + def __getitem__(self, index: SupportsIndex) -> T_co: + return self._transform(self._dataset[index]) + + def __len__(self) -> int: + return len(self._dataset) + + +class IterableTransformedDataset(IterableDataset[T_co]): + def __init__( + self, + dataset: IterableDataset, + transforms: Sequence[_transforms.DataTransformFn], + *, + is_batched: bool = False, + ): + self._dataset = dataset + self._transform = _transforms.compose(transforms) + self._is_batched = is_batched + + def __iter__(self): + for sample in self._dataset: + if self._is_batched: + # Transforms are designed to be applied to individual samples. So we need to split the batch into + # individual samples and apply the transform to each sample individually. + batch_size = next(v.shape[0] for v in sample.values()) + + # Split batch into individual samples using tree_map + individual_samples = [jax.tree.map(lambda x: x[i], sample) for i in range(batch_size)] # noqa: B023 + + # Transform each sample + transformed = [self._transform(s) for s in individual_samples] + + # Recombine batch with tree_map + yield jax.tree.map(lambda *x: np.stack(x, axis=0), *transformed) + else: + yield self._transform(sample) + + def __len__(self) -> int: + return len(self._dataset) + + +class FakeDataset(Dataset): + def __init__(self, model_config: _model.BaseModelConfig, num_samples: int): + self._num_samples = num_samples + self._observation_spec, self._action_spec = model_config.inputs_spec() + + def __getitem__(self, index: SupportsIndex) -> dict: + rng = jax.random.key(index.__index__()) + + def make_from_spec(spec: jax.ShapeDtypeStruct): + nonlocal rng + rng, data_rng = jax.random.split(rng) + # Remove the batch dimension. + shape = spec.shape[1:] + if spec.dtype == jnp.float32: + return jax.random.uniform(data_rng, shape=shape, minval=-1.0, maxval=1.0) + if spec.dtype == jnp.int32: + return jax.random.randint(data_rng, shape=shape, minval=0, maxval=2048) + return jnp.zeros(shape=shape, dtype=spec.dtype) + + observation = jax.tree.map(make_from_spec, self._observation_spec) + action = jax.tree.map(make_from_spec, self._action_spec) + + return { + **observation.to_dict(), + "actions": action, + } + + def __len__(self) -> int: + return self._num_samples + + +def create_torch_dataset( + data_config: _config.DataConfig, action_horizon: int, model_config: _model.BaseModelConfig +) -> Dataset: + """Create a dataset for training.""" + repo_id = data_config.repo_id + if repo_id is None: + raise ValueError("Repo ID is not set. Cannot create dataset.") + if repo_id == "fake": + return FakeDataset(model_config, num_samples=1024) + + dataset_meta = lerobot_dataset.LeRobotDatasetMetadata(repo_id) + dataset = lerobot_dataset.LeRobotDataset( + data_config.repo_id, + delta_timestamps={ + key: [t / dataset_meta.fps for t in range(action_horizon)] for key in data_config.action_sequence_keys + }, + ) + + if data_config.prompt_from_task: + dataset = TransformedDataset(dataset, [_transforms.PromptFromLeRobotTask(dataset_meta.tasks)]) + + return dataset + + +def create_rlds_dataset( + data_config: _config.DataConfig, + action_horizon: int, + batch_size: int, + *, + shuffle: bool = False, +) -> Dataset: + # At the moment, we only support DROID for RLDS datasets. + return DroidRldsDataset( + data_dir=data_config.rlds_data_dir, + batch_size=batch_size, + shuffle=shuffle, + action_chunk_size=action_horizon, + action_space=data_config.action_space, + filter_dict_path=data_config.filter_dict_path, + ) + + +def transform_dataset(dataset: Dataset, data_config: _config.DataConfig, *, skip_norm_stats: bool = False) -> Dataset: + """Transform the dataset by applying the data transforms.""" + norm_stats = {} + if data_config.repo_id != "fake" and not skip_norm_stats: + if data_config.norm_stats is None: + raise ValueError( + "Normalization stats not found. " + "Make sure to run `scripts/compute_norm_stats.py --config-name=`." + ) + norm_stats = data_config.norm_stats + + return TransformedDataset( + dataset, + [ + *data_config.repack_transforms.inputs, + *data_config.data_transforms.inputs, + _transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm), + *data_config.model_transforms.inputs, + ], + ) + + +def transform_iterable_dataset( + dataset: IterableDataset, + data_config: _config.DataConfig, + *, + skip_norm_stats: bool = False, + is_batched: bool = False, +) -> IterableDataset: + """Transform the dataset by applying the data transforms.""" + norm_stats = {} + if data_config.repo_id != "fake" and not skip_norm_stats: + if data_config.norm_stats is None: + raise ValueError( + "Normalization stats not found. " + "Make sure to run `scripts/compute_norm_stats.py --config-name=`." + ) + norm_stats = data_config.norm_stats + + return IterableTransformedDataset( + dataset, + [ + *data_config.repack_transforms.inputs, + *data_config.data_transforms.inputs, + _transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm), + *data_config.model_transforms.inputs, + ], + is_batched=is_batched, + ) + + +def create_data_loader( + config: _config.TrainConfig, + *, + sharding: jax.sharding.Sharding | None = None, + shuffle: bool = False, + num_batches: int | None = None, + skip_norm_stats: bool = False, + framework: Literal["jax", "pytorch"] = "jax", +) -> DataLoader[tuple[_model.Observation, _model.Actions]]: + """Create a data loader for training. + + Args: + config: The training configuration. + sharding: The sharding to use for the data loader (JAX only). + shuffle: Whether to shuffle the data. + num_batches: Determines the number of batches to return. + skip_norm_stats: Whether to skip data normalization. + framework: The framework to use ("jax" or "pytorch"). + """ + data_config = config.data.create(config.assets_dirs, config.model) + logging.info(f"data_config: {data_config}") + + if data_config.rlds_data_dir is not None: + return create_rlds_data_loader( + data_config, + action_horizon=config.model.action_horizon, + batch_size=config.batch_size, + sharding=sharding, + shuffle=shuffle, + num_batches=num_batches, + skip_norm_stats=skip_norm_stats, + framework=framework, + ) + return create_torch_data_loader( + data_config, + model_config=config.model, + action_horizon=config.model.action_horizon, + batch_size=config.batch_size, + sharding=sharding, + shuffle=shuffle, + num_batches=num_batches, + num_workers=config.num_workers, + seed=config.seed, + skip_norm_stats=skip_norm_stats, + framework=framework, + ) + + +def create_torch_data_loader( + data_config: _config.DataConfig, + model_config: _model.BaseModelConfig, + action_horizon: int, + batch_size: int, + *, + sharding: jax.sharding.Sharding | None = None, + skip_norm_stats: bool = False, + shuffle: bool = False, + num_batches: int | None = None, + num_workers: int = 0, + seed: int = 0, + framework: str = "jax", +) -> DataLoader[tuple[_model.Observation, _model.Actions]]: + """Create a data loader for training. + + Args: + data_config: The data configuration. + action_horizon: The action horizon. + batch_size: The batch size. + sharding: The sharding to use for the data loader. If None, the data loader will + use a single device sharding. + skip_norm_stats: Whether to skip data normalization. + shuffle: Whether to shuffle the data. + num_batches: Determines the number of batches to return. If the number exceeds the + number of batches in the dataset, the data loader will loop over the dataset. + If not provided, will iterate over the dataset indefinitely. + num_workers: The number of worker processes to use. If zero, the data loader will + execute in the main process. + seed: The seed to use for shuffling the data. + """ + dataset = create_torch_dataset(data_config, action_horizon, model_config) + dataset = transform_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats) + + # Use TorchDataLoader for both frameworks + # For PyTorch DDP, create DistributedSampler and divide batch size by world size + # For JAX, divide by process count + sampler = None + if framework == "pytorch": + if torch.distributed.is_initialized(): + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, + num_replicas=torch.distributed.get_world_size(), + rank=torch.distributed.get_rank(), + shuffle=shuffle, + drop_last=True, + ) + local_batch_size = batch_size // torch.distributed.get_world_size() + else: + local_batch_size = batch_size + else: + local_batch_size = batch_size // jax.process_count() + + logging.info(f"local_batch_size: {local_batch_size}") + data_loader = TorchDataLoader( + dataset, + local_batch_size=local_batch_size, + sharding=None if framework == "pytorch" else sharding, + shuffle=(sampler is None and shuffle), # Don't shuffle if using sampler + sampler=sampler, + num_batches=num_batches, + num_workers=num_workers, + seed=seed, + framework=framework, + ) + + return DataLoaderImpl(data_config, data_loader) + + +def create_rlds_data_loader( + data_config: _config.DataConfig, + action_horizon: int, + batch_size: int, + *, + sharding: jax.sharding.Sharding | None = None, + skip_norm_stats: bool = False, + shuffle: bool = False, + num_batches: int | None = None, + framework: str = "jax", +) -> DataLoader[tuple[_model.Observation, _model.Actions]]: + """Create an RLDS data loader for training. + + Note: This data loader requires some extra dependencies -- see examples/droid/README_train.md + + Args: + data_config: The data configuration. + action_horizon: The action horizon. + batch_size: The batch size. + sharding: The sharding to use for the data loader. If None, the data loader will + use a single device sharding. + skip_norm_stats: Whether to skip data normalization. + shuffle: Whether to shuffle the data. + num_batches: Determines the number of batches to return. If the number exceeds the + number of batches in the dataset, the data loader will loop over the dataset. + If not provided, will iterate over the dataset indefinitely. + """ + if framework == "pytorch": + raise NotImplementedError("PyTorch RLDS data loader is not supported yet") + dataset = create_rlds_dataset(data_config, action_horizon, batch_size, shuffle=shuffle) + dataset = transform_iterable_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats, is_batched=True) + + data_loader = RLDSDataLoader( + dataset, + sharding=sharding, + num_batches=num_batches, + ) + + return DataLoaderImpl(data_config, data_loader) + + +class TorchDataLoader: + """Torch data loader implementation.""" + + def __init__( + self, + dataset, + local_batch_size: int, + *, + sharding: jax.sharding.Sharding | None = None, + shuffle: bool = False, + sampler: torch.utils.data.Sampler | None = None, + num_batches: int | None = None, + num_workers: int = 0, + seed: int = 0, + framework: str = "jax", + ): + """Create a PyTorch data loader. + + Args: + dataset: The dataset to load. + local_batch_size: The local batch size for each process. + sharding: The sharding to use for the data loader. + shuffle: Whether to shuffle the data. + num_batches: If provided, determines the number of returned batches. If the + number is larger than the number of batches in the dataset, the data loader + will loop over the dataset. If not provided, will iterate over the dataset + indefinitely. + num_workers: The number of worker processes to use. If zero, the data loader will + execute in the main process. + seed: The seed to use for shuffling the data. + """ + if jax.process_count() > 1: + raise NotImplementedError("Data loading with multiple processes is not supported.") + + if len(dataset) < local_batch_size: + raise ValueError(f"Local batch size ({local_batch_size}) is larger than the dataset size ({len(dataset)}).") + + # Store sharding - None for PyTorch, JAX sharding for JAX + self._sharding = sharding + if sharding is None and framework == "jax": + # Use data parallel sharding by default for JAX only. + self._sharding = jax.sharding.NamedSharding( + jax.sharding.Mesh(jax.devices(), ("B",)), + jax.sharding.PartitionSpec("B"), + ) + self._num_batches = num_batches + + mp_context = None + if num_workers > 0: + mp_context = multiprocessing.get_context("spawn") + + generator = torch.Generator() + generator.manual_seed(seed) + self._data_loader = torch.utils.data.DataLoader( + typing.cast(torch.utils.data.Dataset, dataset), + batch_size=local_batch_size, + shuffle=(sampler is None and shuffle), # Don't shuffle if using sampler + sampler=sampler, + num_workers=num_workers, + multiprocessing_context=mp_context, + persistent_workers=num_workers > 0, + collate_fn=_collate_fn, + worker_init_fn=_worker_init_fn, + drop_last=True, + generator=generator, + ) + + @property + def torch_loader(self) -> torch.utils.data.DataLoader: + return self._data_loader + + def __iter__(self): + num_items = 0 + while True: + data_iter = iter(self._data_loader) + while True: + if self._num_batches is not None and num_items >= self._num_batches: + return + try: + batch = next(data_iter) + except StopIteration: + break # We've exhausted the dataset. Create a new iterator and start over. + num_items += 1 + # For JAX, convert to sharded arrays; for PyTorch, return torch tensors + if self._sharding is not None: + yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch) + else: + yield jax.tree.map(torch.as_tensor, batch) + + +def _collate_fn(items): + """Collate the batch elements into batched numpy arrays.""" + # Make sure to convert to numpy arrays before stacking since some of the incoming elements + # may be JAX arrays. + return jax.tree.map(lambda *xs: np.stack([np.asarray(x) for x in xs], axis=0), *items) + + +def _worker_init_fn(worker_id: int) -> None: + """Tell JAX inside the worker process not to preallocate the GPU memory.""" + # NOTE: This is called after jax is imported inside the worker process. This + # means that this approach will not work for selecting the backend. + os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" + os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" + + +class RLDSDataLoader: + """Shallow wrapper around the DROID data loader to make it compatible with openpi. + + All batching already happens in the DROID dataset, so we don't need to do anything here. + """ + + def __init__( + self, + dataset: DroidRldsDataset, + *, + sharding: jax.sharding.Sharding | None = None, + num_batches: int | None = None, + ): + self._dataset = dataset + self._num_batches = num_batches + + if jax.process_count() > 1: + raise NotImplementedError("Data loading with multiple processes is not supported.") + + if sharding is None: + # Use data parallel sharding by default. + sharding = jax.sharding.NamedSharding( + jax.sharding.Mesh(jax.devices(), ("B",)), + jax.sharding.PartitionSpec("B"), + ) + + self._sharding = sharding + self._num_batches = num_batches + + def __iter__(self): + num_items = 0 + while True: + data_iter = iter(self._dataset) + while True: + if self._num_batches is not None and num_items >= self._num_batches: + return + try: + batch = next(data_iter) + except StopIteration: + break # We've exhausted the dataset. Create a new iterator and start over. + num_items += 1 + yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch) + + +class DataLoaderImpl(DataLoader): + def __init__(self, data_config: _config.DataConfig, data_loader: TorchDataLoader | RLDSDataLoader): + self._data_config = data_config + self._data_loader = data_loader + + def data_config(self) -> _config.DataConfig: + return self._data_config + + def __iter__(self): + for batch in self._data_loader: + yield _model.Observation.from_dict(batch), batch["actions"] diff --git a/capvector-pi05/src/openpi/training/data_loader_test.py b/capvector-pi05/src/openpi/training/data_loader_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3b77188885a14c4c82a23f288d912aab03028dcb --- /dev/null +++ b/capvector-pi05/src/openpi/training/data_loader_test.py @@ -0,0 +1,84 @@ +import dataclasses + +import jax + +from openpi.models import pi0_config +from openpi.training import config as _config +from openpi.training import data_loader as _data_loader + + +def test_torch_data_loader(): + config = pi0_config.Pi0Config(action_dim=24, action_horizon=50, max_token_len=48) + dataset = _data_loader.FakeDataset(config, 16) + + loader = _data_loader.TorchDataLoader( + dataset, + local_batch_size=4, + num_batches=2, + ) + batches = list(loader) + + assert len(batches) == 2 + for batch in batches: + assert all(x.shape[0] == 4 for x in jax.tree.leaves(batch)) + + +def test_torch_data_loader_infinite(): + config = pi0_config.Pi0Config(action_dim=24, action_horizon=50, max_token_len=48) + dataset = _data_loader.FakeDataset(config, 4) + + loader = _data_loader.TorchDataLoader(dataset, local_batch_size=4) + data_iter = iter(loader) + + for _ in range(10): + _ = next(data_iter) + + +def test_torch_data_loader_parallel(): + config = pi0_config.Pi0Config(action_dim=24, action_horizon=50, max_token_len=48) + dataset = _data_loader.FakeDataset(config, 10) + + loader = _data_loader.TorchDataLoader(dataset, local_batch_size=4, num_batches=2, num_workers=2) + batches = list(loader) + + assert len(batches) == 2 + + for batch in batches: + assert all(x.shape[0] == 4 for x in jax.tree.leaves(batch)) + + +def test_with_fake_dataset(): + config = _config.get_config("debug") + + loader = _data_loader.create_data_loader(config, skip_norm_stats=True, num_batches=2) + batches = list(loader) + + assert len(batches) == 2 + + for batch in batches: + assert all(x.shape[0] == config.batch_size for x in jax.tree.leaves(batch)) + + for _, actions in batches: + assert actions.shape == (config.batch_size, config.model.action_horizon, config.model.action_dim) + + +def test_with_real_dataset(): + config = _config.get_config("pi0_aloha_sim") + config = dataclasses.replace(config, batch_size=4) + + loader = _data_loader.create_data_loader( + config, + # Skip since we may not have the data available. + skip_norm_stats=True, + num_batches=2, + shuffle=True, + ) + # Make sure that we can get the data config. + assert loader.data_config().repo_id == config.data.repo_id + + batches = list(loader) + + assert len(batches) == 2 + + for _, actions in batches: + assert actions.shape == (config.batch_size, config.model.action_horizon, config.model.action_dim) diff --git a/capvector-pi05/src/openpi/training/droid_rlds_dataset.py b/capvector-pi05/src/openpi/training/droid_rlds_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..debbe73cde9718b3d17157799b676fd65b00c7ba --- /dev/null +++ b/capvector-pi05/src/openpi/training/droid_rlds_dataset.py @@ -0,0 +1,221 @@ +""" +RLDS-based data loader for DROID. +While openpi typically uses LeRobot's data loader, it is not currently scalable enough for larger datasets like DROID. +Thus, we provide a data loader example here that uses the RLDS data format. +The data loader also applies a few DROID-specific data filters / transformations. +""" + +from enum import Enum +from enum import auto +import json +import logging +from pathlib import Path + +import tqdm + +import openpi.shared.download as download + + +class DroidActionSpace(Enum): + """Action space for DROID dataset.""" + + JOINT_POSITION = auto() + JOINT_VELOCITY = auto() + + +class DroidRldsDataset: + def __init__( + self, + data_dir: str, + batch_size: int, + *, # Force keyword-only arguments + shuffle: bool = True, + action_chunk_size: int = 16, + # We default to joint position actions, since they allow policy evaluation in simulation. + action_space: DroidActionSpace = DroidActionSpace.JOINT_POSITION, + max_loaded_steps_per_episode: int = 100, + # Reduce this if you are running out of memory, but careful -- below ~100k shuffling is not sufficiently random. + shuffle_buffer_size: int = 250_000, + num_parallel_reads: int = -1, # -1 == tf.data.AUTOTUNE -- hack to not import tf at top level + num_parallel_calls: int = -1, # -1 == tf.data.AUTOTUNE -- hack to not import tf at top level + filter_dict_path=None, # Path to json file with indices to sample during training + ): + # Import tensorflow here to not make it mandatory in case RLDS data loader is not used. + import dlimp as dl + import tensorflow as tf + import tensorflow_datasets as tfds + + # Configure Tensorflow with *no GPU devices* (to prevent clobber with PyTorch / JAX) + tf.config.set_visible_devices([], "GPU") + + builder = tfds.builder("droid", data_dir=data_dir, version="1.0.1") + dataset = dl.DLataset.from_rlds(builder, split="train", shuffle=shuffle, num_parallel_reads=num_parallel_reads) + + # Filter out any unsuccessful trajectories -- we use the file name to check this + dataset = dataset.filter( + lambda traj: tf.strings.regex_full_match( + traj["traj_metadata"]["episode_metadata"]["file_path"][0], ".*success.*" + ) + ) + + # # Repeat dataset so we never run out of data. + dataset = dataset.repeat() + + # Load the filter dictionary if provided. + # The filter dictionary is a JSON file that maps episode keys to ranges of frames to sample + # (e.g., + # { + # "": [[0, 100], [200, 300]] + # } + # means keep frames 0-99 and 200-299). + if filter_dict_path is not None: + cached_filter_dict_path = download.maybe_download(filter_dict_path) + with Path(cached_filter_dict_path).open("r") as f: + filter_dict = json.load(f) + + logging.info(f"Using filter dictionary with {len(filter_dict)} episodes") + + keys_tensor = [] + values_tensor = [] + + for episode_key, ranges in tqdm.tqdm(filter_dict.items(), desc="Creating idle filter hash table..."): + for start, end in ranges: + for t in range(start, end): + frame_key = f"{episode_key}--{t}" + keys_tensor.append(frame_key) + values_tensor.append(True) + self.filter_table = tf.lookup.StaticHashTable( + tf.lookup.KeyValueTensorInitializer(keys_tensor, values_tensor), default_value=False + ) + logging.info("Filter hash table initialized") + else: + self.filter_table = tf.lookup.StaticHashTable( + tf.lookup.KeyValueTensorInitializer([""], [True]), default_value=True + ) + + def restructure(traj): + """Reformat observation and action keys, sample language instruction.""" + # Important: we use joint *position* action space -- easier to simulate! + actions = tf.concat( + ( + ( + traj["action_dict"]["joint_position"] + if action_space == DroidActionSpace.JOINT_POSITION + else traj["action_dict"]["joint_velocity"] + ), + traj["action_dict"]["gripper_position"], + ), + axis=-1, + ) + # Randomly samples one of the two exterior images in DROID during training (we only train with one at a time). + # Note: the "left" refers to the left camera in the stereo pair, we only train on the left camera. + exterior_img = tf.cond( + tf.random.uniform(shape=[]) > 0.5, + lambda: traj["observation"]["exterior_image_1_left"], + lambda: traj["observation"]["exterior_image_2_left"], + ) + wrist_img = traj["observation"]["wrist_image_left"] + # Randomly sample one of the three language instructions + instruction = tf.random.shuffle( + [traj["language_instruction"], traj["language_instruction_2"], traj["language_instruction_3"]] + )[0] + + traj_len = tf.shape(traj["action"])[0] + indices = tf.as_string(tf.range(traj_len)) + + # Data filtering: + # Compute a uniquely-identifying step ID by concatenating the recording folderpath, file path, + # and each step's time step index. This will index into the filter hash table, and if it returns true, + # then the frame passes the filter. + step_id = ( + traj["traj_metadata"]["episode_metadata"]["recording_folderpath"] + + "--" + + traj["traj_metadata"]["episode_metadata"]["file_path"] + + "--" + + indices + ) + passes_filter = self.filter_table.lookup(step_id) + + return { + "actions": actions, + "observation": { + "image": exterior_img, + "wrist_image": wrist_img, + "joint_position": traj["observation"]["joint_position"], + "gripper_position": traj["observation"]["gripper_position"], + }, + "prompt": instruction, + "step_id": step_id, + "passes_filter": passes_filter, + } + + dataset = dataset.traj_map(restructure, num_parallel_calls) + + def chunk_actions(traj): + """Splits episode into action chunks.""" + traj_len = tf.shape(traj["actions"])[0] + + # For each step in the trajectory, construct indices for the next n actions + action_chunk_indices = tf.broadcast_to( + tf.range(action_chunk_size)[None], + [traj_len, action_chunk_size], + ) + tf.broadcast_to( + tf.range(traj_len)[:, None], + [traj_len, action_chunk_size], + ) + + # Cap to length of the sequence --> final chunks will repeat the last action + # This makes sense, since we are using absolute joint + gripper position actions + action_chunk_indices = tf.minimum(action_chunk_indices, traj_len - 1) + + # Gather the actions for each chunk + traj["actions"] = tf.gather(traj["actions"], action_chunk_indices) + return traj + + dataset = dataset.traj_map(chunk_actions, num_parallel_calls) + + # Flatten: map from trajectory dataset to dataset of individual action chunks + dataset = dataset.flatten(num_parallel_calls=num_parallel_calls) + + # Filter data that doesn't pass the filter + def filter_from_dict(frame): + return frame["passes_filter"] + + dataset = dataset.filter(filter_from_dict) + + # Remove "passes_filter" key from output + def remove_passes_filter(frame): + frame.pop("passes_filter") + return frame + + dataset = dataset.map(remove_passes_filter) + + # Decode images: RLDS saves encoded images, only decode now for efficiency + def decode_images(traj): + traj["observation"]["image"] = tf.io.decode_image( + traj["observation"]["image"], expand_animations=False, dtype=tf.uint8 + ) + traj["observation"]["wrist_image"] = tf.io.decode_image( + traj["observation"]["wrist_image"], expand_animations=False, dtype=tf.uint8 + ) + return traj + + dataset = dataset.frame_map(decode_images, num_parallel_calls) + + # Shuffle, batch + dataset = dataset.shuffle(shuffle_buffer_size) + dataset = dataset.batch(batch_size) + # Note =>> Seems to reduce memory usage without affecting speed? + dataset = dataset.with_ram_budget(1) + + self.dataset = dataset + self.batch_size = batch_size + self.shuffle = shuffle + + def __iter__(self): + yield from self.dataset.as_numpy_iterator() + + def __len__(self): + # This is the approximate number of samples in DROID after filtering. + # Easier to hardcode than to iterate through the dataset and compute it. + return 20_000_000 diff --git a/capvector-pi05/src/openpi/training/misc/roboarena_config.py b/capvector-pi05/src/openpi/training/misc/roboarena_config.py new file mode 100644 index 0000000000000000000000000000000000000000..e0f366a43caf004d0f291db5af1d1678083888ca --- /dev/null +++ b/capvector-pi05/src/openpi/training/misc/roboarena_config.py @@ -0,0 +1,116 @@ +"""RoboArena baseline policy configs.""" + +from typing import TypeAlias + +import openpi.models.model as _model +import openpi.models.pi0_config as pi0_config +import openpi.models.pi0_fast as pi0_fast +import openpi.models.tokenizer as _tokenizer +import openpi.policies.droid_policy as droid_policy +import openpi.transforms as _transforms + +ModelType: TypeAlias = _model.ModelType + + +def get_roboarena_configs(): + # Import here to avoid circular imports. + from openpi.training.config import AssetsConfig + from openpi.training.config import DataConfig + from openpi.training.config import SimpleDataConfig + from openpi.training.config import TrainConfig + + return [ + # + # RoboArena DROID baseline inference configs. + # + TrainConfig( + # Trained from PaliGemma, using RT-2 / OpenVLA style binning tokenizer. + name="paligemma_binning_droid", + model=pi0_fast.Pi0FASTConfig( + action_dim=8, + action_horizon=15, + max_token_len=400, + fast_model_tokenizer=_tokenizer.BinningTokenizer, + ), + data=SimpleDataConfig( + assets=AssetsConfig(asset_id="droid"), + data_transforms=lambda model: _transforms.Group( + inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)], + outputs=[droid_policy.DroidOutputs()], + ), + base_config=DataConfig( + prompt_from_task=True, + ), + ), + ), + TrainConfig( + # Trained from PaliGemma, using FAST tokenizer (using universal FAST+ tokenizer). + name="paligemma_fast_droid", + model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15), + data=SimpleDataConfig( + assets=AssetsConfig(asset_id="droid"), + data_transforms=lambda model: _transforms.Group( + inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)], + outputs=[droid_policy.DroidOutputs()], + ), + base_config=DataConfig( + prompt_from_task=True, + ), + ), + ), + TrainConfig( + # Trained from PaliGemma, using FAST tokenizer (tokenizer trained on DROID dataset). + name="paligemma_fast_specialist_droid", + model=pi0_fast.Pi0FASTConfig( + action_dim=8, + action_horizon=15, + fast_model_tokenizer=_tokenizer.FASTTokenizer, + fast_model_tokenizer_kwargs={"fast_tokenizer_path": "KarlP/fast_droid_specialist"}, + ), + data=SimpleDataConfig( + assets=AssetsConfig(asset_id="droid"), + data_transforms=lambda model: _transforms.Group( + inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)], + outputs=[droid_policy.DroidOutputs()], + ), + base_config=DataConfig( + prompt_from_task=True, + ), + ), + ), + TrainConfig( + # Trained from PaliGemma, using FSQ tokenizer. + name="paligemma_vq_droid", + model=pi0_fast.Pi0FASTConfig( + action_dim=8, + action_horizon=15, + fast_model_tokenizer=_tokenizer.FSQTokenizer, + fast_model_tokenizer_kwargs={"fsq_tokenizer_path": "gs://openpi-assets/tokenizers/droid_fsq_tokenizer"}, + ), + data=SimpleDataConfig( + assets=AssetsConfig(asset_id="droid"), + data_transforms=lambda model: _transforms.Group( + inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)], + outputs=[droid_policy.DroidOutputs()], + ), + base_config=DataConfig( + prompt_from_task=True, + ), + ), + ), + TrainConfig( + # pi0-style diffusion / flow VLA, trained on DROID from PaliGemma. + name="paligemma_diffusion_droid", + model=pi0_config.Pi0Config(action_horizon=10, action_dim=8), + data=SimpleDataConfig( + assets=AssetsConfig(asset_id="droid"), + data_transforms=lambda model: _transforms.Group( + inputs=[droid_policy.DroidInputs(action_dim=model.action_dim)], + outputs=[droid_policy.DroidOutputs()], + ), + base_config=DataConfig( + prompt_from_task=True, + ), + ), + ), + ] diff --git a/capvector-pi05/src/openpi/training/optimizer.py b/capvector-pi05/src/openpi/training/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..a233bfd0e2d0fc62c295bd3ab82b35726a5fc545 --- /dev/null +++ b/capvector-pi05/src/openpi/training/optimizer.py @@ -0,0 +1,109 @@ +import dataclasses +from typing import Protocol, runtime_checkable + +import jax.numpy as jnp +import optax + +import openpi.shared.array_typing as at + + +@runtime_checkable +class LRScheduleConfig(Protocol): + def create(self) -> optax.Schedule: ... + + +@dataclasses.dataclass(frozen=True) +class CosineDecaySchedule(LRScheduleConfig): + """Cosine decay schedule with warmup.""" + + warmup_steps: int = 1_000 + peak_lr: float = 2.5e-5 + decay_steps: int = 30_000 + decay_lr: float = 2.5e-6 + + def create(self) -> optax.Schedule: + return optax.warmup_cosine_decay_schedule( + init_value=self.peak_lr / (self.warmup_steps + 1), + peak_value=self.peak_lr, + warmup_steps=self.warmup_steps, + decay_steps=self.decay_steps, + end_value=self.decay_lr, + ) + + +@dataclasses.dataclass(frozen=True) +class RsqrtDecaySchedule(LRScheduleConfig): + """Inverse square root decay schedule with warmup.""" + + warmup_steps: int = 1_000 + peak_lr: float = 5e-5 + timescale: float = 10_000 + + def create(self) -> optax.Schedule: + return optax.join_schedules( + [ + optax.linear_schedule( + init_value=self.peak_lr / (self.warmup_steps + 1), + end_value=self.peak_lr, + transition_steps=self.warmup_steps, + ), + lambda step: self.peak_lr / jnp.sqrt((self.timescale + step) / self.timescale), + ], + [self.warmup_steps], + ) + + +@runtime_checkable +class OptimizerConfig(Protocol): + def create( + self, + lr: optax.ScalarOrSchedule, + weight_decay_mask: at.PyTree | None = None, + ) -> optax.GradientTransformation: ... + + +@dataclasses.dataclass(frozen=True) +class AdamW(OptimizerConfig): + """AdamW optimizer.""" + + b1: float = 0.9 + b2: float = 0.95 + eps: float = 1e-8 + # Changing this to 0 can cause out-of-memory errors for some reason, so we set it to a negligible value. + weight_decay: float = 1e-10 + clip_gradient_norm: float = 1.0 + + def create( + self, + lr: optax.ScalarOrSchedule, + weight_decay_mask: at.PyTree | None = None, + ) -> optax.GradientTransformation: + tx = optax.adamw( + lr, b1=self.b1, b2=self.b2, eps=self.eps, weight_decay=self.weight_decay, mask=weight_decay_mask + ) + + return optax.chain(optax.clip_by_global_norm(self.clip_gradient_norm), tx) + + +@dataclasses.dataclass(frozen=True) +class SGD(OptimizerConfig): + """SGD optimizer.""" + + lr: float = 5e-5 + momentum: float = 0.9 + nesterov: bool = False + + def create( + self, + lr: optax.ScalarOrSchedule, + weight_decay_mask: at.PyTree | None = None, + ) -> optax.GradientTransformation: + assert weight_decay_mask is None, "Weight decay is not supported for SGD" + return optax.sgd(lr, momentum=self.momentum, nesterov=self.nesterov) + + +def create_optimizer( + optimizer: OptimizerConfig, lr_schedule: LRScheduleConfig, weight_decay_mask: at.PyTree | None = None +) -> optax.GradientTransformation: + lr = lr_schedule.create() + return optimizer.create(lr, weight_decay_mask=weight_decay_mask) diff --git a/capvector-pi05/src/openpi/training/sharding.py b/capvector-pi05/src/openpi/training/sharding.py new file mode 100644 index 0000000000000000000000000000000000000000..6b34e5e11069637028f1792a14fd8be95073dfcd --- /dev/null +++ b/capvector-pi05/src/openpi/training/sharding.py @@ -0,0 +1,102 @@ +import contextlib +import logging + +import jax +import numpy as np + +BATCH_AXIS = "batch" +FSDP_AXIS = "fsdp" +# In FSDP, we shard the data across both the batch and FSDP axes. +DATA_AXIS = (BATCH_AXIS, FSDP_AXIS) + + +class _MeshState: + active_mesh: jax.sharding.Mesh | None = None + + +def make_mesh(num_fsdp_devices: int) -> jax.sharding.Mesh: + if jax.device_count() % num_fsdp_devices != 0: + raise ValueError( + f"Number of devices {jax.device_count()} must be divisible by the number of FSDP devices {num_fsdp_devices}." + ) + mesh_shape = (jax.device_count() // num_fsdp_devices, num_fsdp_devices) + return jax.make_mesh(mesh_shape, (BATCH_AXIS, FSDP_AXIS)) + + +@contextlib.contextmanager +def set_mesh(mesh: jax.sharding.Mesh): + """Plumbing the mesh deep into the module tree is extremeley cumbersome; until the JAX team lands a better API, a + custom context manager like this one is the recommended way to maintain a reference to a global mesh. This is only used + in `activation_sharding_constraint` below.""" + if _MeshState.active_mesh is not None: + raise ValueError("Cannot nest set_mesh context managers.") + _MeshState.active_mesh = mesh + try: + yield + finally: + _MeshState.active_mesh = None + + +def activation_sharding_constraint(pytree): + if _MeshState.active_mesh is None: + return pytree + return jax.lax.with_sharding_constraint( + pytree, jax.sharding.NamedSharding(_MeshState.active_mesh, jax.sharding.PartitionSpec(DATA_AXIS)) + ) + + +def fsdp_sharding( + pytree, + mesh: jax.sharding.Mesh, + *, + min_size_mbytes: int = 4, # 4 MiB + log: bool = False, +): + """Apply FSDP sharding to a pytree of arrays based on the mesh shape. + + Args: + pytree: A pytree to be apply sharding specified by the mesh, note that only array types (eg. contains .shape attr) + will be considered for sharding. + mesh: The mesh being used for applying sharding on to pytree. + min_size_mbytes: The minimum size of the array in MiB to be considered for sharding, any array smaller than this + will be replicated. + log: If true, will log the sharding decisions for arrays that are being considered for sharding. + + Returns: + The sharded pytree. + """ + min_size_bytes = min_size_mbytes * 2**20 + + def _shard_arr(kp, array: jax.ShapeDtypeStruct): + # if fsdp is not actually going to be used, replicate everything to avoid extraneous logging + if mesh.shape[FSDP_AXIS] == 1: + return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) + # replicate scalar and vector arrays + if not hasattr(array, "shape"): + return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) + if len(array.shape) < 2: + return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) + # replicate small arrays + if (arr_size := np.prod(array.shape) * np.dtype(array.dtype).itemsize) < min_size_bytes: + return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) + + # shard matrices and larger tensors along the largest axis that is divisible by the fsdp dimension + axes = np.argsort(array.shape)[::-1] + spec = [None] * len(axes) + for i in axes: + if array.shape[i] % mesh.shape[FSDP_AXIS] == 0: + if log: + logging.info( + f"Sharding {jax.tree_util.keystr(kp)} of shape {array.shape} ({arr_size / 2**20:.2f} MiB) along axis {i}" + ) + spec[i] = FSDP_AXIS + return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(*spec)) + + # replicate if no valid sharding was found + if log: + logging.warning( + f"Could not find a valid sharding for {jax.tree_util.keystr(kp)} of shape {array.shape} with mesh of shape {mesh.shape}" + ) + return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) + + return jax.tree_util.tree_map_with_path(_shard_arr, pytree) diff --git a/capvector-pi05/src/openpi/training/utils.py b/capvector-pi05/src/openpi/training/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5593fee824510233d18c527b7a4f16470970fca3 --- /dev/null +++ b/capvector-pi05/src/openpi/training/utils.py @@ -0,0 +1,38 @@ +from collections.abc import Callable +from typing import Any + +from flax import nnx +from flax import struct +import jax +import optax + +from openpi.models import model as _model +from openpi.shared import array_typing as at + + +@at.typecheck +@struct.dataclass +class TrainState: + step: at.Int[at.ArrayLike, ""] + params: nnx.State + model_def: nnx.GraphDef[_model.BaseModel] + opt_state: optax.OptState + tx: optax.GradientTransformation = struct.field(pytree_node=False) + + ema_decay: float | None = struct.field(pytree_node=False) + ema_params: nnx.State | None = None + + +@at.typecheck +def tree_to_info(tree: at.PyTree, interp_func: Callable[[Any], str] = str) -> str: + """Converts a PyTree into a human-readable string for logging. Optionally, `interp_func` can be provided to convert + the leaf values to more meaningful strings. + """ + tree, _ = jax.tree_util.tree_flatten_with_path(tree) + return "\n".join(f"{jax.tree_util.keystr(path)}: {interp_func(value)}" for path, value in tree) + + +@at.typecheck +def array_tree_to_info(tree: at.PyTree) -> str: + """Converts a PyTree of arrays into a human-readable string for logging.""" + return tree_to_info(tree, lambda x: f"{x.shape}@{x.dtype}") diff --git a/capvector-pi05/src/openpi/training/weight_loaders.py b/capvector-pi05/src/openpi/training/weight_loaders.py new file mode 100644 index 0000000000000000000000000000000000000000..f13f3cb90475a59e76c81392c0ee8246192e8d24 --- /dev/null +++ b/capvector-pi05/src/openpi/training/weight_loaders.py @@ -0,0 +1,104 @@ +import dataclasses +import logging +import re +from typing import Protocol, runtime_checkable + +import flax.traverse_util +import numpy as np + +import openpi.models.model as _model +import openpi.shared.array_typing as at +import openpi.shared.download as download + +logger = logging.getLogger(__name__) + + +@runtime_checkable +class WeightLoader(Protocol): + def load(self, params: at.Params) -> at.Params: + """Loads the model weights. + + Args: + params: Parameters of the model. This is a nested structure of array-like objects that + represent the model's parameters. + + Returns: + Loaded parameters. The structure must be identical to `params`. If returning a subset of + the parameters the loader must merge the loaded parameters with `params`. + """ + + +@dataclasses.dataclass(frozen=True) +class NoOpWeightLoader(WeightLoader): + def load(self, params: at.Params) -> at.Params: + return params + + +@dataclasses.dataclass(frozen=True) +class CheckpointWeightLoader(WeightLoader): + """Loads an entire set of weights from a checkpoint. + + Compatible with: + trained checkpoints: + example: "./checkpoints////params" + released checkpoints: + example: "gs://openpi-assets/checkpoints//params" + """ + + params_path: str + + def load(self, params: at.Params) -> at.Params: + # We are loading np.ndarray and relying on the training code to properly convert and shard the params. + loaded_params = _model.restore_params(download.maybe_download(self.params_path), restore_type=np.ndarray) + # Add all missing LoRA weights. + return _merge_params(loaded_params, params, missing_regex=".*lora.*") + + +@dataclasses.dataclass(frozen=True) +class PaliGemmaWeightLoader(WeightLoader): + """Loads weights from the official PaliGemma checkpoint. + + This will overwrite existing weights with similar names while keeping all extra weights intact. + This allows us to support the action expert which is used by the Pi0 model. + """ + + def load(self, params: at.Params) -> at.Params: + path = download.maybe_download( + "gs://vertex-model-garden-paligemma-us/paligemma/pt_224.npz", gs={"token": "anon"} + ) + with path.open("rb") as f: + flat_params = dict(np.load(f, allow_pickle=False)) + loaded_params = {"PaliGemma": flax.traverse_util.unflatten_dict(flat_params, sep="/")["params"]} + # Add all missing weights. + return _merge_params(loaded_params, params, missing_regex=".*") + + +def _merge_params(loaded_params: at.Params, params: at.Params, *, missing_regex: str) -> at.Params: + """Merges the loaded parameters with the reference parameters. + + Args: + loaded_params: The parameters to merge. + params: The reference parameters. + missing_regex: A regex pattern for all missing keys that should be merged from the reference parameters. + + Returns: + A new dictionary with the merged parameters. + """ + flat_ref = flax.traverse_util.flatten_dict(params, sep="/") + flat_loaded = flax.traverse_util.flatten_dict(loaded_params, sep="/") + + # First, take all weights that are a subset of the reference weights. + result = {} + for k, v in flat_loaded.items(): + if k in flat_ref: + result[k] = v.astype(flat_ref[k].dtype) if v.dtype != flat_ref[k].dtype else v + + flat_loaded.clear() + + # Then, merge any missing weights as defined by the missing regex. + pattern = re.compile(missing_regex) + for k in {k for k in flat_ref if pattern.fullmatch(k)}: + if k not in result: + result[k] = flat_ref[k] + + return flax.traverse_util.unflatten_dict(result, sep="/") diff --git a/capvector-pi05/src/vggt/__init__.py b/capvector-pi05/src/vggt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/capvector-pi05/src/vggt/dependency/__init__.py b/capvector-pi05/src/vggt/dependency/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2dad12e149190cb0f746c1fbdf306614e6302714 --- /dev/null +++ b/capvector-pi05/src/vggt/dependency/__init__.py @@ -0,0 +1,3 @@ +from .track_modules.track_refine import refine_track +from .track_modules.blocks import BasicEncoder, ShallowEncoder +from .track_modules.base_track_predictor import BaseTrackerPredictor diff --git a/capvector-pi05/src/vggt/dependency/distortion.py b/capvector-pi05/src/vggt/dependency/distortion.py new file mode 100644 index 0000000000000000000000000000000000000000..3dad24807a04ddaf61917d4cb3aaf086a7c4095a --- /dev/null +++ b/capvector-pi05/src/vggt/dependency/distortion.py @@ -0,0 +1,182 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import numpy as np +from typing import Union + +ArrayLike = Union[np.ndarray, torch.Tensor] + + +def _is_numpy(x: ArrayLike) -> bool: + return isinstance(x, np.ndarray) + + +def _is_torch(x: ArrayLike) -> bool: + return isinstance(x, torch.Tensor) + + +def _ensure_torch(x: ArrayLike) -> torch.Tensor: + """Convert input to torch tensor if it's not already one.""" + if _is_numpy(x): + return torch.from_numpy(x) + elif _is_torch(x): + return x + else: + return torch.tensor(x) + + +def single_undistortion(params, tracks_normalized): + """ + Apply undistortion to the normalized tracks using the given distortion parameters once. + + Args: + params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN. + tracks_normalized (torch.Tensor or numpy.ndarray): Normalized tracks tensor of shape [batch_size, num_tracks, 2]. + + Returns: + torch.Tensor: Undistorted normalized tracks tensor. + """ + params = _ensure_torch(params) + tracks_normalized = _ensure_torch(tracks_normalized) + + u, v = tracks_normalized[..., 0].clone(), tracks_normalized[..., 1].clone() + u_undist, v_undist = apply_distortion(params, u, v) + return torch.stack([u_undist, v_undist], dim=-1) + + +def iterative_undistortion(params, tracks_normalized, max_iterations=100, max_step_norm=1e-10, rel_step_size=1e-6): + """ + Iteratively undistort the normalized tracks using the given distortion parameters. + + Args: + params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN. + tracks_normalized (torch.Tensor or numpy.ndarray): Normalized tracks tensor of shape [batch_size, num_tracks, 2]. + max_iterations (int): Maximum number of iterations for the undistortion process. + max_step_norm (float): Maximum step norm for convergence. + rel_step_size (float): Relative step size for numerical differentiation. + + Returns: + torch.Tensor: Undistorted normalized tracks tensor. + """ + params = _ensure_torch(params) + tracks_normalized = _ensure_torch(tracks_normalized) + + B, N, _ = tracks_normalized.shape + u, v = tracks_normalized[..., 0].clone(), tracks_normalized[..., 1].clone() + original_u, original_v = u.clone(), v.clone() + + eps = torch.finfo(u.dtype).eps + for idx in range(max_iterations): + u_undist, v_undist = apply_distortion(params, u, v) + dx = original_u - u_undist + dy = original_v - v_undist + + step_u = torch.clamp(torch.abs(u) * rel_step_size, min=eps) + step_v = torch.clamp(torch.abs(v) * rel_step_size, min=eps) + + J_00 = (apply_distortion(params, u + step_u, v)[0] - apply_distortion(params, u - step_u, v)[0]) / (2 * step_u) + J_01 = (apply_distortion(params, u, v + step_v)[0] - apply_distortion(params, u, v - step_v)[0]) / (2 * step_v) + J_10 = (apply_distortion(params, u + step_u, v)[1] - apply_distortion(params, u - step_u, v)[1]) / (2 * step_u) + J_11 = (apply_distortion(params, u, v + step_v)[1] - apply_distortion(params, u, v - step_v)[1]) / (2 * step_v) + + J = torch.stack([torch.stack([J_00 + 1, J_01], dim=-1), torch.stack([J_10, J_11 + 1], dim=-1)], dim=-2) + + delta = torch.linalg.solve(J, torch.stack([dx, dy], dim=-1)) + + u += delta[..., 0] + v += delta[..., 1] + + if torch.max((delta**2).sum(dim=-1)) < max_step_norm: + break + + return torch.stack([u, v], dim=-1) + + +def apply_distortion(extra_params, u, v): + """ + Applies radial or OpenCV distortion to the given 2D points. + + Args: + extra_params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN, where N can be 1, 2, or 4. + u (torch.Tensor or numpy.ndarray): Normalized x coordinates of shape Bxnum_tracks. + v (torch.Tensor or numpy.ndarray): Normalized y coordinates of shape Bxnum_tracks. + + Returns: + points2D (torch.Tensor): Distorted 2D points of shape BxNx2. + """ + extra_params = _ensure_torch(extra_params) + u = _ensure_torch(u) + v = _ensure_torch(v) + + num_params = extra_params.shape[1] + + if num_params == 1: + # Simple radial distortion + k = extra_params[:, 0] + u2 = u * u + v2 = v * v + r2 = u2 + v2 + radial = k[:, None] * r2 + du = u * radial + dv = v * radial + + elif num_params == 2: + # RadialCameraModel distortion + k1, k2 = extra_params[:, 0], extra_params[:, 1] + u2 = u * u + v2 = v * v + r2 = u2 + v2 + radial = k1[:, None] * r2 + k2[:, None] * r2 * r2 + du = u * radial + dv = v * radial + + elif num_params == 4: + # OpenCVCameraModel distortion + k1, k2, p1, p2 = (extra_params[:, 0], extra_params[:, 1], extra_params[:, 2], extra_params[:, 3]) + u2 = u * u + v2 = v * v + uv = u * v + r2 = u2 + v2 + radial = k1[:, None] * r2 + k2[:, None] * r2 * r2 + du = u * radial + 2 * p1[:, None] * uv + p2[:, None] * (r2 + 2 * u2) + dv = v * radial + 2 * p2[:, None] * uv + p1[:, None] * (r2 + 2 * v2) + else: + raise ValueError("Unsupported number of distortion parameters") + + u = u.clone() + du + v = v.clone() + dv + + return u, v + + +if __name__ == "__main__": + import random + import pycolmap + + max_diff = 0 + for i in range(1000): + # Define distortion parameters (assuming 1 parameter for simplicity) + B = random.randint(1, 500) + track_num = random.randint(100, 1000) + params = torch.rand((B, 1), dtype=torch.float32) # Batch size 1, 4 parameters + tracks_normalized = torch.rand((B, track_num, 2), dtype=torch.float32) # Batch size 1, 5 points + + # Undistort the tracks + undistorted_tracks = iterative_undistortion(params, tracks_normalized) + + for b in range(B): + pycolmap_intri = np.array([1, 0, 0, params[b].item()]) + pycam = pycolmap.Camera(model="SIMPLE_RADIAL", width=1, height=1, params=pycolmap_intri, camera_id=0) + + undistorted_tracks_pycolmap = pycam.cam_from_img(tracks_normalized[b].numpy()) + diff = (undistorted_tracks[b] - undistorted_tracks_pycolmap).abs().median() + max_diff = max(max_diff, diff) + print(f"diff: {diff}, max_diff: {max_diff}") + + import pdb + + pdb.set_trace() diff --git a/capvector-pi05/src/vggt/dependency/np_to_pycolmap.py b/capvector-pi05/src/vggt/dependency/np_to_pycolmap.py new file mode 100644 index 0000000000000000000000000000000000000000..a49c1fb856a69f329dbe4be3ea7627e4e5676f53 --- /dev/null +++ b/capvector-pi05/src/vggt/dependency/np_to_pycolmap.py @@ -0,0 +1,320 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import pycolmap +from .projection import project_3D_points_np + + +def batch_np_matrix_to_pycolmap( + points3d, + extrinsics, + intrinsics, + tracks, + image_size, + masks=None, + max_reproj_error=None, + max_points3D_val=3000, + shared_camera=False, + camera_type="SIMPLE_PINHOLE", + extra_params=None, + min_inlier_per_frame=64, + points_rgb=None, +): + """ + Convert Batched NumPy Arrays to PyCOLMAP + + Check https://github.com/colmap/pycolmap for more details about its format + + NOTE that colmap expects images/cameras/points3D to be 1-indexed + so there is a +1 offset between colmap index and batch index + + + NOTE: different from VGGSfM, this function: + 1. Use np instead of torch + 2. Frame index and camera id starts from 1 rather than 0 (to fit the format of PyCOLMAP) + """ + # points3d: Px3 + # extrinsics: Nx3x4 + # intrinsics: Nx3x3 + # tracks: NxPx2 + # masks: NxP + # image_size: 2, assume all the frames have been padded to the same size + # where N is the number of frames and P is the number of tracks + + N, P, _ = tracks.shape + assert len(extrinsics) == N + assert len(intrinsics) == N + assert len(points3d) == P + assert image_size.shape[0] == 2 + + reproj_mask = None + + if max_reproj_error is not None: + projected_points_2d, projected_points_cam = project_3D_points_np(points3d, extrinsics, intrinsics) + projected_diff = np.linalg.norm(projected_points_2d - tracks, axis=-1) + projected_points_2d[projected_points_cam[:, -1] <= 0] = 1e6 + reproj_mask = projected_diff < max_reproj_error + + if masks is not None and reproj_mask is not None: + masks = np.logical_and(masks, reproj_mask) + elif masks is not None: + masks = masks + else: + masks = reproj_mask + + assert masks is not None + + if masks.sum(1).min() < min_inlier_per_frame: + print(f"Not enough inliers per frame, skip BA.") + return None, None + + # Reconstruction object, following the format of PyCOLMAP/COLMAP + reconstruction = pycolmap.Reconstruction() + + inlier_num = masks.sum(0) + valid_mask = inlier_num >= 2 # a track is invalid if without two inliers + valid_idx = np.nonzero(valid_mask)[0] + + # Only add 3D points that have sufficient 2D points + for vidx in valid_idx: + # Use RGB colors if provided, otherwise use zeros + rgb = points_rgb[vidx] if points_rgb is not None else np.zeros(3) + reconstruction.add_point3D(points3d[vidx], pycolmap.Track(), rgb) + + num_points3D = len(valid_idx) + camera = None + # frame idx + for fidx in range(N): + # set camera + if camera is None or (not shared_camera): + pycolmap_intri = _build_pycolmap_intri(fidx, intrinsics, camera_type, extra_params) + + camera = pycolmap.Camera( + model=camera_type, width=image_size[0], height=image_size[1], params=pycolmap_intri, camera_id=fidx + 1 + ) + + # add camera + reconstruction.add_camera(camera) + + # set image + cam_from_world = pycolmap.Rigid3d( + pycolmap.Rotation3d(extrinsics[fidx][:3, :3]), extrinsics[fidx][:3, 3] + ) # Rot and Trans + + image = pycolmap.Image( + id=fidx + 1, name=f"image_{fidx + 1}", camera_id=camera.camera_id, cam_from_world=cam_from_world + ) + + points2D_list = [] + + point2D_idx = 0 + + # NOTE point3D_id start by 1 + for point3D_id in range(1, num_points3D + 1): + original_track_idx = valid_idx[point3D_id - 1] + + if (reconstruction.points3D[point3D_id].xyz < max_points3D_val).all(): + if masks[fidx][original_track_idx]: + # It seems we don't need +0.5 for BA + point2D_xy = tracks[fidx][original_track_idx] + # Please note when adding the Point2D object + # It not only requires the 2D xy location, but also the id to 3D point + points2D_list.append(pycolmap.Point2D(point2D_xy, point3D_id)) + + # add element + track = reconstruction.points3D[point3D_id].track + track.add_element(fidx + 1, point2D_idx) + point2D_idx += 1 + + assert point2D_idx == len(points2D_list) + + try: + image.points2D = pycolmap.ListPoint2D(points2D_list) + image.registered = True + except: + print(f"frame {fidx + 1} is out of BA") + image.registered = False + + # add image + reconstruction.add_image(image) + + return reconstruction, valid_mask + + +def pycolmap_to_batch_np_matrix(reconstruction, device="cpu", camera_type="SIMPLE_PINHOLE"): + """ + Convert a PyCOLMAP Reconstruction Object to batched NumPy arrays. + + Args: + reconstruction (pycolmap.Reconstruction): The reconstruction object from PyCOLMAP. + device (str): Ignored in NumPy version (kept for API compatibility). + camera_type (str): The type of camera model used (default: "SIMPLE_PINHOLE"). + + Returns: + tuple: A tuple containing points3D, extrinsics, intrinsics, and optionally extra_params. + """ + + num_images = len(reconstruction.images) + max_points3D_id = max(reconstruction.point3D_ids()) + points3D = np.zeros((max_points3D_id, 3)) + + for point3D_id in reconstruction.points3D: + points3D[point3D_id - 1] = reconstruction.points3D[point3D_id].xyz + + extrinsics = [] + intrinsics = [] + + extra_params = [] if camera_type == "SIMPLE_RADIAL" else None + + for i in range(num_images): + # Extract and append extrinsics + pyimg = reconstruction.images[i + 1] + pycam = reconstruction.cameras[pyimg.camera_id] + matrix = pyimg.cam_from_world.matrix() + extrinsics.append(matrix) + + # Extract and append intrinsics + calibration_matrix = pycam.calibration_matrix() + intrinsics.append(calibration_matrix) + + if camera_type == "SIMPLE_RADIAL": + extra_params.append(pycam.params[-1]) + + # Convert lists to NumPy arrays instead of torch tensors + extrinsics = np.stack(extrinsics) + intrinsics = np.stack(intrinsics) + + if camera_type == "SIMPLE_RADIAL": + extra_params = np.stack(extra_params) + extra_params = extra_params[:, None] + + return points3D, extrinsics, intrinsics, extra_params + + +######################################################## + + +def batch_np_matrix_to_pycolmap_wo_track( + points3d, + points_xyf, + points_rgb, + extrinsics, + intrinsics, + image_size, + shared_camera=False, + camera_type="SIMPLE_PINHOLE", +): + """ + Convert Batched NumPy Arrays to PyCOLMAP + + Different from batch_np_matrix_to_pycolmap, this function does not use tracks. + + It saves points3d to colmap reconstruction format only to serve as init for Gaussians or other nvs methods. + + Do NOT use this for BA. + """ + # points3d: Px3 + # points_xyf: Px3, with x, y coordinates and frame indices + # points_rgb: Px3, rgb colors + # extrinsics: Nx3x4 + # intrinsics: Nx3x3 + # image_size: 2, assume all the frames have been padded to the same size + # where N is the number of frames and P is the number of tracks + + N = len(extrinsics) + P = len(points3d) + + # Reconstruction object, following the format of PyCOLMAP/COLMAP + reconstruction = pycolmap.Reconstruction() + + for vidx in range(P): + reconstruction.add_point3D(points3d[vidx], pycolmap.Track(), points_rgb[vidx]) + + camera = None + # frame idx + for fidx in range(N): + # set camera + if camera is None or (not shared_camera): + pycolmap_intri = _build_pycolmap_intri(fidx, intrinsics, camera_type) + + camera = pycolmap.Camera( + model=camera_type, width=image_size[0], height=image_size[1], params=pycolmap_intri, camera_id=fidx + 1 + ) + + # add camera + reconstruction.add_camera(camera) + + # set image + cam_from_world = pycolmap.Rigid3d( + pycolmap.Rotation3d(extrinsics[fidx][:3, :3]), extrinsics[fidx][:3, 3] + ) # Rot and Trans + + image = pycolmap.Image( + id=fidx + 1, name=f"image_{fidx + 1}", camera_id=camera.camera_id, cam_from_world=cam_from_world + ) + + points2D_list = [] + + point2D_idx = 0 + + points_belong_to_fidx = points_xyf[:, 2].astype(np.int32) == fidx + points_belong_to_fidx = np.nonzero(points_belong_to_fidx)[0] + + for point3D_batch_idx in points_belong_to_fidx: + point3D_id = point3D_batch_idx + 1 + point2D_xyf = points_xyf[point3D_batch_idx] + point2D_xy = point2D_xyf[:2] + points2D_list.append(pycolmap.Point2D(point2D_xy, point3D_id)) + + # add element + track = reconstruction.points3D[point3D_id].track + track.add_element(fidx + 1, point2D_idx) + point2D_idx += 1 + + assert point2D_idx == len(points2D_list) + + try: + image.points2D = pycolmap.ListPoint2D(points2D_list) + image.registered = True + except: + print(f"frame {fidx + 1} does not have any points") + image.registered = False + + # add image + reconstruction.add_image(image) + + return reconstruction + + +def _build_pycolmap_intri(fidx, intrinsics, camera_type, extra_params=None): + """ + Helper function to get camera parameters based on camera type. + + Args: + fidx: Frame index + intrinsics: Camera intrinsic parameters + camera_type: Type of camera model + extra_params: Additional parameters for certain camera types + + Returns: + pycolmap_intri: NumPy array of camera parameters + """ + if camera_type == "PINHOLE": + pycolmap_intri = np.array( + [intrinsics[fidx][0, 0], intrinsics[fidx][1, 1], intrinsics[fidx][0, 2], intrinsics[fidx][1, 2]] + ) + elif camera_type == "SIMPLE_PINHOLE": + focal = (intrinsics[fidx][0, 0] + intrinsics[fidx][1, 1]) / 2 + pycolmap_intri = np.array([focal, intrinsics[fidx][0, 2], intrinsics[fidx][1, 2]]) + elif camera_type == "SIMPLE_RADIAL": + raise NotImplementedError("SIMPLE_RADIAL is not supported yet") + focal = (intrinsics[fidx][0, 0] + intrinsics[fidx][1, 1]) / 2 + pycolmap_intri = np.array([focal, intrinsics[fidx][0, 2], intrinsics[fidx][1, 2], extra_params[fidx][0]]) + else: + raise ValueError(f"Camera type {camera_type} is not supported yet") + + return pycolmap_intri diff --git a/capvector-pi05/src/vggt/dependency/projection.py b/capvector-pi05/src/vggt/dependency/projection.py new file mode 100644 index 0000000000000000000000000000000000000000..38fd175fe6fce096ef2bfb6b0996085c5ae44fee --- /dev/null +++ b/capvector-pi05/src/vggt/dependency/projection.py @@ -0,0 +1,228 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import numpy as np +from .distortion import apply_distortion + + +def img_from_cam_np( + intrinsics: np.ndarray, points_cam: np.ndarray, extra_params: np.ndarray | None = None, default: float = 0.0 +) -> np.ndarray: + """ + Apply intrinsics (and optional radial distortion) to camera-space points. + + Args + ---- + intrinsics : (B,3,3) camera matrix K. + points_cam : (B,3,N) homogeneous camera coords (x, y, z)ᵀ. + extra_params: (B, N) or (B, k) distortion params (k = 1,2,4) or None. + default : value used for np.nan replacement. + + Returns + ------- + points2D : (B,N,2) pixel coordinates. + """ + # 1. perspective divide ─────────────────────────────────────── + z = points_cam[:, 2:3, :] # (B,1,N) + points_cam_norm = points_cam / z # (B,3,N) + uv = points_cam_norm[:, :2, :] # (B,2,N) + + # 2. optional distortion ────────────────────────────────────── + if extra_params is not None: + uu, vv = apply_distortion(extra_params, uv[:, 0], uv[:, 1]) + uv = np.stack([uu, vv], axis=1) # (B,2,N) + + # 3. homogeneous coords then K multiplication ───────────────── + ones = np.ones_like(uv[:, :1, :]) # (B,1,N) + points_cam_h = np.concatenate([uv, ones], axis=1) # (B,3,N) + + # batched mat-mul: K · [u v 1]ᵀ + points2D_h = np.einsum("bij,bjk->bik", intrinsics, points_cam_h) # (B,3,N) + points2D = np.nan_to_num(points2D_h[:, :2, :], nan=default) # (B,2,N) + + return points2D.transpose(0, 2, 1) # (B,N,2) + + +def project_3D_points_np( + points3D: np.ndarray, + extrinsics: np.ndarray, + intrinsics: np.ndarray | None = None, + extra_params: np.ndarray | None = None, + *, + default: float = 0.0, + only_points_cam: bool = False, +): + """ + NumPy clone of ``project_3D_points``. + + Parameters + ---------- + points3D : (N,3) world-space points. + extrinsics : (B,3,4) [R|t] matrix for each of B cameras. + intrinsics : (B,3,3) K matrix (optional if you only need cam-space). + extra_params : (B,k) or (B,N) distortion parameters (k ∈ {1,2,4}) or None. + default : value used to replace NaNs. + only_points_cam : if True, skip the projection and return points_cam with points2D as None. + + Returns + ------- + (points2D, points_cam) : A tuple where points2D is (B,N,2) pixel coords or None if only_points_cam=True, + and points_cam is (B,3,N) camera-space coordinates. + """ + # ----- 0. prep sizes ----------------------------------------------------- + N = points3D.shape[0] # #points + B = extrinsics.shape[0] # #cameras + + # ----- 1. world → homogeneous ------------------------------------------- + w_h = np.ones((N, 1), dtype=points3D.dtype) + points3D_h = np.concatenate([points3D, w_h], axis=1) # (N,4) + + # broadcast to every camera (no actual copying with np.broadcast_to) ------ + points3D_h_B = np.broadcast_to(points3D_h, (B, N, 4)) # (B,N,4) + + # ----- 2. apply extrinsics (camera frame) ------------------------------ + # X_cam = E · X_hom + # einsum: E_(b i j) · X_(b n j) → (b n i) + points_cam = np.einsum("bij,bnj->bni", extrinsics, points3D_h_B) # (B,N,3) + points_cam = points_cam.transpose(0, 2, 1) # (B,3,N) + + if only_points_cam: + return None, points_cam + + # ----- 3. intrinsics + distortion --------------------------------------- + if intrinsics is None: + raise ValueError("`intrinsics` must be provided unless only_points_cam=True") + + points2D = img_from_cam_np(intrinsics, points_cam, extra_params=extra_params, default=default) + + return points2D, points_cam + + +def project_3D_points(points3D, extrinsics, intrinsics=None, extra_params=None, default=0, only_points_cam=False): + """ + Transforms 3D points to 2D using extrinsic and intrinsic parameters. + Args: + points3D (torch.Tensor): 3D points of shape Px3. + extrinsics (torch.Tensor): Extrinsic parameters of shape Bx3x4. + intrinsics (torch.Tensor): Intrinsic parameters of shape Bx3x3. + extra_params (torch.Tensor): Extra parameters of shape BxN, used for radial distortion. + default (float): Default value to replace NaNs. + only_points_cam (bool): If True, skip the projection and return points2D as None. + + Returns: + tuple: (points2D, points_cam) where points2D is of shape BxNx2 or None if only_points_cam=True, + and points_cam is of shape Bx3xN. + """ + with torch.cuda.amp.autocast(dtype=torch.double): + N = points3D.shape[0] # Number of points + B = extrinsics.shape[0] # Batch size, i.e., number of cameras + points3D_homogeneous = torch.cat([points3D, torch.ones_like(points3D[..., 0:1])], dim=1) # Nx4 + # Reshape for batch processing + points3D_homogeneous = points3D_homogeneous.unsqueeze(0).expand(B, -1, -1) # BxNx4 + + # Step 1: Apply extrinsic parameters + # Transform 3D points to camera coordinate system for all cameras + points_cam = torch.bmm(extrinsics, points3D_homogeneous.transpose(-1, -2)) + + if only_points_cam: + return None, points_cam + + # Step 2: Apply intrinsic parameters and (optional) distortion + points2D = img_from_cam(intrinsics, points_cam, extra_params, default) + + return points2D, points_cam + + +def img_from_cam(intrinsics, points_cam, extra_params=None, default=0.0): + """ + Applies intrinsic parameters and optional distortion to the given 3D points. + + Args: + intrinsics (torch.Tensor): Intrinsic camera parameters of shape Bx3x3. + points_cam (torch.Tensor): 3D points in camera coordinates of shape Bx3xN. + extra_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4. + default (float, optional): Default value to replace NaNs in the output. + + Returns: + points2D (torch.Tensor): 2D points in pixel coordinates of shape BxNx2. + """ + + # Normalize by the third coordinate (homogeneous division) + points_cam = points_cam / points_cam[:, 2:3, :] + # Extract uv + uv = points_cam[:, :2, :] + + # Apply distortion if extra_params are provided + if extra_params is not None: + uu, vv = apply_distortion(extra_params, uv[:, 0], uv[:, 1]) + uv = torch.stack([uu, vv], dim=1) + + # Prepare points_cam for batch matrix multiplication + points_cam_homo = torch.cat((uv, torch.ones_like(uv[:, :1, :])), dim=1) # Bx3xN + # Apply intrinsic parameters using batch matrix multiplication + points2D_homo = torch.bmm(intrinsics, points_cam_homo) # Bx3xN + + # Extract x and y coordinates + points2D = points2D_homo[:, :2, :] # Bx2xN + + # Replace NaNs with default value + points2D = torch.nan_to_num(points2D, nan=default) + + return points2D.transpose(1, 2) # BxNx2 + + +if __name__ == "__main__": + # Set up example input + B, N = 24, 10240 + + for _ in range(100): + points3D = np.random.rand(N, 3).astype(np.float64) + extrinsics = np.random.rand(B, 3, 4).astype(np.float64) + intrinsics = np.random.rand(B, 3, 3).astype(np.float64) + + # Convert to torch tensors + points3D_torch = torch.tensor(points3D) + extrinsics_torch = torch.tensor(extrinsics) + intrinsics_torch = torch.tensor(intrinsics) + + # Run NumPy implementation + points2D_np, points_cam_np = project_3D_points_np(points3D, extrinsics, intrinsics) + + # Run torch implementation + points2D_torch, points_cam_torch = project_3D_points(points3D_torch, extrinsics_torch, intrinsics_torch) + + # Convert torch output to numpy + points2D_torch_np = points2D_torch.detach().numpy() + points_cam_torch_np = points_cam_torch.detach().numpy() + + # Compute difference + diff = np.abs(points2D_np - points2D_torch_np) + print("Difference between NumPy and PyTorch implementations:") + print(diff) + + # Check max error + max_diff = np.max(diff) + print(f"Maximum difference: {max_diff}") + + if np.allclose(points2D_np, points2D_torch_np, atol=1e-6): + print("Implementations match closely.") + else: + print("Significant differences detected.") + + if points_cam_np is not None: + points_cam_diff = np.abs(points_cam_np - points_cam_torch_np) + print("Difference between NumPy and PyTorch camera-space coordinates:") + print(points_cam_diff) + + # Check max error + max_cam_diff = np.max(points_cam_diff) + print(f"Maximum camera-space coordinate difference: {max_cam_diff}") + + if np.allclose(points_cam_np, points_cam_torch_np, atol=1e-6): + print("Camera-space coordinates match closely.") + else: + print("Significant differences detected in camera-space coordinates.") diff --git a/capvector-pi05/src/vggt/dependency/track_modules/__init__.py b/capvector-pi05/src/vggt/dependency/track_modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/capvector-pi05/src/vggt/dependency/track_modules/base_track_predictor.py b/capvector-pi05/src/vggt/dependency/track_modules/base_track_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..27aa7092691f4526f4a93ca76170783b1e71c335 --- /dev/null +++ b/capvector-pi05/src/vggt/dependency/track_modules/base_track_predictor.py @@ -0,0 +1,190 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from einops import rearrange, repeat + +from .blocks import EfficientUpdateFormer, CorrBlock +from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed + + +class BaseTrackerPredictor(nn.Module): + def __init__( + self, + stride=4, + corr_levels=5, + corr_radius=4, + latent_dim=128, + hidden_size=384, + use_spaceatt=True, + depth=6, + fine=False, + ): + super(BaseTrackerPredictor, self).__init__() + """ + The base template to create a track predictor + + Modified from https://github.com/facebookresearch/co-tracker/ + """ + + self.stride = stride + self.latent_dim = latent_dim + self.corr_levels = corr_levels + self.corr_radius = corr_radius + self.hidden_size = hidden_size + self.fine = fine + + self.flows_emb_dim = latent_dim // 2 + self.transformer_dim = self.corr_levels * (self.corr_radius * 2 + 1) ** 2 + self.latent_dim * 2 + + if self.fine: + # TODO this is the old dummy code, will remove this when we train next model + self.transformer_dim += 4 if self.transformer_dim % 2 == 0 else 5 + else: + self.transformer_dim += (4 - self.transformer_dim % 4) % 4 + + space_depth = depth if use_spaceatt else 0 + time_depth = depth + + self.updateformer = EfficientUpdateFormer( + space_depth=space_depth, + time_depth=time_depth, + input_dim=self.transformer_dim, + hidden_size=self.hidden_size, + output_dim=self.latent_dim + 2, + mlp_ratio=4.0, + add_space_attn=use_spaceatt, + ) + + self.norm = nn.GroupNorm(1, self.latent_dim) + + # A linear layer to update track feats at each iteration + self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU()) + + if not self.fine: + self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) + + def forward(self, query_points, fmaps=None, iters=4, return_feat=False, down_ratio=1): + """ + query_points: B x N x 2, the number of batches, tracks, and xy + fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension. + note HH and WW is the size of feature maps instead of original images + """ + B, N, D = query_points.shape + B, S, C, HH, WW = fmaps.shape + + assert D == 2 + + # Scale the input query_points because we may downsample the images + # by down_ratio or self.stride + # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map + # its query_points should be query_points/4 + if down_ratio > 1: + query_points = query_points / float(down_ratio) + query_points = query_points / float(self.stride) + + # Init with coords as the query points + # It means the search will start from the position of query points at the reference frames + coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1) + + # Sample/extract the features of the query points in the query frame + query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0]) + + # init track feats by query feats + track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C + # back up the init coords + coords_backup = coords.clone() + + # Construct the correlation block + + fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius) + + coord_preds = [] + + # Iterative Refinement + for itr in range(iters): + # Detach the gradients from the last iteration + # (in my experience, not very important for performance) + coords = coords.detach() + + # Compute the correlation (check the implementation of CorrBlock) + + fcorr_fn.corr(track_feats) + fcorrs = fcorr_fn.sample(coords) # B, S, N, corrdim + + corrdim = fcorrs.shape[3] + + fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corrdim) + + # Movement of current coords relative to query points + flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2) + + flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False) + + # (In my trials, it is also okay to just add the flows_emb instead of concat) + flows_emb = torch.cat([flows_emb, flows], dim=-1) + + track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim) + + # Concatenate them as the input for the transformers + transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2) + + if transformer_input.shape[2] < self.transformer_dim: + # pad the features to match the dimension + pad_dim = self.transformer_dim - transformer_input.shape[2] + pad = torch.zeros_like(flows_emb[..., 0:pad_dim]) + transformer_input = torch.cat([transformer_input, pad], dim=2) + + # 2D positional embed + # TODO: this can be much simplified + pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device) + sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0]) + sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1) + + x = transformer_input + sampled_pos_emb + + # B, N, S, C + x = rearrange(x, "(b n) s d -> b n s d", b=B) + + # Compute the delta coordinates and delta track features + delta = self.updateformer(x) + # BN, S, C + delta = rearrange(delta, " b n s d -> (b n) s d", b=B) + delta_coords_ = delta[:, :, :2] + delta_feats_ = delta[:, :, 2:] + + track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim) + delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim) + + # Update the track features + track_feats_ = self.ffeat_updater(self.norm(delta_feats_)) + track_feats_ + track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC + + # B x S x N x 2 + coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3) + + # Force coord0 as query + # because we assume the query points should not be changed + coords[:, 0] = coords_backup[:, 0] + + # The predicted tracks are in the original image scale + if down_ratio > 1: + coord_preds.append(coords * self.stride * down_ratio) + else: + coord_preds.append(coords * self.stride) + + # B, S, N + if not self.fine: + vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N) + vis_e = torch.sigmoid(vis_e) + else: + vis_e = None + + if return_feat: + return coord_preds, vis_e, track_feats, query_track_feat + else: + return coord_preds, vis_e diff --git a/capvector-pi05/src/vggt/dependency/track_modules/blocks.py b/capvector-pi05/src/vggt/dependency/track_modules/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..513f96836644ff27e714cba517510d2dd7e702df --- /dev/null +++ b/capvector-pi05/src/vggt/dependency/track_modules/blocks.py @@ -0,0 +1,329 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +# Modified from https://github.com/facebookresearch/co-tracker/ + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +from typing import Callable +import collections +from torch import Tensor +from itertools import repeat + +from .utils import bilinear_sampler + +from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock + + +class BasicEncoder(nn.Module): + def __init__(self, input_dim=3, output_dim=128, stride=4): + super(BasicEncoder, self).__init__() + + self.stride = stride + self.norm_fn = "instance" + self.in_planes = output_dim // 2 + + self.norm1 = nn.InstanceNorm2d(self.in_planes) + self.norm2 = nn.InstanceNorm2d(output_dim * 2) + + self.conv1 = nn.Conv2d(input_dim, self.in_planes, kernel_size=7, stride=2, padding=3, padding_mode="zeros") + self.relu1 = nn.ReLU(inplace=True) + self.layer1 = self._make_layer(output_dim // 2, stride=1) + self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2) + self.layer3 = self._make_layer(output_dim, stride=2) + self.layer4 = self._make_layer(output_dim, stride=2) + + self.conv2 = nn.Conv2d( + output_dim * 3 + output_dim // 4, output_dim * 2, kernel_size=3, padding=1, padding_mode="zeros" + ) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.InstanceNorm2d)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + _, _, H, W = x.shape + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + a = self.layer1(x) + b = self.layer2(a) + c = self.layer3(b) + d = self.layer4(c) + + a = _bilinear_intepolate(a, self.stride, H, W) + b = _bilinear_intepolate(b, self.stride, H, W) + c = _bilinear_intepolate(c, self.stride, H, W) + d = _bilinear_intepolate(d, self.stride, H, W) + + x = self.conv2(torch.cat([a, b, c, d], dim=1)) + x = self.norm2(x) + x = self.relu2(x) + x = self.conv3(x) + return x + + +class ShallowEncoder(nn.Module): + def __init__(self, input_dim=3, output_dim=32, stride=1, norm_fn="instance"): + super(ShallowEncoder, self).__init__() + self.stride = stride + self.norm_fn = norm_fn + self.in_planes = output_dim + + if self.norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes) + self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2) + elif self.norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(self.in_planes) + self.norm2 = nn.BatchNorm2d(output_dim * 2) + elif self.norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(self.in_planes) + self.norm2 = nn.InstanceNorm2d(output_dim * 2) + elif self.norm_fn == "none": + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(input_dim, self.in_planes, kernel_size=3, stride=2, padding=1, padding_mode="zeros") + self.relu1 = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(output_dim, stride=2) + + self.layer2 = self._make_layer(output_dim, stride=2) + self.conv2 = nn.Conv2d(output_dim, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + self.in_planes = dim + + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + return layer1 + + def forward(self, x): + _, _, H, W = x.shape + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + tmp = self.layer1(x) + x = x + F.interpolate(tmp, (x.shape[-2:]), mode="bilinear", align_corners=True) + tmp = self.layer2(tmp) + x = x + F.interpolate(tmp, (x.shape[-2:]), mode="bilinear", align_corners=True) + tmp = None + x = self.conv2(x) + x + + x = F.interpolate(x, (H // self.stride, W // self.stride), mode="bilinear", align_corners=True) + + return x + + +def _bilinear_intepolate(x, stride, H, W): + return F.interpolate(x, (H // stride, W // stride), mode="bilinear", align_corners=True) + + +class EfficientUpdateFormer(nn.Module): + """ + Transformer model that updates track estimates. + """ + + def __init__( + self, + space_depth=6, + time_depth=6, + input_dim=320, + hidden_size=384, + num_heads=8, + output_dim=130, + mlp_ratio=4.0, + add_space_attn=True, + num_virtual_tracks=64, + ): + super().__init__() + + self.out_channels = 2 + self.num_heads = num_heads + self.hidden_size = hidden_size + self.add_space_attn = add_space_attn + self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) + self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True) + self.num_virtual_tracks = num_virtual_tracks + + if self.add_space_attn: + self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size)) + else: + self.virual_tracks = None + + self.time_blocks = nn.ModuleList( + [ + AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention) + for _ in range(time_depth) + ] + ) + + if add_space_attn: + self.space_virtual_blocks = nn.ModuleList( + [ + AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention) + for _ in range(space_depth) + ] + ) + self.space_point2virtual_blocks = nn.ModuleList( + [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] + ) + self.space_virtual2point_blocks = nn.ModuleList( + [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] + ) + assert len(self.time_blocks) >= len(self.space_virtual2point_blocks) + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + def forward(self, input_tensor, mask=None): + tokens = self.input_transform(input_tensor) + + init_tokens = tokens + + B, _, T, _ = tokens.shape + + if self.add_space_attn: + virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1) + tokens = torch.cat([tokens, virtual_tokens], dim=1) + + _, N, _, _ = tokens.shape + + j = 0 + for i in range(len(self.time_blocks)): + time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C + time_tokens = self.time_blocks[i](time_tokens) + + tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C + if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0): + space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C + point_tokens = space_tokens[:, : N - self.num_virtual_tracks] + virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :] + + virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask) + virtual_tokens = self.space_virtual_blocks[j](virtual_tokens) + point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask) + space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1) + tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C + j += 1 + + if self.add_space_attn: + tokens = tokens[:, : N - self.num_virtual_tracks] + + tokens = tokens + init_tokens + + flow = self.flow_head(tokens) + return flow + + +class CorrBlock: + def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"): + B, S, C, H, W = fmaps.shape + self.S, self.C, self.H, self.W = S, C, H, W + self.padding_mode = padding_mode + self.num_levels = num_levels + self.radius = radius + self.fmaps_pyramid = [] + self.multiple_track_feats = multiple_track_feats + + self.fmaps_pyramid.append(fmaps) + for i in range(self.num_levels - 1): + fmaps_ = fmaps.reshape(B * S, C, H, W) + fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2) + _, _, H, W = fmaps_.shape + fmaps = fmaps_.reshape(B, S, C, H, W) + self.fmaps_pyramid.append(fmaps) + + def sample(self, coords): + r = self.radius + B, S, N, D = coords.shape + assert D == 2 + + H, W = self.H, self.W + out_pyramid = [] + for i in range(self.num_levels): + corrs = self.corrs_pyramid[i] # B, S, N, H, W + *_, H, W = corrs.shape + + dx = torch.linspace(-r, r, 2 * r + 1) + dy = torch.linspace(-r, r, 2 * r + 1) + delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device) + + centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i + delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) + coords_lvl = centroid_lvl + delta_lvl + + corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode) + corrs = corrs.view(B, S, N, -1) + + out_pyramid.append(corrs) + + out = torch.cat(out_pyramid, dim=-1).contiguous() # B, S, N, LRR*2 + return out + + def corr(self, targets): + B, S, N, C = targets.shape + if self.multiple_track_feats: + targets_split = targets.split(C // self.num_levels, dim=-1) + B, S, N, C = targets_split[0].shape + + assert C == self.C + assert S == self.S + + fmap1 = targets + + self.corrs_pyramid = [] + for i, fmaps in enumerate(self.fmaps_pyramid): + *_, H, W = fmaps.shape + fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W) + if self.multiple_track_feats: + fmap1 = targets_split[i] + corrs = torch.matmul(fmap1, fmap2s) + corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W + corrs = corrs / torch.sqrt(torch.tensor(C).float()) + self.corrs_pyramid.append(corrs) diff --git a/capvector-pi05/src/vggt/dependency/track_modules/modules.py b/capvector-pi05/src/vggt/dependency/track_modules/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..e1a5cdb57239a9e40f8cf2e208622c06f6492004 --- /dev/null +++ b/capvector-pi05/src/vggt/dependency/track_modules/modules.py @@ -0,0 +1,202 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +from typing import Callable +import collections +from torch import Tensor +from itertools import repeat + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +to_2tuple = _ntuple(2) + + +class ResidualBlock(nn.Module): + """ + ResidualBlock: construct a block of two conv layers with residual connections + """ + + def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=kernel_size, padding=1, stride=stride, padding_mode="zeros" + ) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, padding=1, padding_mode="zeros") + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + else: + raise NotImplementedError + + if stride == 1: + self.downsample = None + else: + self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class AttnBlock(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + attn_class: Callable[..., nn.Module] = nn.MultiheadAttention, + mlp_ratio=4.0, + **block_kwargs, + ): + """ + Self attention block + """ + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs) + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) + + def forward(self, x, mask=None): + # Prepare the mask for PyTorch's attention (it expects a different format) + # attn_mask = mask if mask is not None else None + # Normalize before attention + x = self.norm1(x) + + # PyTorch's MultiheadAttention returns attn_output, attn_output_weights + # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask) + + attn_output, _ = self.attn(x, x, x) + + # Add & Norm + x = x + attn_output + x = x + self.mlp(self.norm2(x)) + return x + + +class CrossAttnBlock(nn.Module): + def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs): + """ + Cross attention block + """ + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.norm_context = nn.LayerNorm(hidden_size) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.cross_attn = nn.MultiheadAttention( + embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs + ) + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) + + def forward(self, x, context, mask=None): + # Normalize inputs + x = self.norm1(x) + context = self.norm_context(context) + + # Apply cross attention + # Note: nn.MultiheadAttention returns attn_output, attn_output_weights + attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask) + + # Add & Norm + x = x + attn_output + x = x + self.mlp(self.norm2(x)) + return x diff --git a/capvector-pi05/src/vggt/dependency/track_modules/track_refine.py b/capvector-pi05/src/vggt/dependency/track_modules/track_refine.py new file mode 100644 index 0000000000000000000000000000000000000000..461572c2096f2fb69de45cef9ba401465ebd084f --- /dev/null +++ b/capvector-pi05/src/vggt/dependency/track_modules/track_refine.py @@ -0,0 +1,419 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +from torch import nn, einsum +from einops import rearrange, repeat +from einops.layers.torch import Rearrange, Reduce + +from PIL import Image +import os +from typing import Union, Tuple + + +def refine_track( + images, fine_fnet, fine_tracker, coarse_pred, compute_score=False, pradius=15, sradius=2, fine_iters=6, chunk=40960 +): + """ + Refines the tracking of images using a fine track predictor and a fine feature network. + Check https://arxiv.org/abs/2312.04563 for more details. + + Args: + images (torch.Tensor): The images to be tracked. + fine_fnet (nn.Module): The fine feature network. + fine_tracker (nn.Module): The fine track predictor. + coarse_pred (torch.Tensor): The coarse predictions of tracks. + compute_score (bool, optional): Whether to compute the score. Defaults to False. + pradius (int, optional): The radius of a patch. Defaults to 15. + sradius (int, optional): The search radius. Defaults to 2. + + Returns: + torch.Tensor: The refined tracks. + torch.Tensor, optional: The score. + """ + + # coarse_pred shape: BxSxNx2, + # where B is the batch, S is the video/images length, and N is the number of tracks + # now we are going to extract patches with the center at coarse_pred + # Please note that the last dimension indicates x and y, and hence has a dim number of 2 + B, S, N, _ = coarse_pred.shape + _, _, _, H, W = images.shape + + # Given the raidus of a patch, compute the patch size + psize = pradius * 2 + 1 + + # Note that we assume the first frame is the query frame + # so the 2D locations of the first frame are the query points + query_points = coarse_pred[:, 0] + + # Given 2D positions, we can use grid_sample to extract patches + # but it takes too much memory. + # Instead, we use the floored track xy to sample patches. + + # For example, if the query point xy is (128.16, 252.78), + # and the patch size is (31, 31), + # our goal is to extract the content of a rectangle + # with left top: (113.16, 237.78) + # and right bottom: (143.16, 267.78). + # However, we record the floored left top: (113, 237) + # and the offset (0.16, 0.78) + # Then what we need is just unfolding the images like in CNN, + # picking the content at [(113, 237), (143, 267)]. + # Such operations are highly optimized at pytorch + # (well if you really want to use interpolation, check the function extract_glimpse() below) + + with torch.no_grad(): + content_to_extract = images.reshape(B * S, 3, H, W) + C_in = content_to_extract.shape[1] + + # Please refer to https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html + # for the detailed explanation of unfold() + # Here it runs sliding windows (psize x psize) to build patches + # The shape changes from + # (B*S)x C_in x H x W to (B*S)x C_in x H_new x W_new x Psize x Psize + # where Psize is the size of patch + content_to_extract = content_to_extract.unfold(2, psize, 1).unfold(3, psize, 1) + + # Floor the coarse predictions to get integers and save the fractional/decimal + track_int = coarse_pred.floor().int() + track_frac = coarse_pred - track_int + + # Note the points represent the center of patches + # now we get the location of the top left corner of patches + # because the ouput of pytorch unfold are indexed by top left corner + topleft = track_int - pradius + topleft_BSN = topleft.clone() + + # clamp the values so that we will not go out of indexes + # NOTE: (VERY IMPORTANT: This operation ASSUMES H=W). + # You need to seperately clamp x and y if H!=W + topleft = topleft.clamp(0, H - psize) + + # Reshape from BxSxNx2 -> (B*S)xNx2 + topleft = topleft.reshape(B * S, N, 2) + + # Prepare batches for indexing, shape: (B*S)xN + batch_indices = torch.arange(B * S)[:, None].expand(-1, N).to(content_to_extract.device) + + # extracted_patches: (B*S) x N x C_in x Psize x Psize + extracted_patches = content_to_extract[batch_indices, :, topleft[..., 1], topleft[..., 0]] + + if chunk < 0: + # Extract image patches based on top left corners + # Feed patches to fine fent for features + patch_feat = fine_fnet(extracted_patches.reshape(B * S * N, C_in, psize, psize)) + else: + patches = extracted_patches.reshape(B * S * N, C_in, psize, psize) + + patch_feat_list = [] + for p in torch.split(patches, chunk): + patch_feat_list += [fine_fnet(p)] + patch_feat = torch.cat(patch_feat_list, 0) + + C_out = patch_feat.shape[1] + + # Refine the coarse tracks by fine_tracker + # reshape back to B x S x N x C_out x Psize x Psize + patch_feat = patch_feat.reshape(B, S, N, C_out, psize, psize) + patch_feat = rearrange(patch_feat, "b s n c p q -> (b n) s c p q") + + # Prepare for the query points for fine tracker + # They are relative to the patch left top corner, + # instead of the image top left corner now + # patch_query_points: N x 1 x 2 + # only 1 here because for each patch we only have 1 query point + patch_query_points = track_frac[:, 0] + pradius + patch_query_points = patch_query_points.reshape(B * N, 2).unsqueeze(1) + + # Feed the PATCH query points and tracks into fine tracker + fine_pred_track_lists, _, _, query_point_feat = fine_tracker( + query_points=patch_query_points, fmaps=patch_feat, iters=fine_iters, return_feat=True + ) + + # relative the patch top left + fine_pred_track = fine_pred_track_lists[-1].clone() + + # From (relative to the patch top left) to (relative to the image top left) + for idx in range(len(fine_pred_track_lists)): + fine_level = rearrange(fine_pred_track_lists[idx], "(b n) s u v -> b s n u v", b=B, n=N) + fine_level = fine_level.squeeze(-2) + fine_level = fine_level + topleft_BSN + fine_pred_track_lists[idx] = fine_level + + # relative to the image top left + refined_tracks = fine_pred_track_lists[-1].clone() + refined_tracks[:, 0] = query_points + + score = None + + if compute_score: + score = compute_score_fn(query_point_feat, patch_feat, fine_pred_track, sradius, psize, B, N, S, C_out) + + return refined_tracks, score + + +def refine_track_v0( + images, fine_fnet, fine_tracker, coarse_pred, compute_score=False, pradius=15, sradius=2, fine_iters=6 +): + """ + COPIED FROM VGGSfM + + Refines the tracking of images using a fine track predictor and a fine feature network. + Check https://arxiv.org/abs/2312.04563 for more details. + + Args: + images (torch.Tensor): The images to be tracked. + fine_fnet (nn.Module): The fine feature network. + fine_tracker (nn.Module): The fine track predictor. + coarse_pred (torch.Tensor): The coarse predictions of tracks. + compute_score (bool, optional): Whether to compute the score. Defaults to False. + pradius (int, optional): The radius of a patch. Defaults to 15. + sradius (int, optional): The search radius. Defaults to 2. + + Returns: + torch.Tensor: The refined tracks. + torch.Tensor, optional: The score. + """ + + # coarse_pred shape: BxSxNx2, + # where B is the batch, S is the video/images length, and N is the number of tracks + # now we are going to extract patches with the center at coarse_pred + # Please note that the last dimension indicates x and y, and hence has a dim number of 2 + B, S, N, _ = coarse_pred.shape + _, _, _, H, W = images.shape + + # Given the raidus of a patch, compute the patch size + psize = pradius * 2 + 1 + + # Note that we assume the first frame is the query frame + # so the 2D locations of the first frame are the query points + query_points = coarse_pred[:, 0] + + # Given 2D positions, we can use grid_sample to extract patches + # but it takes too much memory. + # Instead, we use the floored track xy to sample patches. + + # For example, if the query point xy is (128.16, 252.78), + # and the patch size is (31, 31), + # our goal is to extract the content of a rectangle + # with left top: (113.16, 237.78) + # and right bottom: (143.16, 267.78). + # However, we record the floored left top: (113, 237) + # and the offset (0.16, 0.78) + # Then what we need is just unfolding the images like in CNN, + # picking the content at [(113, 237), (143, 267)]. + # Such operations are highly optimized at pytorch + # (well if you really want to use interpolation, check the function extract_glimpse() below) + + with torch.no_grad(): + content_to_extract = images.reshape(B * S, 3, H, W) + C_in = content_to_extract.shape[1] + + # Please refer to https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html + # for the detailed explanation of unfold() + # Here it runs sliding windows (psize x psize) to build patches + # The shape changes from + # (B*S)x C_in x H x W to (B*S)x C_in x H_new x W_new x Psize x Psize + # where Psize is the size of patch + content_to_extract = content_to_extract.unfold(2, psize, 1).unfold(3, psize, 1) + + # Floor the coarse predictions to get integers and save the fractional/decimal + track_int = coarse_pred.floor().int() + track_frac = coarse_pred - track_int + + # Note the points represent the center of patches + # now we get the location of the top left corner of patches + # because the ouput of pytorch unfold are indexed by top left corner + topleft = track_int - pradius + topleft_BSN = topleft.clone() + + # clamp the values so that we will not go out of indexes + # NOTE: (VERY IMPORTANT: This operation ASSUMES H=W). + # You need to seperately clamp x and y if H!=W + topleft = topleft.clamp(0, H - psize) + + # Reshape from BxSxNx2 -> (B*S)xNx2 + topleft = topleft.reshape(B * S, N, 2) + + # Prepare batches for indexing, shape: (B*S)xN + batch_indices = torch.arange(B * S)[:, None].expand(-1, N).to(content_to_extract.device) + + # Extract image patches based on top left corners + # extracted_patches: (B*S) x N x C_in x Psize x Psize + extracted_patches = content_to_extract[batch_indices, :, topleft[..., 1], topleft[..., 0]] + + # Feed patches to fine fent for features + patch_feat = fine_fnet(extracted_patches.reshape(B * S * N, C_in, psize, psize)) + + C_out = patch_feat.shape[1] + + # Refine the coarse tracks by fine_tracker + + # reshape back to B x S x N x C_out x Psize x Psize + patch_feat = patch_feat.reshape(B, S, N, C_out, psize, psize) + patch_feat = rearrange(patch_feat, "b s n c p q -> (b n) s c p q") + + # Prepare for the query points for fine tracker + # They are relative to the patch left top corner, + # instead of the image top left corner now + # patch_query_points: N x 1 x 2 + # only 1 here because for each patch we only have 1 query point + patch_query_points = track_frac[:, 0] + pradius + patch_query_points = patch_query_points.reshape(B * N, 2).unsqueeze(1) + + # Feed the PATCH query points and tracks into fine tracker + fine_pred_track_lists, _, _, query_point_feat = fine_tracker( + query_points=patch_query_points, fmaps=patch_feat, iters=fine_iters, return_feat=True + ) + + # relative the patch top left + fine_pred_track = fine_pred_track_lists[-1].clone() + + # From (relative to the patch top left) to (relative to the image top left) + for idx in range(len(fine_pred_track_lists)): + fine_level = rearrange(fine_pred_track_lists[idx], "(b n) s u v -> b s n u v", b=B, n=N) + fine_level = fine_level.squeeze(-2) + fine_level = fine_level + topleft_BSN + fine_pred_track_lists[idx] = fine_level + + # relative to the image top left + refined_tracks = fine_pred_track_lists[-1].clone() + refined_tracks[:, 0] = query_points + + score = None + + if compute_score: + score = compute_score_fn(query_point_feat, patch_feat, fine_pred_track, sradius, psize, B, N, S, C_out) + + return refined_tracks, score + + +################################## NOTE: NOT USED ################################## + + +def compute_score_fn(query_point_feat, patch_feat, fine_pred_track, sradius, psize, B, N, S, C_out): + """ + Compute the scores, i.e., the standard deviation of the 2D similarity heatmaps, + given the query point features and reference frame feature maps + """ + + from kornia.utils.grid import create_meshgrid + from kornia.geometry.subpix import dsnt + + # query_point_feat initial shape: B x N x C_out, + # query_point_feat indicates the feat at the coorponsing query points + # Therefore we don't have S dimension here + query_point_feat = query_point_feat.reshape(B, N, C_out) + # reshape and expand to B x (S-1) x N x C_out + query_point_feat = query_point_feat.unsqueeze(1).expand(-1, S - 1, -1, -1) + # and reshape to (B*(S-1)*N) x C_out + query_point_feat = query_point_feat.reshape(B * (S - 1) * N, C_out) + + # Radius and size for computing the score + ssize = sradius * 2 + 1 + + # Reshape, you know it, so many reshaping operations + patch_feat = rearrange(patch_feat, "(b n) s c p q -> b s n c p q", b=B, n=N) + + # Again, we unfold the patches to smaller patches + # so that we can then focus on smaller patches + # patch_feat_unfold shape: + # B x S x N x C_out x (psize - 2*sradius) x (psize - 2*sradius) x ssize x ssize + # well a bit scary, but actually not + patch_feat_unfold = patch_feat.unfold(4, ssize, 1).unfold(5, ssize, 1) + + # Do the same stuffs above, i.e., the same as extracting patches + fine_prediction_floor = fine_pred_track.floor().int() + fine_level_floor_topleft = fine_prediction_floor - sradius + + # Clamp to ensure the smaller patch is valid + fine_level_floor_topleft = fine_level_floor_topleft.clamp(0, psize - ssize) + fine_level_floor_topleft = fine_level_floor_topleft.squeeze(2) + + # Prepare the batch indices and xy locations + + batch_indices_score = torch.arange(B)[:, None, None].expand(-1, S, N) # BxSxN + batch_indices_score = batch_indices_score.reshape(-1).to(patch_feat_unfold.device) # B*S*N + y_indices = fine_level_floor_topleft[..., 0].flatten() # Flatten H indices + x_indices = fine_level_floor_topleft[..., 1].flatten() # Flatten W indices + + reference_frame_feat = patch_feat_unfold.reshape( + B * S * N, C_out, psize - sradius * 2, psize - sradius * 2, ssize, ssize + ) + + # Note again, according to pytorch convention + # x_indices cooresponds to [..., 1] and y_indices cooresponds to [..., 0] + reference_frame_feat = reference_frame_feat[batch_indices_score, :, x_indices, y_indices] + reference_frame_feat = reference_frame_feat.reshape(B, S, N, C_out, ssize, ssize) + # pick the frames other than the first one, so we have S-1 frames here + reference_frame_feat = reference_frame_feat[:, 1:].reshape(B * (S - 1) * N, C_out, ssize * ssize) + + # Compute similarity + sim_matrix = torch.einsum("mc,mcr->mr", query_point_feat, reference_frame_feat) + softmax_temp = 1.0 / C_out**0.5 + heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1) + # 2D heatmaps + heatmap = heatmap.reshape(B * (S - 1) * N, ssize, ssize) # * x ssize x ssize + + coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] + grid_normalized = create_meshgrid(ssize, ssize, normalized_coordinates=True, device=heatmap.device).reshape( + 1, -1, 2 + ) + + var = torch.sum(grid_normalized**2 * heatmap.view(-1, ssize * ssize, 1), dim=1) - coords_normalized**2 + std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # clamp needed for numerical stability + + score = std.reshape(B, S - 1, N) + # set score as 1 for the query frame + score = torch.cat([torch.ones_like(score[:, 0:1]), score], dim=1) + + return score + + +def extract_glimpse( + tensor: torch.Tensor, size: Tuple[int, int], offsets, mode="bilinear", padding_mode="zeros", debug=False, orib=None +): + B, C, W, H = tensor.shape + + h, w = size + xs = torch.arange(0, w, dtype=tensor.dtype, device=tensor.device) - (w - 1) / 2.0 + ys = torch.arange(0, h, dtype=tensor.dtype, device=tensor.device) - (h - 1) / 2.0 + + vy, vx = torch.meshgrid(ys, xs) + grid = torch.stack([vx, vy], dim=-1) # h, w, 2 + grid = grid[None] + + B, N, _ = offsets.shape + + offsets = offsets.reshape((B * N), 1, 1, 2) + offsets_grid = offsets + grid + + # normalised grid to [-1, 1] + offsets_grid = (offsets_grid - offsets_grid.new_tensor([W / 2, H / 2])) / offsets_grid.new_tensor([W / 2, H / 2]) + + # BxCxHxW -> Bx1xCxHxW + tensor = tensor[:, None] + + # Bx1xCxHxW -> BxNxCxHxW + tensor = tensor.expand(-1, N, -1, -1, -1) + + # BxNxCxHxW -> (B*N)xCxHxW + tensor = tensor.reshape((B * N), C, W, H) + + sampled = torch.nn.functional.grid_sample( + tensor, offsets_grid, mode=mode, align_corners=False, padding_mode=padding_mode + ) + + # NOTE: I am not sure it should be h, w or w, h here + # but okay for sqaures + sampled = sampled.reshape(B, N, C, h, w) + + return sampled diff --git a/capvector-pi05/src/vggt/dependency/track_modules/utils.py b/capvector-pi05/src/vggt/dependency/track_modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c1b002c055ec44d2bf65f99041c47a53d6b0c9b1 --- /dev/null +++ b/capvector-pi05/src/vggt/dependency/track_modules/utils.py @@ -0,0 +1,216 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from https://github.com/facebookresearch/PoseDiffusion +# and https://github.com/facebookresearch/co-tracker/tree/main + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Union +from einops import rearrange, repeat + + +def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor: + """ + This function initializes a grid and generates a 2D positional embedding using sine and cosine functions. + It is a wrapper of get_2d_sincos_pos_embed_from_grid. + Args: + - embed_dim: The embedding dimension. + - grid_size: The grid size. + Returns: + - pos_embed: The generated 2D positional embedding. + """ + if isinstance(grid_size, tuple): + grid_size_h, grid_size_w = grid_size + else: + grid_size_h = grid_size_w = grid_size + grid_h = torch.arange(grid_size_h, dtype=torch.float) + grid_w = torch.arange(grid_size_w, dtype=torch.float) + grid = torch.meshgrid(grid_w, grid_h, indexing="xy") + grid = torch.stack(grid, dim=0) + grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if return_grid: + return (pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2), grid) + return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2) + + +def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor: + """ + This function generates a 2D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - grid: The grid to generate the embedding from. + + Returns: + - emb: The generated 2D positional embedding. + """ + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor: + """ + This function generates a 1D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - pos: The position to generate the embedding from. + + Returns: + - emb: The generated 1D positional embedding. + """ + assert embed_dim % 2 == 0 + omega = torch.arange(embed_dim // 2, dtype=torch.double) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + return emb[None].float() + + +def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor: + """ + This function generates a 2D positional embedding from given coordinates using sine and cosine functions. + + Args: + - xy: The coordinates to generate the embedding from. + - C: The size of the embedding. + - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding. + + Returns: + - pe: The generated 2D positional embedding. + """ + B, N, D = xy.shape + assert D == 2 + + x = xy[:, :, 0:1] + y = xy[:, :, 1:2] + div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2)) + + pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) + pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) + + pe_x[:, :, 0::2] = torch.sin(x * div_term) + pe_x[:, :, 1::2] = torch.cos(x * div_term) + + pe_y[:, :, 0::2] = torch.sin(y * div_term) + pe_y[:, :, 1::2] = torch.cos(y * div_term) + + pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3) + if cat_coords: + pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3) + return pe + + +def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"): + r"""Sample a tensor using bilinear interpolation + + `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at + coordinates :attr:`coords` using bilinear interpolation. It is the same + as `torch.nn.functional.grid_sample()` but with a different coordinate + convention. + + The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where + :math:`B` is the batch size, :math:`C` is the number of channels, + :math:`H` is the height of the image, and :math:`W` is the width of the + image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is + interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`. + + Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`, + in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note + that in this case the order of the components is slightly different + from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`. + + If `align_corners` is `True`, the coordinate :math:`x` is assumed to be + in the range :math:`[0,W-1]`, with 0 corresponding to the center of the + left-most image pixel :math:`W-1` to the center of the right-most + pixel. + + If `align_corners` is `False`, the coordinate :math:`x` is assumed to + be in the range :math:`[0,W]`, with 0 corresponding to the left edge of + the left-most pixel :math:`W` to the right edge of the right-most + pixel. + + Similar conventions apply to the :math:`y` for the range + :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range + :math:`[0,T-1]` and :math:`[0,T]`. + + Args: + input (Tensor): batch of input images. + coords (Tensor): batch of coordinates. + align_corners (bool, optional): Coordinate convention. Defaults to `True`. + padding_mode (str, optional): Padding mode. Defaults to `"border"`. + + Returns: + Tensor: sampled points. + """ + + sizes = input.shape[2:] + + assert len(sizes) in [2, 3] + + if len(sizes) == 3: + # t x y -> x y t to match dimensions T H W in grid_sample + coords = coords[..., [1, 2, 0]] + + if align_corners: + coords = coords * torch.tensor([2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device) + else: + coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device) + + coords -= 1 + + return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode) + + +def sample_features4d(input, coords): + r"""Sample spatial features + + `sample_features4d(input, coords)` samples the spatial features + :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`. + + The field is sampled at coordinates :attr:`coords` using bilinear + interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R, + 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the + same convention as :func:`bilinear_sampler` with `align_corners=True`. + + The output tensor has one feature per point, and has shape :math:`(B, + R, C)`. + + Args: + input (Tensor): spatial features. + coords (Tensor): points. + + Returns: + Tensor: sampled features. + """ + + B, _, _, _ = input.shape + + # B R 2 -> B R 1 2 + coords = coords.unsqueeze(2) + + # B C R 1 + feats = bilinear_sampler(input, coords) + + return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C diff --git a/capvector-pi05/src/vggt/dependency/track_predict.py b/capvector-pi05/src/vggt/dependency/track_predict.py new file mode 100644 index 0000000000000000000000000000000000000000..22465c0b53d3e310e0025c06aed8dabccf7339d3 --- /dev/null +++ b/capvector-pi05/src/vggt/dependency/track_predict.py @@ -0,0 +1,326 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import numpy as np +from .vggsfm_utils import * + + +def predict_tracks( + images, + conf=None, + points_3d=None, + masks=None, + max_query_pts=2048, + query_frame_num=5, + keypoint_extractor="aliked+sp", + max_points_num=163840, + fine_tracking=True, + complete_non_vis=True, +): + """ + Predict tracks for the given images and masks. + + TODO: support non-square images + TODO: support masks + + + This function predicts the tracks for the given images and masks using the specified query method + and track predictor. It finds query points, and predicts the tracks, visibility, and scores for the query frames. + + Args: + images: Tensor of shape [S, 3, H, W] containing the input images. + conf: Tensor of shape [S, 1, H, W] containing the confidence scores. Default is None. + points_3d: Tensor containing 3D points. Default is None. + masks: Optional tensor of shape [S, 1, H, W] containing masks. Default is None. + max_query_pts: Maximum number of query points. Default is 2048. + query_frame_num: Number of query frames to use. Default is 5. + keypoint_extractor: Method for keypoint extraction. Default is "aliked+sp". + max_points_num: Maximum number of points to process at once. Default is 163840. + fine_tracking: Whether to use fine tracking. Default is True. + complete_non_vis: Whether to augment non-visible frames. Default is True. + + Returns: + pred_tracks: Numpy array containing the predicted tracks. + pred_vis_scores: Numpy array containing the visibility scores for the tracks. + pred_confs: Numpy array containing the confidence scores for the tracks. + pred_points_3d: Numpy array containing the 3D points for the tracks. + pred_colors: Numpy array containing the point colors for the tracks. (0, 255) + """ + + device = images.device + dtype = images.dtype + tracker = build_vggsfm_tracker().to(device, dtype) + + # Find query frames + query_frame_indexes = generate_rank_by_dino(images, query_frame_num=query_frame_num, device=device) + + # Add the first image to the front if not already present + if 0 in query_frame_indexes: + query_frame_indexes.remove(0) + query_frame_indexes = [0, *query_frame_indexes] + + # TODO: add the functionality to handle the masks + keypoint_extractors = initialize_feature_extractors( + max_query_pts, extractor_method=keypoint_extractor, device=device + ) + + pred_tracks = [] + pred_vis_scores = [] + pred_confs = [] + pred_points_3d = [] + pred_colors = [] + + fmaps_for_tracker = tracker.process_images_to_fmaps(images) + + if fine_tracking: + print("For faster inference, consider disabling fine_tracking") + + for query_index in query_frame_indexes: + print(f"Predicting tracks for query frame {query_index}") + pred_track, pred_vis, pred_conf, pred_point_3d, pred_color = _forward_on_query( + query_index, + images, + conf, + points_3d, + fmaps_for_tracker, + keypoint_extractors, + tracker, + max_points_num, + fine_tracking, + device, + ) + + pred_tracks.append(pred_track) + pred_vis_scores.append(pred_vis) + pred_confs.append(pred_conf) + pred_points_3d.append(pred_point_3d) + pred_colors.append(pred_color) + + if complete_non_vis: + pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors = _augment_non_visible_frames( + pred_tracks, + pred_vis_scores, + pred_confs, + pred_points_3d, + pred_colors, + images, + conf, + points_3d, + fmaps_for_tracker, + keypoint_extractors, + tracker, + max_points_num, + fine_tracking, + min_vis=500, + non_vis_thresh=0.1, + device=device, + ) + + pred_tracks = np.concatenate(pred_tracks, axis=1) + pred_vis_scores = np.concatenate(pred_vis_scores, axis=1) + pred_confs = np.concatenate(pred_confs, axis=0) if pred_confs else None + pred_points_3d = np.concatenate(pred_points_3d, axis=0) if pred_points_3d else None + pred_colors = np.concatenate(pred_colors, axis=0) if pred_colors else None + + # from vggt.utils.visual_track import visualize_tracks_on_images + # visualize_tracks_on_images(images[None], torch.from_numpy(pred_tracks[None]), torch.from_numpy(pred_vis_scores[None])>0.2, out_dir="track_visuals") + + return pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors + + +def _forward_on_query( + query_index, + images, + conf, + points_3d, + fmaps_for_tracker, + keypoint_extractors, + tracker, + max_points_num, + fine_tracking, + device, +): + """ + Process a single query frame for track prediction. + + Args: + query_index: Index of the query frame + images: Tensor of shape [S, 3, H, W] containing the input images + conf: Confidence tensor + points_3d: 3D points tensor + fmaps_for_tracker: Feature maps for the tracker + keypoint_extractors: Initialized feature extractors + tracker: VGG-SFM tracker + max_points_num: Maximum number of points to process at once + fine_tracking: Whether to use fine tracking + device: Device to use for computation + + Returns: + pred_track: Predicted tracks + pred_vis: Visibility scores for the tracks + pred_conf: Confidence scores for the tracks + pred_point_3d: 3D points for the tracks + pred_color: Point colors for the tracks (0, 255) + """ + frame_num, _, height, width = images.shape + + query_image = images[query_index] + query_points = extract_keypoints(query_image, keypoint_extractors, round_keypoints=False) + query_points = query_points[:, torch.randperm(query_points.shape[1], device=device)] + + # Extract the color at the keypoint locations + query_points_long = query_points.squeeze(0).round().long() + pred_color = images[query_index][:, query_points_long[:, 1], query_points_long[:, 0]] + pred_color = (pred_color.permute(1, 0).cpu().numpy() * 255).astype(np.uint8) + + # Query the confidence and points_3d at the keypoint locations + if (conf is not None) and (points_3d is not None): + assert height == width + assert conf.shape[-2] == conf.shape[-1] + assert conf.shape[:3] == points_3d.shape[:3] + scale = conf.shape[-1] / width + + query_points_scaled = (query_points.squeeze(0) * scale).round().long() + query_points_scaled = query_points_scaled.cpu().numpy() + + pred_conf = conf[query_index][query_points_scaled[:, 1], query_points_scaled[:, 0]] + pred_point_3d = points_3d[query_index][query_points_scaled[:, 1], query_points_scaled[:, 0]] + + # heuristic to remove low confidence points + # should I export this as an input parameter? + valid_mask = pred_conf > 1.2 + if valid_mask.sum() > 512: + query_points = query_points[:, valid_mask] # Make sure shape is compatible + pred_conf = pred_conf[valid_mask] + pred_point_3d = pred_point_3d[valid_mask] + pred_color = pred_color[valid_mask] + else: + pred_conf = None + pred_point_3d = None + + reorder_index = calculate_index_mappings(query_index, frame_num, device=device) + + images_feed, fmaps_feed = switch_tensor_order([images, fmaps_for_tracker], reorder_index, dim=0) + images_feed = images_feed[None] # add batch dimension + fmaps_feed = fmaps_feed[None] # add batch dimension + + all_points_num = images_feed.shape[1] * query_points.shape[1] + + # Don't need to be scared, this is just chunking to make GPU happy + if all_points_num > max_points_num: + num_splits = (all_points_num + max_points_num - 1) // max_points_num + query_points = torch.chunk(query_points, num_splits, dim=1) + else: + query_points = [query_points] + + pred_track, pred_vis, _ = predict_tracks_in_chunks( + tracker, images_feed, query_points, fmaps_feed, fine_tracking=fine_tracking + ) + + pred_track, pred_vis = switch_tensor_order([pred_track, pred_vis], reorder_index, dim=1) + + pred_track = pred_track.squeeze(0).float().cpu().numpy() + pred_vis = pred_vis.squeeze(0).float().cpu().numpy() + + return pred_track, pred_vis, pred_conf, pred_point_3d, pred_color + + +def _augment_non_visible_frames( + pred_tracks: list, # ← running list of np.ndarrays + pred_vis_scores: list, # ← running list of np.ndarrays + pred_confs: list, # ← running list of np.ndarrays for confidence scores + pred_points_3d: list, # ← running list of np.ndarrays for 3D points + pred_colors: list, # ← running list of np.ndarrays for colors + images: torch.Tensor, + conf, + points_3d, + fmaps_for_tracker, + keypoint_extractors, + tracker, + max_points_num: int, + fine_tracking: bool, + *, + min_vis: int = 500, + non_vis_thresh: float = 0.1, + device: torch.device = None, +): + """ + Augment tracking for frames with insufficient visibility. + + Args: + pred_tracks: List of numpy arrays containing predicted tracks. + pred_vis_scores: List of numpy arrays containing visibility scores. + pred_confs: List of numpy arrays containing confidence scores. + pred_points_3d: List of numpy arrays containing 3D points. + pred_colors: List of numpy arrays containing point colors. + images: Tensor of shape [S, 3, H, W] containing the input images. + conf: Tensor of shape [S, 1, H, W] containing confidence scores + points_3d: Tensor containing 3D points + fmaps_for_tracker: Feature maps for the tracker + keypoint_extractors: Initialized feature extractors + tracker: VGG-SFM tracker + max_points_num: Maximum number of points to process at once + fine_tracking: Whether to use fine tracking + min_vis: Minimum visibility threshold + non_vis_thresh: Non-visibility threshold + device: Device to use for computation + + Returns: + Updated pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, and pred_colors lists. + """ + last_query = -1 + final_trial = False + cur_extractors = keypoint_extractors # may be replaced on the final trial + + while True: + # Visibility per frame + vis_array = np.concatenate(pred_vis_scores, axis=1) + + # Count frames with sufficient visibility using numpy + sufficient_vis_count = (vis_array > non_vis_thresh).sum(axis=-1) + non_vis_frames = np.where(sufficient_vis_count < min_vis)[0].tolist() + + if len(non_vis_frames) == 0: + break + + print("Processing non visible frames:", non_vis_frames) + + # Decide the frames & extractor for this round + if non_vis_frames[0] == last_query: + # Same frame failed twice - final "all-in" attempt + final_trial = True + cur_extractors = initialize_feature_extractors(2048, extractor_method="sp+sift+aliked", device=device) + query_frame_list = non_vis_frames # blast them all at once + else: + query_frame_list = [non_vis_frames[0]] # Process one at a time + + last_query = non_vis_frames[0] + + # Run the tracker for every selected frame + for query_index in query_frame_list: + new_track, new_vis, new_conf, new_point_3d, new_color = _forward_on_query( + query_index, + images, + conf, + points_3d, + fmaps_for_tracker, + cur_extractors, + tracker, + max_points_num, + fine_tracking, + device, + ) + pred_tracks.append(new_track) + pred_vis_scores.append(new_vis) + pred_confs.append(new_conf) + pred_points_3d.append(new_point_3d) + pred_colors.append(new_color) + + if final_trial: + break # Stop after final attempt + + return pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors diff --git a/capvector-pi05/src/vggt/dependency/vggsfm_tracker.py b/capvector-pi05/src/vggt/dependency/vggsfm_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..e3940907f2bde73886af29af5e4ef8250b5b0d1b --- /dev/null +++ b/capvector-pi05/src/vggt/dependency/vggsfm_tracker.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +from torch import nn, einsum +from einops import rearrange, repeat +from einops.layers.torch import Rearrange, Reduce + +from hydra.utils import instantiate +from omegaconf import OmegaConf + +from .track_modules.track_refine import refine_track +from .track_modules.blocks import BasicEncoder, ShallowEncoder +from .track_modules.base_track_predictor import BaseTrackerPredictor + + +class TrackerPredictor(nn.Module): + def __init__(self, **extra_args): + super(TrackerPredictor, self).__init__() + """ + Initializes the tracker predictor. + + Both coarse_predictor and fine_predictor are constructed as a BaseTrackerPredictor, + check track_modules/base_track_predictor.py + + Both coarse_fnet and fine_fnet are constructed as a 2D CNN network + check track_modules/blocks.py for BasicEncoder and ShallowEncoder + """ + # Define coarse predictor configuration + coarse_stride = 4 + self.coarse_down_ratio = 2 + + # Create networks directly instead of using instantiate + self.coarse_fnet = BasicEncoder(stride=coarse_stride) + self.coarse_predictor = BaseTrackerPredictor(stride=coarse_stride) + + # Create fine predictor with stride = 1 + self.fine_fnet = ShallowEncoder(stride=1) + self.fine_predictor = BaseTrackerPredictor( + stride=1, + depth=4, + corr_levels=3, + corr_radius=3, + latent_dim=32, + hidden_size=256, + fine=True, + use_spaceatt=False, + ) + + def forward( + self, images, query_points, fmaps=None, coarse_iters=6, inference=True, fine_tracking=True, fine_chunk=40960 + ): + """ + Args: + images (torch.Tensor): Images as RGB, in the range of [0, 1], with a shape of B x S x 3 x H x W. + query_points (torch.Tensor): 2D xy of query points, relative to top left, with a shape of B x N x 2. + fmaps (torch.Tensor, optional): Precomputed feature maps. Defaults to None. + coarse_iters (int, optional): Number of iterations for coarse prediction. Defaults to 6. + inference (bool, optional): Whether to perform inference. Defaults to True. + fine_tracking (bool, optional): Whether to perform fine tracking. Defaults to True. + + Returns: + tuple: A tuple containing fine_pred_track, coarse_pred_track, pred_vis, and pred_score. + """ + + if fmaps is None: + batch_num, frame_num, image_dim, height, width = images.shape + reshaped_image = images.reshape(batch_num * frame_num, image_dim, height, width) + fmaps = self.process_images_to_fmaps(reshaped_image) + fmaps = fmaps.reshape(batch_num, frame_num, -1, fmaps.shape[-2], fmaps.shape[-1]) + + if inference: + torch.cuda.empty_cache() + + # Coarse prediction + coarse_pred_track_lists, pred_vis = self.coarse_predictor( + query_points=query_points, fmaps=fmaps, iters=coarse_iters, down_ratio=self.coarse_down_ratio + ) + coarse_pred_track = coarse_pred_track_lists[-1] + + if inference: + torch.cuda.empty_cache() + + if fine_tracking: + # Refine the coarse prediction + fine_pred_track, pred_score = refine_track( + images, self.fine_fnet, self.fine_predictor, coarse_pred_track, compute_score=False, chunk=fine_chunk + ) + + if inference: + torch.cuda.empty_cache() + else: + fine_pred_track = coarse_pred_track + pred_score = torch.ones_like(pred_vis) + + return fine_pred_track, coarse_pred_track, pred_vis, pred_score + + def process_images_to_fmaps(self, images): + """ + This function processes images for inference. + + Args: + images (torch.Tensor): The images to be processed with shape S x 3 x H x W. + + Returns: + torch.Tensor: The processed feature maps. + """ + if self.coarse_down_ratio > 1: + # whether or not scale down the input images to save memory + fmaps = self.coarse_fnet( + F.interpolate(images, scale_factor=1 / self.coarse_down_ratio, mode="bilinear", align_corners=True) + ) + else: + fmaps = self.coarse_fnet(images) + + return fmaps diff --git a/capvector-pi05/src/vggt/dependency/vggsfm_utils.py b/capvector-pi05/src/vggt/dependency/vggsfm_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d1b75497199d18d62f3fb5db1f203fe5edccedf2 --- /dev/null +++ b/capvector-pi05/src/vggt/dependency/vggsfm_utils.py @@ -0,0 +1,305 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import warnings +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import pycolmap +import torch +import torch.nn.functional as F +from lightglue import ALIKED, SIFT, SuperPoint + +from .vggsfm_tracker import TrackerPredictor + +# Suppress verbose logging from dependencies +logging.getLogger("dinov2").setLevel(logging.WARNING) +warnings.filterwarnings("ignore", message="xFormers is available") +warnings.filterwarnings("ignore", message="dinov2") + +# Constants +_RESNET_MEAN = [0.485, 0.456, 0.406] +_RESNET_STD = [0.229, 0.224, 0.225] + + +def build_vggsfm_tracker(model_path=None): + """ + Build and initialize the VGGSfM tracker. + + Args: + model_path: Path to the model weights file. If None, weights are downloaded from HuggingFace. + + Returns: + Initialized tracker model in eval mode. + """ + tracker = TrackerPredictor() + + if model_path is None: + default_url = "https://huggingface.co/facebook/VGGSfM/resolve/main/vggsfm_v2_tracker.pt" + tracker.load_state_dict(torch.hub.load_state_dict_from_url(default_url)) + else: + tracker.load_state_dict(torch.load(model_path)) + + tracker.eval() + return tracker + + +def generate_rank_by_dino( + images, query_frame_num, image_size=336, model_name="dinov2_vitb14_reg", device="cuda", spatial_similarity=False +): + """ + Generate a ranking of frames using DINO ViT features. + + Args: + images: Tensor of shape (S, 3, H, W) with values in range [0, 1] + query_frame_num: Number of frames to select + image_size: Size to resize images to before processing + model_name: Name of the DINO model to use + device: Device to run the model on + spatial_similarity: Whether to use spatial token similarity or CLS token similarity + + Returns: + List of frame indices ranked by their representativeness + """ + # Resize images to the target size + images = F.interpolate(images, (image_size, image_size), mode="bilinear", align_corners=False) + + # Load DINO model + dino_v2_model = torch.hub.load("facebookresearch/dinov2", model_name) + dino_v2_model.eval() + dino_v2_model = dino_v2_model.to(device) + + # Normalize images using ResNet normalization + resnet_mean = torch.tensor(_RESNET_MEAN, device=device).view(1, 3, 1, 1) + resnet_std = torch.tensor(_RESNET_STD, device=device).view(1, 3, 1, 1) + images_resnet_norm = (images - resnet_mean) / resnet_std + + with torch.no_grad(): + frame_feat = dino_v2_model(images_resnet_norm, is_training=True) + + # Process features based on similarity type + if spatial_similarity: + frame_feat = frame_feat["x_norm_patchtokens"] + frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) + + # Compute the similarity matrix + frame_feat_norm = frame_feat_norm.permute(1, 0, 2) + similarity_matrix = torch.bmm(frame_feat_norm, frame_feat_norm.transpose(-1, -2)) + similarity_matrix = similarity_matrix.mean(dim=0) + else: + frame_feat = frame_feat["x_norm_clstoken"] + frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) + similarity_matrix = torch.mm(frame_feat_norm, frame_feat_norm.transpose(-1, -2)) + + distance_matrix = 100 - similarity_matrix.clone() + + # Ignore self-pairing + similarity_matrix.fill_diagonal_(-100) + similarity_sum = similarity_matrix.sum(dim=1) + + # Find the most common frame + most_common_frame_index = torch.argmax(similarity_sum).item() + + # Conduct FPS sampling starting from the most common frame + fps_idx = farthest_point_sampling(distance_matrix, query_frame_num, most_common_frame_index) + + # Clean up all tensors and models to free memory + del frame_feat, frame_feat_norm, similarity_matrix, distance_matrix + del dino_v2_model + torch.cuda.empty_cache() + + return fps_idx + + +def farthest_point_sampling(distance_matrix, num_samples, most_common_frame_index=0): + """ + Farthest point sampling algorithm to select diverse frames. + + Args: + distance_matrix: Matrix of distances between frames + num_samples: Number of frames to select + most_common_frame_index: Index of the first frame to select + + Returns: + List of selected frame indices + """ + distance_matrix = distance_matrix.clamp(min=0) + N = distance_matrix.size(0) + + # Initialize with the most common frame + selected_indices = [most_common_frame_index] + check_distances = distance_matrix[selected_indices] + + while len(selected_indices) < num_samples: + # Find the farthest point from the current set of selected points + farthest_point = torch.argmax(check_distances) + selected_indices.append(farthest_point.item()) + + check_distances = distance_matrix[farthest_point] + # Mark already selected points to avoid selecting them again + check_distances[selected_indices] = 0 + + # Break if all points have been selected + if len(selected_indices) == N: + break + + return selected_indices + + +def calculate_index_mappings(query_index, S, device=None): + """ + Construct an order that switches [query_index] and [0] + so that the content of query_index would be placed at [0]. + + Args: + query_index: Index to swap with 0 + S: Total number of elements + device: Device to place the tensor on + + Returns: + Tensor of indices with the swapped order + """ + new_order = torch.arange(S) + new_order[0] = query_index + new_order[query_index] = 0 + if device is not None: + new_order = new_order.to(device) + return new_order + + +def switch_tensor_order(tensors, order, dim=1): + """ + Reorder tensors along a specific dimension according to the given order. + + Args: + tensors: List of tensors to reorder + order: Tensor of indices specifying the new order + dim: Dimension along which to reorder + + Returns: + List of reordered tensors + """ + return [torch.index_select(tensor, dim, order) if tensor is not None else None for tensor in tensors] + + +def initialize_feature_extractors(max_query_num, det_thres=0.005, extractor_method="aliked", device="cuda"): + """ + Initialize feature extractors that can be reused based on a method string. + + Args: + max_query_num: Maximum number of keypoints to extract + det_thres: Detection threshold for keypoint extraction + extractor_method: String specifying which extractors to use (e.g., "aliked", "sp+sift", "aliked+sp+sift") + device: Device to run extraction on + + Returns: + Dictionary of initialized extractors + """ + extractors = {} + methods = extractor_method.lower().split("+") + + for method in methods: + method = method.strip() + if method == "aliked": + aliked_extractor = ALIKED(max_num_keypoints=max_query_num, detection_threshold=det_thres) + extractors["aliked"] = aliked_extractor.to(device).eval() + elif method == "sp": + sp_extractor = SuperPoint(max_num_keypoints=max_query_num, detection_threshold=det_thres) + extractors["sp"] = sp_extractor.to(device).eval() + elif method == "sift": + sift_extractor = SIFT(max_num_keypoints=max_query_num) + extractors["sift"] = sift_extractor.to(device).eval() + else: + print(f"Warning: Unknown feature extractor '{method}', ignoring.") + + if not extractors: + print(f"Warning: No valid extractors found in '{extractor_method}'. Using ALIKED by default.") + aliked_extractor = ALIKED(max_num_keypoints=max_query_num, detection_threshold=det_thres) + extractors["aliked"] = aliked_extractor.to(device).eval() + + return extractors + + +def extract_keypoints(query_image, extractors, round_keypoints=True): + """ + Extract keypoints using pre-initialized feature extractors. + + Args: + query_image: Input image tensor (3xHxW, range [0, 1]) + extractors: Dictionary of initialized extractors + + Returns: + Tensor of keypoint coordinates (1xNx2) + """ + query_points = None + + with torch.no_grad(): + for extractor_name, extractor in extractors.items(): + query_points_data = extractor.extract(query_image, invalid_mask=None) + extractor_points = query_points_data["keypoints"] + if round_keypoints: + extractor_points = extractor_points.round() + + if query_points is not None: + query_points = torch.cat([query_points, extractor_points], dim=1) + else: + query_points = extractor_points + + return query_points + + +def predict_tracks_in_chunks( + track_predictor, images_feed, query_points_list, fmaps_feed, fine_tracking, num_splits=None, fine_chunk=40960 +): + """ + Process a list of query points to avoid memory issues. + + Args: + track_predictor (object): The track predictor object used for predicting tracks. + images_feed (torch.Tensor): A tensor of shape (B, T, C, H, W) representing a batch of images. + query_points_list (list or tuple): A list/tuple of tensors, each of shape (B, Ni, 2) representing chunks of query points. + fmaps_feed (torch.Tensor): A tensor of feature maps for the tracker. + fine_tracking (bool): Whether to perform fine tracking. + num_splits (int, optional): Ignored when query_points_list is provided. Kept for backward compatibility. + + Returns: + tuple: A tuple containing the concatenated predicted tracks, visibility, and scores. + """ + # If query_points_list is not a list or tuple but a single tensor, handle it like the old version for backward compatibility + if not isinstance(query_points_list, (list, tuple)): + query_points = query_points_list + if num_splits is None: + num_splits = 1 + query_points_list = torch.chunk(query_points, num_splits, dim=1) + + # Ensure query_points_list is a list for iteration (as torch.chunk returns a tuple) + if isinstance(query_points_list, tuple): + query_points_list = list(query_points_list) + + fine_pred_track_list = [] + pred_vis_list = [] + pred_score_list = [] + + for split_points in query_points_list: + # Feed into track predictor for each split + fine_pred_track, _, pred_vis, pred_score = track_predictor( + images_feed, split_points, fmaps=fmaps_feed, fine_tracking=fine_tracking, fine_chunk=fine_chunk + ) + fine_pred_track_list.append(fine_pred_track) + pred_vis_list.append(pred_vis) + pred_score_list.append(pred_score) + + # Concatenate the results from all splits + fine_pred_track = torch.cat(fine_pred_track_list, dim=2) + pred_vis = torch.cat(pred_vis_list, dim=2) + + if pred_score is not None: + pred_score = torch.cat(pred_score_list, dim=2) + else: + pred_score = None + + return fine_pred_track, pred_vis, pred_score diff --git a/capvector-pi05/src/vggt/heads/camera_head.py b/capvector-pi05/src/vggt/heads/camera_head.py new file mode 100644 index 0000000000000000000000000000000000000000..b88b7d6e909adf45853af03d546343b5a8bbe472 --- /dev/null +++ b/capvector-pi05/src/vggt/heads/camera_head.py @@ -0,0 +1,149 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from vggt.layers import Mlp +from vggt.layers.block import Block +from vggt.heads.head_act import activate_pose + + +class CameraHead(nn.Module): + """ + CameraHead predicts camera parameters from token representations using iterative refinement. + + It applies a series of transformer blocks (the "trunk") to dedicated camera tokens. + """ + + def __init__( + self, + dim_in: int = 2048, + trunk_depth: int = 4, + pose_encoding_type: str = "absT_quaR_FoV", + num_heads: int = 16, + mlp_ratio: int = 4, + init_values: float = 0.01, + trans_act: str = "linear", + quat_act: str = "linear", + fl_act: str = "relu", # Field of view activations: ensures FOV values are positive. + ): + super().__init__() + + if pose_encoding_type == "absT_quaR_FoV": + self.target_dim = 9 + else: + raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}") + + self.trans_act = trans_act + self.quat_act = quat_act + self.fl_act = fl_act + self.trunk_depth = trunk_depth + + # Build the trunk using a sequence of transformer blocks. + self.trunk = nn.Sequential( + *[ + Block(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values) + for _ in range(trunk_depth) + ] + ) + + # Normalizations for camera token and trunk output. + self.token_norm = nn.LayerNorm(dim_in) + self.trunk_norm = nn.LayerNorm(dim_in) + + # Learnable empty camera pose token. + self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim)) + self.embed_pose = nn.Linear(self.target_dim, dim_in) + + # Module for producing modulation parameters: shift, scale, and a gate. + self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True)) + + # Adaptive layer normalization without affine parameters. + self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6) + self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0) + + def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list: + """ + Forward pass to predict camera parameters. + + Args: + aggregated_tokens_list (list): List of token tensors from the network; + the last tensor is used for prediction. + num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4. + + Returns: + list: A list of predicted camera encodings (post-activation) from each iteration. + """ + # Use tokens from the last block for camera prediction. + tokens = aggregated_tokens_list[-1] + + # Extract the camera tokens + pose_tokens = tokens[:, :, 0] + pose_tokens = self.token_norm(pose_tokens) + + pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations) + return pred_pose_enc_list + + def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list: + """ + Iteratively refine camera pose predictions. + + Args: + pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, S, C]. + num_iterations (int): Number of refinement iterations. + + Returns: + list: List of activated camera encodings from each iteration. + """ + B, S, C = pose_tokens.shape + pred_pose_enc = None + pred_pose_enc_list = [] + + for _ in range(num_iterations): + # Use a learned empty pose for the first iteration. + if pred_pose_enc is None: + module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1)) + else: + # Detach the previous prediction to avoid backprop through time. + pred_pose_enc = pred_pose_enc.detach() + module_input = self.embed_pose(pred_pose_enc) + + # Generate modulation parameters and split them into shift, scale, and gate components. + shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1) + + # Adaptive layer normalization and modulation. + pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa) + pose_tokens_modulated = pose_tokens_modulated + pose_tokens + + pose_tokens_modulated = self.trunk(pose_tokens_modulated) + # Compute the delta update for the pose encoding. + pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated)) + + if pred_pose_enc is None: + pred_pose_enc = pred_pose_enc_delta + else: + pred_pose_enc = pred_pose_enc + pred_pose_enc_delta + + # Apply final activation functions for translation, quaternion, and field-of-view. + activated_pose = activate_pose( + pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act + ) + pred_pose_enc_list.append(activated_pose) + + return pred_pose_enc_list + + +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """ + Modulate the input tensor using scaling and shifting parameters. + """ + # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19 + return x * (1 + scale) + shift diff --git a/capvector-pi05/src/vggt/heads/dpt_head.py b/capvector-pi05/src/vggt/heads/dpt_head.py new file mode 100644 index 0000000000000000000000000000000000000000..6f88a404f50735ece07b44714c986de0f4efcfe3 --- /dev/null +++ b/capvector-pi05/src/vggt/heads/dpt_head.py @@ -0,0 +1,484 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +# Inspired by https://github.com/DepthAnything/Depth-Anything-V2 + + +import os +from typing import List, Dict, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from .head_act import activate_head +from .utils import create_uv_grid, position_grid_to_embed + + +class DPTHead(nn.Module): + """ + DPT Head for dense prediction tasks. + + This implementation follows the architecture described in "Vision Transformers for Dense Prediction" + (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer + backbone and produces dense predictions by fusing multi-scale features. + + Args: + dim_in (int): Input dimension (channels). + patch_size (int, optional): Patch size. Default is 14. + output_dim (int, optional): Number of output channels. Default is 4. + activation (str, optional): Activation type. Default is "inv_log". + conf_activation (str, optional): Confidence activation type. Default is "expp1". + features (int, optional): Feature channels for intermediate representations. Default is 256. + out_channels (List[int], optional): Output channels for each intermediate layer. + intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT. + pos_embed (bool, optional): Whether to use positional embedding. Default is True. + feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False. + down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1. + """ + + def __init__( + self, + dim_in: int, + patch_size: int = 14, + output_dim: int = 4, + activation: str = "inv_log", + conf_activation: str = "expp1", + features: int = 256, + out_channels: List[int] = [256, 512, 1024, 1024], + intermediate_layer_idx: List[int] = [4, 11, 17, 23], + pos_embed: bool = True, + feature_only: bool = False, + down_ratio: int = 1, + ) -> None: + super(DPTHead, self).__init__() + self.patch_size = patch_size + self.activation = activation + self.conf_activation = conf_activation + self.pos_embed = pos_embed + self.feature_only = feature_only + self.down_ratio = down_ratio + self.intermediate_layer_idx = intermediate_layer_idx + + self.norm = nn.LayerNorm(dim_in) + + # Projection layers for each output channel from tokens. + self.projects = nn.ModuleList( + [nn.Conv2d(in_channels=dim_in, out_channels=oc, kernel_size=1, stride=1, padding=0) for oc in out_channels] + ) + + # Resize layers for upsampling feature maps. + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0 + ), + nn.ConvTranspose2d( + in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0 + ), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1 + ), + ] + ) + + self.scratch = _make_scratch(out_channels, features, expand=False) + + # Attach additional modules to scratch. + self.scratch.stem_transpose = None + self.scratch.refinenet1 = _make_fusion_block(features) + self.scratch.refinenet2 = _make_fusion_block(features) + self.scratch.refinenet3 = _make_fusion_block(features) + self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False) + + head_features_1 = features + head_features_2 = 32 + + if feature_only: + self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1) + else: + self.scratch.output_conv1 = nn.Conv2d( + head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1 + ) + conv2_in_channels = head_features_1 // 2 + + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0), + ) + + def forward( + self, + aggregated_tokens_list: List[torch.Tensor], + images: torch.Tensor, + patch_start_idx: int, + frames_chunk_size: int = 8, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Forward pass through the DPT head, supports processing by chunking frames. + Args: + aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. + images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. + patch_start_idx (int): Starting index for patch tokens in the token sequence. + Used to separate patch tokens from other tokens (e.g., camera or register tokens). + frames_chunk_size (int, optional): Number of frames to process in each chunk. + If None or larger than S, all frames are processed at once. Default: 8. + + Returns: + Tensor or Tuple[Tensor, Tensor]: + - If feature_only=True: Feature maps with shape [B, S, C, H, W] + - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W] + """ + B, S, _, H, W = images.shape + + # If frames_chunk_size is not specified or greater than S, process all frames at once + if frames_chunk_size is None or frames_chunk_size >= S: + return self._forward_impl(aggregated_tokens_list, images, patch_start_idx) + + # Otherwise, process frames in chunks to manage memory usage + assert frames_chunk_size > 0 + + # Process frames in batches + all_preds = [] + all_conf = [] + + for frames_start_idx in range(0, S, frames_chunk_size): + frames_end_idx = min(frames_start_idx + frames_chunk_size, S) + + # Process batch of frames + if self.feature_only: + chunk_output = self._forward_impl( + aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx + ) + all_preds.append(chunk_output) + else: + chunk_preds, chunk_conf = self._forward_impl( + aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx + ) + all_preds.append(chunk_preds) + all_conf.append(chunk_conf) + + # Concatenate results along the sequence dimension + if self.feature_only: + return torch.cat(all_preds, dim=1) + else: + return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1) + + def _forward_impl( + self, + aggregated_tokens_list: List[torch.Tensor], + images: torch.Tensor, + patch_start_idx: int, + frames_start_idx: int = None, + frames_end_idx: int = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Implementation of the forward pass through the DPT head. + + This method processes a specific chunk of frames from the sequence. + + Args: + aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. + images (Tensor): Input images with shape [B, S, 3, H, W]. + patch_start_idx (int): Starting index for patch tokens. + frames_start_idx (int, optional): Starting index for frames to process. + frames_end_idx (int, optional): Ending index for frames to process. + + Returns: + Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence). + """ + if frames_start_idx is not None and frames_end_idx is not None: + images = images[:, frames_start_idx:frames_end_idx].contiguous() + + B, S, _, H, W = images.shape + + patch_h, patch_w = H // self.patch_size, W // self.patch_size + + out = [] + dpt_idx = 0 + + for layer_idx in self.intermediate_layer_idx: + x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:] + + # Select frames if processing a chunk + if frames_start_idx is not None and frames_end_idx is not None: + x = x[:, frames_start_idx:frames_end_idx] + + x = x.reshape(B * S, -1, x.shape[-1]) + + x = self.norm(x) + + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + + x = self.projects[dpt_idx](x) + if self.pos_embed: + x = self._apply_pos_embed(x, W, H) + x = self.resize_layers[dpt_idx](x) + + out.append(x) + dpt_idx += 1 + + # Fuse features from multiple layers. + out = self.scratch_forward(out) + # Interpolate fused output to match target image resolution. + out = custom_interpolate( + out, + (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)), + mode="bilinear", + align_corners=True, + ) + + if self.pos_embed: + out = self._apply_pos_embed(out, W, H) + + if self.feature_only: + return out.view(B, S, *out.shape[1:]) + + out = self.scratch.output_conv2(out) + preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation) + + preds = preds.view(B, S, *preds.shape[1:]) + conf = conf.view(B, S, *conf.shape[1:]) + return preds, conf + + def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor: + """ + Apply positional embedding to tensor x. + """ + patch_w = x.shape[-1] + patch_h = x.shape[-2] + pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device) + pos_embed = position_grid_to_embed(pos_embed, x.shape[1]) + pos_embed = pos_embed * ratio + pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1) + return x + pos_embed + + def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor: + """ + Forward pass through the fusion blocks. + + Args: + features (List[Tensor]): List of feature maps from different layers. + + Returns: + Tensor: Fused feature map. + """ + layer_1, layer_2, layer_3, layer_4 = features + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + del layer_4_rn, layer_4 + + out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:]) + del layer_3_rn, layer_3 + + out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:]) + del layer_2_rn, layer_2 + + out = self.scratch.refinenet1(out, layer_1_rn) + del layer_1_rn, layer_1 + + out = self.scratch.output_conv1(out) + return out + + +################################################################################ +# Modules +################################################################################ + + +def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module: + return FeatureFusionBlock( + features, + nn.ReLU(inplace=True), + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=size, + has_residual=has_residual, + groups=groups, + ) + + +def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module: + scratch = nn.Module() + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + if len(in_shape) >= 4: + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + return scratch + + +class ResidualConvUnit(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn, groups=1): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + self.groups = groups + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + self.norm1 = None + self.norm2 = None + + self.activation = activation + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.norm1 is not None: + out = self.norm1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.norm2 is not None: + out = self.norm2(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=None, + has_residual=True, + groups=1, + ): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + self.groups = groups + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups + ) + + if has_residual: + self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups) + + self.has_residual = has_residual + self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups) + + self.skip_add = nn.quantized.FloatFunctional() + self.size = size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if self.has_residual: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) + output = self.out_conv(output) + + return output + + +def custom_interpolate( + x: torch.Tensor, + size: Tuple[int, int] = None, + scale_factor: float = None, + mode: str = "bilinear", + align_corners: bool = True, +) -> torch.Tensor: + """ + Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate. + """ + if size is None: + size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor)) + + INT_MAX = 1610612736 + + input_elements = size[0] * size[1] * x.shape[0] * x.shape[1] + + if input_elements > INT_MAX: + chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0) + interpolated_chunks = [ + nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks + ] + x = torch.cat(interpolated_chunks, dim=0) + return x.contiguous() + else: + return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners) diff --git a/capvector-pi05/src/vggt/heads/head_act.py b/capvector-pi05/src/vggt/heads/head_act.py new file mode 100644 index 0000000000000000000000000000000000000000..a37669d50cf9b52dba297c2c7a5bea00987bb67d --- /dev/null +++ b/capvector-pi05/src/vggt/heads/head_act.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn.functional as F + + +def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"): + """ + Activate pose parameters with specified activation functions. + + Args: + pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length] + trans_act: Activation type for translation component + quat_act: Activation type for quaternion component + fl_act: Activation type for focal length component + + Returns: + Activated pose parameters tensor + """ + T = pred_pose_enc[..., :3] + quat = pred_pose_enc[..., 3:7] + fl = pred_pose_enc[..., 7:] # or fov + + T = base_pose_act(T, trans_act) + quat = base_pose_act(quat, quat_act) + fl = base_pose_act(fl, fl_act) # or fov + + pred_pose_enc = torch.cat([T, quat, fl], dim=-1) + + return pred_pose_enc + + +def base_pose_act(pose_enc, act_type="linear"): + """ + Apply basic activation function to pose parameters. + + Args: + pose_enc: Tensor containing encoded pose parameters + act_type: Activation type ("linear", "inv_log", "exp", "relu") + + Returns: + Activated pose parameters + """ + if act_type == "linear": + return pose_enc + elif act_type == "inv_log": + return inverse_log_transform(pose_enc) + elif act_type == "exp": + return torch.exp(pose_enc) + elif act_type == "relu": + return F.relu(pose_enc) + else: + raise ValueError(f"Unknown act_type: {act_type}") + + +def activate_head(out, activation="norm_exp", conf_activation="expp1"): + """ + Process network output to extract 3D points and confidence values. + + Args: + out: Network output tensor (B, C, H, W) + activation: Activation type for 3D points + conf_activation: Activation type for confidence values + + Returns: + Tuple of (3D points tensor, confidence tensor) + """ + # Move channels from last dim to the 4th dimension => (B, H, W, C) + fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected + + # Split into xyz (first C-1 channels) and confidence (last channel) + xyz = fmap[:, :, :, :-1] + conf = fmap[:, :, :, -1] + + if activation == "norm_exp": + d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8) + xyz_normed = xyz / d + pts3d = xyz_normed * torch.expm1(d) + elif activation == "norm": + pts3d = xyz / xyz.norm(dim=-1, keepdim=True) + elif activation == "exp": + pts3d = torch.exp(xyz) + elif activation == "relu": + pts3d = F.relu(xyz) + elif activation == "inv_log": + pts3d = inverse_log_transform(xyz) + elif activation == "xy_inv_log": + xy, z = xyz.split([2, 1], dim=-1) + z = inverse_log_transform(z) + pts3d = torch.cat([xy * z, z], dim=-1) + elif activation == "sigmoid": + pts3d = torch.sigmoid(xyz) + elif activation == "linear": + pts3d = xyz + else: + raise ValueError(f"Unknown activation: {activation}") + + if conf_activation == "expp1": + conf_out = 1 + conf.exp() + elif conf_activation == "expp0": + conf_out = conf.exp() + elif conf_activation == "sigmoid": + conf_out = torch.sigmoid(conf) + else: + raise ValueError(f"Unknown conf_activation: {conf_activation}") + + return pts3d, conf_out + + +def inverse_log_transform(y): + """ + Apply inverse log transform: sign(y) * (exp(|y|) - 1) + + Args: + y: Input tensor + + Returns: + Transformed tensor + """ + return torch.sign(y) * (torch.expm1(torch.abs(y))) diff --git a/capvector-pi05/src/vggt/heads/track_head.py b/capvector-pi05/src/vggt/heads/track_head.py new file mode 100644 index 0000000000000000000000000000000000000000..e6356cbd8273557deac9225cd09125ebca34fc65 --- /dev/null +++ b/capvector-pi05/src/vggt/heads/track_head.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn +from .dpt_head import DPTHead +from .track_modules.base_track_predictor import BaseTrackerPredictor + + +class TrackHead(nn.Module): + """ + Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking. + The tracking is performed iteratively, refining predictions over multiple iterations. + """ + + def __init__( + self, + dim_in, + patch_size=14, + features=128, + iters=4, + predict_conf=True, + stride=2, + corr_levels=7, + corr_radius=4, + hidden_size=384, + ): + """ + Initialize the TrackHead module. + + Args: + dim_in (int): Input dimension of tokens from the backbone. + patch_size (int): Size of image patches used in the vision transformer. + features (int): Number of feature channels in the feature extractor output. + iters (int): Number of refinement iterations for tracking predictions. + predict_conf (bool): Whether to predict confidence scores for tracked points. + stride (int): Stride value for the tracker predictor. + corr_levels (int): Number of correlation pyramid levels + corr_radius (int): Radius for correlation computation, controlling the search area. + hidden_size (int): Size of hidden layers in the tracker network. + """ + super().__init__() + + self.patch_size = patch_size + + # Feature extractor based on DPT architecture + # Processes tokens into feature maps for tracking + self.feature_extractor = DPTHead( + dim_in=dim_in, + patch_size=patch_size, + features=features, + feature_only=True, # Only output features, no activation + down_ratio=2, # Reduces spatial dimensions by factor of 2 + pos_embed=False, + ) + + # Tracker module that predicts point trajectories + # Takes feature maps and predicts coordinates and visibility + self.tracker = BaseTrackerPredictor( + latent_dim=features, # Match the output_dim of feature extractor + predict_conf=predict_conf, + stride=stride, + corr_levels=corr_levels, + corr_radius=corr_radius, + hidden_size=hidden_size, + ) + + self.iters = iters + + def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None): + """ + Forward pass of the TrackHead. + + Args: + aggregated_tokens_list (list): List of aggregated tokens from the backbone. + images (torch.Tensor): Input images of shape (B, S, C, H, W) where: + B = batch size, S = sequence length. + patch_start_idx (int): Starting index for patch tokens. + query_points (torch.Tensor, optional): Initial query points to track. + If None, points are initialized by the tracker. + iters (int, optional): Number of refinement iterations. If None, uses self.iters. + + Returns: + tuple: + - coord_preds (torch.Tensor): Predicted coordinates for tracked points. + - vis_scores (torch.Tensor): Visibility scores for tracked points. + - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True). + """ + B, S, _, H, W = images.shape + + # Extract features from tokens + # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2 + feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx) + + # Use default iterations if not specified + if iters is None: + iters = self.iters + + # Perform tracking using the extracted features + coord_preds, vis_scores, conf_scores = self.tracker(query_points=query_points, fmaps=feature_maps, iters=iters) + + return coord_preds, vis_scores, conf_scores diff --git a/capvector-pi05/src/vggt/heads/track_modules/__init__.py b/capvector-pi05/src/vggt/heads/track_modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c4196294309799347172dba54a17360698071ca8 --- /dev/null +++ b/capvector-pi05/src/vggt/heads/track_modules/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/capvector-pi05/src/vggt/heads/track_modules/base_track_predictor.py b/capvector-pi05/src/vggt/heads/track_modules/base_track_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..540c1d110d4b35b36fdbd2a8a81121d9f9cf2f9b --- /dev/null +++ b/capvector-pi05/src/vggt/heads/track_modules/base_track_predictor.py @@ -0,0 +1,209 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from einops import rearrange, repeat + + +from .blocks import EfficientUpdateFormer, CorrBlock +from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed +from .modules import Mlp + + +class BaseTrackerPredictor(nn.Module): + def __init__( + self, + stride=1, + corr_levels=5, + corr_radius=4, + latent_dim=128, + hidden_size=384, + use_spaceatt=True, + depth=6, + max_scale=518, + predict_conf=True, + ): + super(BaseTrackerPredictor, self).__init__() + """ + The base template to create a track predictor + + Modified from https://github.com/facebookresearch/co-tracker/ + and https://github.com/facebookresearch/vggsfm + """ + + self.stride = stride + self.latent_dim = latent_dim + self.corr_levels = corr_levels + self.corr_radius = corr_radius + self.hidden_size = hidden_size + self.max_scale = max_scale + self.predict_conf = predict_conf + + self.flows_emb_dim = latent_dim // 2 + + self.corr_mlp = Mlp( + in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2, + hidden_features=self.hidden_size, + out_features=self.latent_dim, + ) + + self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4 + + self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim)) + + space_depth = depth if use_spaceatt else 0 + time_depth = depth + + self.updateformer = EfficientUpdateFormer( + space_depth=space_depth, + time_depth=time_depth, + input_dim=self.transformer_dim, + hidden_size=self.hidden_size, + output_dim=self.latent_dim + 2, + mlp_ratio=4.0, + add_space_attn=use_spaceatt, + ) + + self.fmap_norm = nn.LayerNorm(self.latent_dim) + self.ffeat_norm = nn.GroupNorm(1, self.latent_dim) + + # A linear layer to update track feats at each iteration + self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU()) + + self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) + + if predict_conf: + self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) + + def forward(self, query_points, fmaps=None, iters=6, return_feat=False, down_ratio=1, apply_sigmoid=True): + """ + query_points: B x N x 2, the number of batches, tracks, and xy + fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension. + note HH and WW is the size of feature maps instead of original images + """ + B, N, D = query_points.shape + B, S, C, HH, WW = fmaps.shape + + assert D == 2, "Input points must be 2D coordinates" + + # apply a layernorm to fmaps here + fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2)) + fmaps = fmaps.permute(0, 1, 4, 2, 3) + + # Scale the input query_points because we may downsample the images + # by down_ratio or self.stride + # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map + # its query_points should be query_points/4 + if down_ratio > 1: + query_points = query_points / float(down_ratio) + + query_points = query_points / float(self.stride) + + # Init with coords as the query points + # It means the search will start from the position of query points at the reference frames + coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1) + + # Sample/extract the features of the query points in the query frame + query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0]) + + # init track feats by query feats + track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C + # back up the init coords + coords_backup = coords.clone() + + fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius) + + coord_preds = [] + + # Iterative Refinement + for _ in range(iters): + # Detach the gradients from the last iteration + # (in my experience, not very important for performance) + coords = coords.detach() + + fcorrs = fcorr_fn.corr_sample(track_feats, coords) + + corr_dim = fcorrs.shape[3] + fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim) + fcorrs_ = self.corr_mlp(fcorrs_) + + # Movement of current coords relative to query points + flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2) + + flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False) + + # (In my trials, it is also okay to just add the flows_emb instead of concat) + flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1) + + track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim) + + # Concatenate them as the input for the transformers + transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2) + + # 2D positional embed + # TODO: this can be much simplified + pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device) + sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0]) + + sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1) + + x = transformer_input + sampled_pos_emb + + # Add the query ref token to the track feats + query_ref_token = torch.cat( + [self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1 + ) + x = x + query_ref_token.to(x.device).to(x.dtype) + + # B, N, S, C + x = rearrange(x, "(b n) s d -> b n s d", b=B) + + # Compute the delta coordinates and delta track features + delta, _ = self.updateformer(x) + + # BN, S, C + delta = rearrange(delta, " b n s d -> (b n) s d", b=B) + delta_coords_ = delta[:, :, :2] + delta_feats_ = delta[:, :, 2:] + + track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim) + delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim) + + # Update the track features + track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_ + + track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC + + # B x S x N x 2 + coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3) + + # Force coord0 as query + # because we assume the query points should not be changed + coords[:, 0] = coords_backup[:, 0] + + # The predicted tracks are in the original image scale + if down_ratio > 1: + coord_preds.append(coords * self.stride * down_ratio) + else: + coord_preds.append(coords * self.stride) + + # B, S, N + vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N) + if apply_sigmoid: + vis_e = torch.sigmoid(vis_e) + + if self.predict_conf: + conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N) + if apply_sigmoid: + conf_e = torch.sigmoid(conf_e) + else: + conf_e = None + + if return_feat: + return coord_preds, vis_e, track_feats, query_track_feat, conf_e + else: + return coord_preds, vis_e, conf_e diff --git a/capvector-pi05/src/vggt/heads/track_modules/blocks.py b/capvector-pi05/src/vggt/heads/track_modules/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..394c31d120a716bee1e82911c841c0f63d9965d3 --- /dev/null +++ b/capvector-pi05/src/vggt/heads/track_modules/blocks.py @@ -0,0 +1,236 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +# Modified from https://github.com/facebookresearch/co-tracker/ + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils import bilinear_sampler +from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock + + +class EfficientUpdateFormer(nn.Module): + """ + Transformer model that updates track estimates. + """ + + def __init__( + self, + space_depth=6, + time_depth=6, + input_dim=320, + hidden_size=384, + num_heads=8, + output_dim=130, + mlp_ratio=4.0, + add_space_attn=True, + num_virtual_tracks=64, + ): + super().__init__() + + self.out_channels = 2 + self.num_heads = num_heads + self.hidden_size = hidden_size + self.add_space_attn = add_space_attn + + # Add input LayerNorm before linear projection + self.input_norm = nn.LayerNorm(input_dim) + self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) + + # Add output LayerNorm before final projection + self.output_norm = nn.LayerNorm(hidden_size) + self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True) + self.num_virtual_tracks = num_virtual_tracks + + if self.add_space_attn: + self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size)) + else: + self.virual_tracks = None + + self.time_blocks = nn.ModuleList( + [ + AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention) + for _ in range(time_depth) + ] + ) + + if add_space_attn: + self.space_virtual_blocks = nn.ModuleList( + [ + AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention) + for _ in range(space_depth) + ] + ) + self.space_point2virtual_blocks = nn.ModuleList( + [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] + ) + self.space_virtual2point_blocks = nn.ModuleList( + [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] + ) + assert len(self.time_blocks) >= len(self.space_virtual2point_blocks) + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001) + + self.apply(_basic_init) + + def forward(self, input_tensor, mask=None): + # Apply input LayerNorm + input_tensor = self.input_norm(input_tensor) + tokens = self.input_transform(input_tensor) + + init_tokens = tokens + + B, _, T, _ = tokens.shape + + if self.add_space_attn: + virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1) + tokens = torch.cat([tokens, virtual_tokens], dim=1) + + _, N, _, _ = tokens.shape + + j = 0 + for i in range(len(self.time_blocks)): + time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C + + time_tokens = self.time_blocks[i](time_tokens) + + tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C + if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0): + space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C + point_tokens = space_tokens[:, : N - self.num_virtual_tracks] + virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :] + + virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask) + virtual_tokens = self.space_virtual_blocks[j](virtual_tokens) + point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask) + + space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1) + tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C + j += 1 + + if self.add_space_attn: + tokens = tokens[:, : N - self.num_virtual_tracks] + + tokens = tokens + init_tokens + + # Apply output LayerNorm before final projection + tokens = self.output_norm(tokens) + flow = self.flow_head(tokens) + + return flow, None + + +class CorrBlock: + def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"): + """ + Build a pyramid of feature maps from the input. + + fmaps: Tensor (B, S, C, H, W) + num_levels: number of pyramid levels (each downsampled by factor 2) + radius: search radius for sampling correlation + multiple_track_feats: if True, split the target features per pyramid level + padding_mode: passed to grid_sample / bilinear_sampler + """ + B, S, C, H, W = fmaps.shape + self.S, self.C, self.H, self.W = S, C, H, W + self.num_levels = num_levels + self.radius = radius + self.padding_mode = padding_mode + self.multiple_track_feats = multiple_track_feats + + # Build pyramid: each level is half the spatial resolution of the previous + self.fmaps_pyramid = [fmaps] # level 0 is full resolution + current_fmaps = fmaps + for i in range(num_levels - 1): + B, S, C, H, W = current_fmaps.shape + # Merge batch & sequence dimensions + current_fmaps = current_fmaps.reshape(B * S, C, H, W) + # Avg pool down by factor 2 + current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2) + _, _, H_new, W_new = current_fmaps.shape + current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new) + self.fmaps_pyramid.append(current_fmaps) + + # Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling. + # This grid is added to the (scaled) coordinate centroids. + r = self.radius + dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype) + dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype) + # delta: for every (dy,dx) displacement (i.e. Δx, Δy) + self.delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), dim=-1) # shape: (2r+1, 2r+1, 2) + + def corr_sample(self, targets, coords): + """ + Instead of storing the entire correlation pyramid, we compute each level's correlation + volume, sample it immediately, then discard it. This saves GPU memory. + + Args: + targets: Tensor (B, S, N, C) — features for the current targets. + coords: Tensor (B, S, N, 2) — coordinates at full resolution. + + Returns: + Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations) + """ + B, S, N, C = targets.shape + + # If you have multiple track features, split them per level. + if self.multiple_track_feats: + targets_split = torch.split(targets, C // self.num_levels, dim=-1) + + out_pyramid = [] + for i, fmaps in enumerate(self.fmaps_pyramid): + # Get current spatial resolution H, W for this pyramid level. + B, S, C, H, W = fmaps.shape + # Reshape feature maps for correlation computation: + # fmap2s: (B, S, C, H*W) + fmap2s = fmaps.view(B, S, C, H * W) + # Choose appropriate target features. + fmap1 = targets_split[i] if self.multiple_track_feats else targets # shape: (B, S, N, C) + + # Compute correlation directly + corrs = compute_corr_level(fmap1, fmap2s, C) + corrs = corrs.view(B, S, N, H, W) + + # Prepare sampling grid: + # Scale down the coordinates for the current level. + centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i) + # Make sure our precomputed delta grid is on the same device/dtype. + delta_lvl = self.delta.to(coords.device).to(coords.dtype) + # Now the grid for grid_sample is: + # coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid) + coords_lvl = centroid_lvl + delta_lvl.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2) + + # Sample from the correlation volume using bilinear interpolation. + # We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target. + corrs_sampled = bilinear_sampler( + corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode + ) + # The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims. + corrs_sampled = corrs_sampled.view(B, S, N, -1) # Now shape: (B, S, N, (2r+1)^2) + out_pyramid.append(corrs_sampled) + + # Concatenate all levels along the last dimension. + out = torch.cat(out_pyramid, dim=-1).contiguous() + return out + + +def compute_corr_level(fmap1, fmap2s, C): + # fmap1: (B, S, N, C) + # fmap2s: (B, S, C, H*W) + corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W) + corrs = corrs.view(fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1) # (B, S, N, H*W) + return corrs / math.sqrt(C) diff --git a/capvector-pi05/src/vggt/heads/track_modules/modules.py b/capvector-pi05/src/vggt/heads/track_modules/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..84a9f64bda7d749f01b9b9243b13659461008355 --- /dev/null +++ b/capvector-pi05/src/vggt/heads/track_modules/modules.py @@ -0,0 +1,204 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +from typing import Callable +import collections +from torch import Tensor +from itertools import repeat + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +to_2tuple = _ntuple(2) + + +class ResidualBlock(nn.Module): + """ + ResidualBlock: construct a block of two conv layers with residual connections + """ + + def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=kernel_size, padding=1, stride=stride, padding_mode="zeros" + ) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, padding=1, padding_mode="zeros") + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + else: + raise NotImplementedError + + if stride == 1: + self.downsample = None + else: + self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class AttnBlock(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + attn_class: Callable[..., nn.Module] = nn.MultiheadAttention, + mlp_ratio=4.0, + **block_kwargs, + ): + """ + Self attention block + """ + super().__init__() + + self.norm1 = nn.LayerNorm(hidden_size) + self.norm2 = nn.LayerNorm(hidden_size) + + self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs) + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) + + def forward(self, x, mask=None): + # Prepare the mask for PyTorch's attention (it expects a different format) + # attn_mask = mask if mask is not None else None + # Normalize before attention + x = self.norm1(x) + + # PyTorch's MultiheadAttention returns attn_output, attn_output_weights + # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask) + + attn_output, _ = self.attn(x, x, x) + + # Add & Norm + x = x + attn_output + x = x + self.mlp(self.norm2(x)) + return x + + +class CrossAttnBlock(nn.Module): + def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs): + """ + Cross attention block + """ + super().__init__() + + self.norm1 = nn.LayerNorm(hidden_size) + self.norm_context = nn.LayerNorm(hidden_size) + self.norm2 = nn.LayerNorm(hidden_size) + + self.cross_attn = nn.MultiheadAttention( + embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs + ) + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) + + def forward(self, x, context, mask=None): + # Normalize inputs + x = self.norm1(x) + context = self.norm_context(context) + + # Apply cross attention + # Note: nn.MultiheadAttention returns attn_output, attn_output_weights + attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask) + + # Add & Norm + x = x + attn_output + x = x + self.mlp(self.norm2(x)) + return x diff --git a/capvector-pi05/src/vggt/heads/track_modules/utils.py b/capvector-pi05/src/vggt/heads/track_modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3fc9486f5d070c882273de1165e6f1322b6c5ce4 --- /dev/null +++ b/capvector-pi05/src/vggt/heads/track_modules/utils.py @@ -0,0 +1,223 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from https://github.com/facebookresearch/vggsfm +# and https://github.com/facebookresearch/co-tracker/tree/main + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Union + + +def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor: + """ + This function initializes a grid and generates a 2D positional embedding using sine and cosine functions. + It is a wrapper of get_2d_sincos_pos_embed_from_grid. + Args: + - embed_dim: The embedding dimension. + - grid_size: The grid size. + Returns: + - pos_embed: The generated 2D positional embedding. + """ + if isinstance(grid_size, tuple): + grid_size_h, grid_size_w = grid_size + else: + grid_size_h = grid_size_w = grid_size + grid_h = torch.arange(grid_size_h, dtype=torch.float) + grid_w = torch.arange(grid_size_w, dtype=torch.float) + grid = torch.meshgrid(grid_w, grid_h, indexing="xy") + grid = torch.stack(grid, dim=0) + grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if return_grid: + return (pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2), grid) + return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2) + + +def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor: + """ + This function generates a 2D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - grid: The grid to generate the embedding from. + + Returns: + - emb: The generated 2D positional embedding. + """ + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor: + """ + This function generates a 1D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - pos: The position to generate the embedding from. + + Returns: + - emb: The generated 1D positional embedding. + """ + assert embed_dim % 2 == 0 + omega = torch.arange(embed_dim // 2, dtype=torch.double) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + return emb[None].float() + + +def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor: + """ + This function generates a 2D positional embedding from given coordinates using sine and cosine functions. + + Args: + - xy: The coordinates to generate the embedding from. + - C: The size of the embedding. + - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding. + + Returns: + - pe: The generated 2D positional embedding. + """ + B, N, D = xy.shape + assert D == 2 + + x = xy[:, :, 0:1] + y = xy[:, :, 1:2] + div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2)) + + pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) + pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) + + pe_x[:, :, 0::2] = torch.sin(x * div_term) + pe_x[:, :, 1::2] = torch.cos(x * div_term) + + pe_y[:, :, 0::2] = torch.sin(y * div_term) + pe_y[:, :, 1::2] = torch.cos(y * div_term) + + pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3) + if cat_coords: + pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3) + return pe + + +def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"): + r"""Sample a tensor using bilinear interpolation + + `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at + coordinates :attr:`coords` using bilinear interpolation. It is the same + as `torch.nn.functional.grid_sample()` but with a different coordinate + convention. + + The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where + :math:`B` is the batch size, :math:`C` is the number of channels, + :math:`H` is the height of the image, and :math:`W` is the width of the + image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is + interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`. + + Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`, + in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note + that in this case the order of the components is slightly different + from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`. + + If `align_corners` is `True`, the coordinate :math:`x` is assumed to be + in the range :math:`[0,W-1]`, with 0 corresponding to the center of the + left-most image pixel :math:`W-1` to the center of the right-most + pixel. + + If `align_corners` is `False`, the coordinate :math:`x` is assumed to + be in the range :math:`[0,W]`, with 0 corresponding to the left edge of + the left-most pixel :math:`W` to the right edge of the right-most + pixel. + + Similar conventions apply to the :math:`y` for the range + :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range + :math:`[0,T-1]` and :math:`[0,T]`. + + Args: + input (Tensor): batch of input images. + coords (Tensor): batch of coordinates. + align_corners (bool, optional): Coordinate convention. Defaults to `True`. + padding_mode (str, optional): Padding mode. Defaults to `"border"`. + + Returns: + Tensor: sampled points. + """ + coords = coords.detach().clone() + ############################################################ + # IMPORTANT: + coords = coords.to(input.device).to(input.dtype) + ############################################################ + + sizes = input.shape[2:] + + assert len(sizes) in [2, 3] + + if len(sizes) == 3: + # t x y -> x y t to match dimensions T H W in grid_sample + coords = coords[..., [1, 2, 0]] + + if align_corners: + scale = torch.tensor( + [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype + ) + else: + scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype) + + coords.mul_(scale) # coords = coords * scale + coords.sub_(1) # coords = coords - 1 + + return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode) + + +def sample_features4d(input, coords): + r"""Sample spatial features + + `sample_features4d(input, coords)` samples the spatial features + :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`. + + The field is sampled at coordinates :attr:`coords` using bilinear + interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R, + 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the + same convention as :func:`bilinear_sampler` with `align_corners=True`. + + The output tensor has one feature per point, and has shape :math:`(B, + R, C)`. + + Args: + input (Tensor): spatial features. + coords (Tensor): points. + + Returns: + Tensor: sampled features. + """ + + B, _, _, _ = input.shape + + # B R 2 -> B R 1 2 + coords = coords.unsqueeze(2) + + # B C R 1 + feats = bilinear_sampler(input, coords) + + return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C diff --git a/capvector-pi05/src/vggt/heads/utils.py b/capvector-pi05/src/vggt/heads/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1804227cff9d9fde67712bbb7e5d64b4be88d6cf --- /dev/null +++ b/capvector-pi05/src/vggt/heads/utils.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import numpy as np +from typing import List, Dict, Tuple, Union +from einops import rearrange + +def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor: + """ + Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC) + + Args: + pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates + embed_dim: Output channel dimension for embeddings + + Returns: + Tensor of shape (H, W, embed_dim) with positional embeddings + """ + H, W, grid_dim = pos_grid.shape + assert grid_dim == 2 + pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2) + + # Process x and y coordinates separately + emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2] + emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2] + + # Combine and reshape + emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D] + + return emb.view(H, W, embed_dim) # [H, W, D] + + +def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor: + """ + This function generates a 1D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - pos: The position to generate the embedding from. + + Returns: + - emb: The generated 1D positional embedding. + """ + assert embed_dim % 2 == 0 + device = pos.device + omega = torch.arange(embed_dim // 2, dtype=torch.float32 if device.type == "mps" else torch.double, device=device) + omega /= embed_dim / 2.0 + omega = 1.0 / omega_0**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + return emb.float() + + +# Inspired by https://github.com/microsoft/moge + + +def create_uv_grid( + width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None +) -> torch.Tensor: + """ + Create a normalized UV grid of shape (width, height, 2). + + The grid spans horizontally and vertically according to an aspect ratio, + ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right + corner is at (x_span, y_span), normalized by the diagonal of the plane. + + Args: + width (int): Number of points horizontally. + height (int): Number of points vertically. + aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height. + dtype (torch.dtype, optional): Data type of the resulting tensor. + device (torch.device, optional): Device on which the tensor is created. + + Returns: + torch.Tensor: A (width, height, 2) tensor of UV coordinates. + """ + # Derive aspect ratio if not explicitly provided + if aspect_ratio is None: + aspect_ratio = float(width) / float(height) + + # Compute normalized spans for X and Y + diag_factor = (aspect_ratio**2 + 1.0) ** 0.5 + span_x = aspect_ratio / diag_factor + span_y = 1.0 / diag_factor + + # Establish the linspace boundaries + left_x = -span_x * (width - 1) / width + right_x = span_x * (width - 1) / width + top_y = -span_y * (height - 1) / height + bottom_y = span_y * (height - 1) / height + + # Generate 1D coordinates + x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device) + y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device) + + # Create 2D meshgrid (width x height) and stack into UV + uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy") + uv_grid = torch.stack((uu, vv), dim=-1) + + return uv_grid + + +def _interpolate( + x: torch.Tensor, + size: Tuple[int, int] = None, + scale_factor: float = None, + mode: str = "bilinear", + align_corners: bool = True, +) -> torch.Tensor: + """ + Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate. + """ + if size is None: + size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor)) + + INT_MAX = 1610612736 + + input_elements = size[0] * size[1] * x.shape[0] * x.shape[1] + + if input_elements > INT_MAX: + chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0) + interpolated_chunks = [ + nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks + ] + x = torch.cat(interpolated_chunks, dim=0) + return x.contiguous() + else: + return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners) + + +def _apply_pos_embed(x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor: + """ + Apply positional embedding to tensor x. + """ + patch_w = x.shape[-1] + patch_h = x.shape[-2] + pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device) + pos_embed = position_grid_to_embed(pos_embed, x.shape[1]) + pos_embed = pos_embed * ratio + pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1) + return x + pos_embed + + +def interpolate_pooling(hidden, patch_hw, img_hw, reference, pooling_func, use_vggt_pe): + (patch_h, patch_w) = patch_hw + (img_h, img_w) = img_hw + bs, N, S, D = hidden.shape + re_sample_ratio = 1 / np.sqrt(N * S / reference.shape[1]) + + _hidden = hidden.permute(0, 1, 3, 2) + _hidden = _hidden.reshape(bs*N, D, patch_h, patch_w) + if use_vggt_pe: + _hidden = _apply_pos_embed(_hidden, img_w, img_h) + hidden_pooling = _interpolate( + _hidden, scale_factor=re_sample_ratio, mode=pooling_func, align_corners=True + ) + hidden_pooling = hidden_pooling.reshape(bs, N, D, -1).permute(0, 1, 3, 2).reshape(bs, -1, D) + return hidden_pooling + + +def custom_pooling(hidden, patch_hw, img_hw, reference, pooling_func, use_vggt_pe): + if pooling_func in ['bilinear']: + return interpolate_pooling(hidden, patch_hw, img_hw, reference, pooling_func, use_vggt_pe) + else: + raise NotImplementedError(f"Pooling function {pooling_func} is not implemented.") diff --git a/capvector-pi05/src/vggt/layers/__init__.py b/capvector-pi05/src/vggt/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e59a83eb90512d763b03e4d38536b6ae07e87541 --- /dev/null +++ b/capvector-pi05/src/vggt/layers/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/capvector-pi05/src/vggt/layers/attention.py b/capvector-pi05/src/vggt/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..27329716a95a1c3e70a12e74b3be5fe79f2663f9 --- /dev/null +++ b/capvector-pi05/src/vggt/layers/attention.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import os +import warnings + +from torch import Tensor +from torch import nn +import torch.nn.functional as F + +XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + qk_norm: bool = False, + fused_attn: bool = True, # use F.scaled_dot_product_attention or not + rope=None, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.fused_attn = fused_attn + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + self.rope = rope + + def forward(self, x: Tensor, pos=None) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.rope is not None: + q = self.rope(q, pos) + k = self.rope(k, pos) + + if self.fused_attn: + x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor: + assert pos is None + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/capvector-pi05/src/vggt/layers/block.py b/capvector-pi05/src/vggt/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..15fa99ce76d14b6b9c2e98c1031fa35a3046a429 --- /dev/null +++ b/capvector-pi05/src/vggt/layers/block.py @@ -0,0 +1,247 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +import os +from typing import Callable, List, Any, Tuple, Dict +import warnings + +import torch +from torch import nn, Tensor + +from .attention import Attention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +XFORMERS_AVAILABLE = False + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + qk_norm: bool = False, + fused_attn: bool = True, # use F.scaled_dot_product_attention or not + rope=None, + ) -> None: + super().__init__() + + self.norm1 = norm_layer(dim) + + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + qk_norm=qk_norm, + fused_attn=fused_attn, + rope=rope, + ) + + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, bias=ffn_bias + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor, pos=None) -> Tensor: + def attn_residual_func(x: Tensor, pos=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), pos=pos)) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, pos=pos, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio + ) + x = drop_add_residual_stochastic_depth( + x, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x, pos=pos)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x, pos=pos) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, residual_func: Callable[[Tensor], Tensor], sample_drop_ratio: float = 0.0, pos=None +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + if pos is not None: + # if necessary, apply rope to the subset + pos = pos[brange] + residual = residual_func(x_subset, pos=pos) + else: + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=(self.ls1.gamma if isinstance(self.ls1, LayerScale) else None), + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=(self.ls2.gamma if isinstance(self.ls1, LayerScale) else None), + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/capvector-pi05/src/vggt/layers/drop_path.py b/capvector-pi05/src/vggt/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..4bb1487b0eed4cb14dc0d5d1ee57a2acc78de34a --- /dev/null +++ b/capvector-pi05/src/vggt/layers/drop_path.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/capvector-pi05/src/vggt/layers/layer_scale.py b/capvector-pi05/src/vggt/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..9047736a9fcd57a091aac8d42a8c07cc348cd1b3 --- /dev/null +++ b/capvector-pi05/src/vggt/layers/layer_scale.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/capvector-pi05/src/vggt/layers/mlp.py b/capvector-pi05/src/vggt/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..0965768a9aef04ac6b81322f4dd60cf035159e91 --- /dev/null +++ b/capvector-pi05/src/vggt/layers/mlp.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/capvector-pi05/src/vggt/layers/patch_embed.py b/capvector-pi05/src/vggt/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..7244ad8e3b956417f52b4bcea1aefb3796fc7e59 --- /dev/null +++ b/capvector-pi05/src/vggt/layers/patch_embed.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1]) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/capvector-pi05/src/vggt/layers/rope.py b/capvector-pi05/src/vggt/layers/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..107ff7a2267e936bc01c4dbd576e1da4f038f904 --- /dev/null +++ b/capvector-pi05/src/vggt/layers/rope.py @@ -0,0 +1,188 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + + +# Implementation of 2D Rotary Position Embeddings (RoPE). + +# This module provides a clean implementation of 2D Rotary Position Embeddings, +# which extends the original RoPE concept to handle 2D spatial positions. + +# Inspired by: +# https://github.com/meta-llama/codellama/blob/main/llama/model.py +# https://github.com/naver-ai/rope-vit + + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Dict, Tuple + + +class PositionGetter: + """Generates and caches 2D spatial positions for patches in a grid. + + This class efficiently manages the generation of spatial coordinates for patches + in a 2D grid, caching results to avoid redundant computations. + + Attributes: + position_cache: Dictionary storing precomputed position tensors for different + grid dimensions. + """ + + def __init__(self): + """Initializes the position generator with an empty cache.""" + self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {} + + def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor: + """Generates spatial positions for a batch of patches. + + Args: + batch_size: Number of samples in the batch. + height: Height of the grid in patches. + width: Width of the grid in patches. + device: Target device for the position tensor. + + Returns: + Tensor of shape (batch_size, height*width, 2) containing y,x coordinates + for each position in the grid, repeated for each batch item. + """ + if (height, width) not in self.position_cache: + y_coords = torch.arange(height, device=device) + x_coords = torch.arange(width, device=device) + positions = torch.cartesian_prod(y_coords, x_coords) + self.position_cache[height, width] = positions + + cached_positions = self.position_cache[height, width] + return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone() + + +class RotaryPositionEmbedding2D(nn.Module): + """2D Rotary Position Embedding implementation. + + This module applies rotary position embeddings to input tokens based on their + 2D spatial positions. It handles the position-dependent rotation of features + separately for vertical and horizontal dimensions. + + Args: + frequency: Base frequency for the position embeddings. Default: 100.0 + scaling_factor: Scaling factor for frequency computation. Default: 1.0 + + Attributes: + base_frequency: Base frequency for computing position embeddings. + scaling_factor: Factor to scale the computed frequencies. + frequency_cache: Cache for storing precomputed frequency components. + """ + + def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0): + """Initializes the 2D RoPE module.""" + super().__init__() + self.base_frequency = frequency + self.scaling_factor = scaling_factor + self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {} + + def _compute_frequency_components( + self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Computes frequency components for rotary embeddings. + + Args: + dim: Feature dimension (must be even). + seq_len: Maximum sequence length. + device: Target device for computations. + dtype: Data type for the computed tensors. + + Returns: + Tuple of (cosine, sine) tensors for frequency components. + """ + cache_key = (dim, seq_len, device, dtype) + if cache_key not in self.frequency_cache: + # Compute frequency bands + exponents = torch.arange(0, dim, 2, device=device).float() / dim + inv_freq = 1.0 / (self.base_frequency**exponents) + + # Generate position-dependent frequencies + positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + angles = torch.einsum("i,j->ij", positions, inv_freq) + + # Compute and cache frequency components + angles = angles.to(dtype) + angles = torch.cat((angles, angles), dim=-1) + cos_components = angles.cos().to(dtype) + sin_components = angles.sin().to(dtype) + self.frequency_cache[cache_key] = (cos_components, sin_components) + + return self.frequency_cache[cache_key] + + @staticmethod + def _rotate_features(x: torch.Tensor) -> torch.Tensor: + """Performs feature rotation by splitting and recombining feature dimensions. + + Args: + x: Input tensor to rotate. + + Returns: + Rotated feature tensor. + """ + feature_dim = x.shape[-1] + x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def _apply_1d_rope( + self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor + ) -> torch.Tensor: + """Applies 1D rotary position embeddings along one dimension. + + Args: + tokens: Input token features. + positions: Position indices. + cos_comp: Cosine components for rotation. + sin_comp: Sine components for rotation. + + Returns: + Tokens with applied rotary position embeddings. + """ + # Embed positions with frequency components + cos = F.embedding(positions, cos_comp)[:, None, :, :] + sin = F.embedding(positions, sin_comp)[:, None, :, :] + + # Apply rotation + return (tokens * cos) + (self._rotate_features(tokens) * sin) + + def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + """Applies 2D rotary position embeddings to input tokens. + + Args: + tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim). + The feature dimension (dim) must be divisible by 4. + positions: Position tensor of shape (batch_size, n_tokens, 2) containing + the y and x coordinates for each token. + + Returns: + Tensor of same shape as input with applied 2D rotary position embeddings. + + Raises: + AssertionError: If input dimensions are invalid or positions are malformed. + """ + # Validate inputs + assert tokens.size(-1) % 2 == 0, "Feature dimension must be even" + assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)" + + # Compute feature dimension for each spatial direction + feature_dim = tokens.size(-1) // 2 + + # Get frequency components + max_position = int(positions.max()) + 1 + cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype) + + # Split features for vertical and horizontal processing + vertical_features, horizontal_features = tokens.chunk(2, dim=-1) + + # Apply RoPE separately for each dimension + vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp) + horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp) + + # Combine processed features + return torch.cat((vertical_features, horizontal_features), dim=-1) diff --git a/capvector-pi05/src/vggt/layers/swiglu_ffn.py b/capvector-pi05/src/vggt/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..9c6b6c74f97e61041ecef912ea21c2d259335aa7 --- /dev/null +++ b/capvector-pi05/src/vggt/layers/swiglu_ffn.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Callable, Optional +import warnings + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +# try: +# if XFORMERS_ENABLED: +# from xformers.ops import SwiGLU + +# XFORMERS_AVAILABLE = True +# warnings.warn("xFormers is available (SwiGLU)") +# else: +# warnings.warn("xFormers is disabled (SwiGLU)") +# raise ImportError +# except ImportError: +SwiGLU = SwiGLUFFN +XFORMERS_AVAILABLE = False + +# warnings.warn("xFormers is not available (SwiGLU)") + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__(in_features=in_features, hidden_features=hidden_features, out_features=out_features, bias=bias) diff --git a/capvector-pi05/src/vggt/layers/vision_transformer.py b/capvector-pi05/src/vggt/layers/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..ced58dd042a84b44ca97ce3f25d3983f322a8e27 --- /dev/null +++ b/capvector-pi05/src/vggt/layers/vision_transformer.py @@ -0,0 +1,397 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +from torch.nn.init import trunc_normal_ +from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + qk_norm=False, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + self.use_reentrant = False # hardcoded to False + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + qk_norm=qk_norm, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if self.interpolate_offset: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sx, sy) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (w0, h0) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat((x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]), dim=1) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + + for blk in self.blocks: + if self.training: + x = checkpoint(blk, x, use_reentrant=self.use_reentrant) + else: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + if self.training: + x = checkpoint(blk, x, use_reentrant=self.use_reentrant) + else: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=True, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model diff --git a/capvector-pi05/src/vggt/models/aggregator.py b/capvector-pi05/src/vggt/models/aggregator.py new file mode 100644 index 0000000000000000000000000000000000000000..3ccc16110c82008a7652d05834a478226780a9b5 --- /dev/null +++ b/capvector-pi05/src/vggt/models/aggregator.py @@ -0,0 +1,331 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint +from typing import Optional, Tuple, Union, List, Dict, Any + +from vggt.layers import PatchEmbed +from vggt.layers.block import Block +from vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter +from vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2 + +logger = logging.getLogger(__name__) + +_RESNET_MEAN = [0.485, 0.456, 0.406] +_RESNET_STD = [0.229, 0.224, 0.225] + + +class Aggregator(nn.Module): + """ + The Aggregator applies alternating-attention over input frames, + as described in VGGT: Visual Geometry Grounded Transformer. + + Remember to set model.train() to enable gradient checkpointing to reduce memory usage. + + Args: + img_size (int): Image size in pixels. + patch_size (int): Size of each patch for PatchEmbed. + embed_dim (int): Dimension of the token embeddings. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + mlp_ratio (float): Ratio of MLP hidden dim to embedding dim. + num_register_tokens (int): Number of register tokens. + block_fn (nn.Module): The block type used for attention (Block by default). + qkv_bias (bool): Whether to include bias in QKV projections. + proj_bias (bool): Whether to include bias in the output projection. + ffn_bias (bool): Whether to include bias in MLP layers. + patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg". + aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"]. + aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1. + qk_norm (bool): Whether to apply QK normalization. + rope_freq (int): Base frequency for rotary embedding. -1 to disable. + init_values (float): Init scale for layer scale. + """ + + def __init__( + self, + img_size=518, + patch_size=14, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4.0, + num_register_tokens=4, + block_fn=Block, + qkv_bias=True, + proj_bias=True, + ffn_bias=True, + patch_embed="dinov2_vitl14_reg", + aa_order=["frame", "global"], + aa_block_size=1, + qk_norm=True, + rope_freq=100, + init_values=0.01, + ): + super().__init__() + + self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim) + + # Initialize rotary position embedding if frequency > 0 + self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None + self.position_getter = PositionGetter() if self.rope is not None else None + + self.frame_blocks = nn.ModuleList( + [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + init_values=init_values, + qk_norm=qk_norm, + rope=self.rope, + ) + for _ in range(depth) + ] + ) + + self.global_blocks = nn.ModuleList( + [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + init_values=init_values, + qk_norm=qk_norm, + rope=self.rope, + ) + for _ in range(depth) + ] + ) + + self.depth = depth + self.aa_order = aa_order + self.patch_size = patch_size + self.aa_block_size = aa_block_size + + # Validate that depth is divisible by aa_block_size + if self.depth % self.aa_block_size != 0: + raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})") + + self.aa_block_num = self.depth // self.aa_block_size + + # Note: We have two camera tokens, one for the first frame and one for the rest + # The same applies for register tokens + self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim)) + self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim)) + + # The patch tokens start after the camera and register tokens + self.patch_start_idx = 1 + num_register_tokens + + # Initialize parameters with small values + nn.init.normal_(self.camera_token, std=1e-6) + nn.init.normal_(self.register_token, std=1e-6) + + # Register normalization constants as buffers + for name, value in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)): + self.register_buffer(name, torch.FloatTensor(value).view(1, 1, 3, 1, 1), persistent=False) + + self.use_reentrant = False # hardcoded to False + + def __build_patch_embed__( + self, + patch_embed, + img_size, + patch_size, + num_register_tokens, + interpolate_antialias=True, + interpolate_offset=0.0, + block_chunks=0, + init_values=1.0, + embed_dim=1024, + ): + """ + Build the patch embed layer. If 'conv', we use a + simple PatchEmbed conv layer. Otherwise, we use a vision transformer. + """ + + if "conv" in patch_embed: + self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim) + else: + vit_models = { + "dinov2_vitl14_reg": vit_large, + "dinov2_vitb14_reg": vit_base, + "dinov2_vits14_reg": vit_small, + "dinov2_vitg2_reg": vit_giant2, + } + + self.patch_embed = vit_models[patch_embed]( + img_size=img_size, + patch_size=patch_size, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + block_chunks=block_chunks, + init_values=init_values, + ) + + # Disable gradient updates for mask token + if hasattr(self.patch_embed, "mask_token"): + self.patch_embed.mask_token.requires_grad_(False) + + def forward(self, images: torch.Tensor) -> Tuple[List[torch.Tensor], int]: + """ + Args: + images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. + B: batch size, S: sequence length, 3: RGB channels, H: height, W: width + + Returns: + (list[torch.Tensor], int): + The list of outputs from the attention blocks, + and the patch_start_idx indicating where patch tokens begin. + """ + B, S, C_in, H, W = images.shape + + if C_in != 3: + raise ValueError(f"Expected 3 input channels, got {C_in}") + + # Normalize images and reshape for patch embed + images = (images - self._resnet_mean) / self._resnet_std + + # Reshape to [B*S, C, H, W] for patch embedding + images = images.view(B * S, C_in, H, W) + patch_tokens = self.patch_embed(images) + + if isinstance(patch_tokens, dict): + patch_tokens = patch_tokens["x_norm_patchtokens"] + + _, P, C = patch_tokens.shape + + # Expand camera and register tokens to match batch size and sequence length + camera_token = slice_expand_and_flatten(self.camera_token, B, S) + register_token = slice_expand_and_flatten(self.register_token, B, S) + + # Concatenate special tokens with patch tokens + tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1) + + pos = None + if self.rope is not None: + pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device) + + if self.patch_start_idx > 0: + # do not use position embedding for special tokens (camera and register tokens) + # so set pos to 0 for the special tokens + pos = pos + 1 + pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype) + pos = torch.cat([pos_special, pos], dim=1) + + # update P because we added special tokens + _, P, C = tokens.shape + + frame_idx = 0 + global_idx = 0 + output_list = [] + + for _ in range(self.aa_block_num): + for attn_type in self.aa_order: + if attn_type == "frame": + tokens, frame_idx, frame_intermediates = self._process_frame_attention( + tokens, B, S, P, C, frame_idx, pos=pos + ) + elif attn_type == "global": + tokens, global_idx, global_intermediates = self._process_global_attention( + tokens, B, S, P, C, global_idx, pos=pos + ) + else: + raise ValueError(f"Unknown attention type: {attn_type}") + + for i in range(len(frame_intermediates)): + # concat frame and global intermediates, [B x S x P x 2C] + concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1) + output_list.append(concat_inter) + + del concat_inter + del frame_intermediates + del global_intermediates + return output_list, self.patch_start_idx + + def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None): + """ + Process frame attention blocks. We keep tokens in shape (B*S, P, C). + """ + # If needed, reshape tokens or positions: + if tokens.shape != (B * S, P, C): + tokens = tokens.view(B, S, P, C).view(B * S, P, C) + + if pos is not None and pos.shape != (B * S, P, 2): + pos = pos.view(B, S, P, 2).view(B * S, P, 2) + + intermediates = [] + + # by default, self.aa_block_size=1, which processes one block at a time + for _ in range(self.aa_block_size): + if self.training: + tokens = checkpoint(self.frame_blocks[frame_idx], tokens, pos, use_reentrant=self.use_reentrant) + else: + tokens = self.frame_blocks[frame_idx](tokens, pos=pos) + frame_idx += 1 + intermediates.append(tokens.view(B, S, P, C)) + + return tokens, frame_idx, intermediates + + def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None): + """ + Process global attention blocks. We keep tokens in shape (B, S*P, C). + """ + if tokens.shape != (B, S * P, C): + tokens = tokens.view(B, S, P, C).view(B, S * P, C) + + if pos is not None and pos.shape != (B, S * P, 2): + pos = pos.view(B, S, P, 2).view(B, S * P, 2) + + intermediates = [] + + # by default, self.aa_block_size=1, which processes one block at a time + for _ in range(self.aa_block_size): + if self.training: + tokens = checkpoint(self.global_blocks[global_idx], tokens, pos, use_reentrant=self.use_reentrant) + else: + tokens = self.global_blocks[global_idx](tokens, pos=pos) + global_idx += 1 + intermediates.append(tokens.view(B, S, P, C)) + + return tokens, global_idx, intermediates + + +def slice_expand_and_flatten(token_tensor, B, S): + """ + Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing: + 1) Uses the first position (index=0) for the first frame only + 2) Uses the second position (index=1) for all remaining frames (S-1 frames) + 3) Expands both to match batch size B + 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token + followed by (S-1) second-position tokens + 5) Flattens to (B*S, X, C) for processing + + Returns: + torch.Tensor: Processed tokens with shape (B*S, X, C) + """ + + # Slice out the "query" tokens => shape (1, 1, ...) + query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:]) + # Slice out the "other" tokens => shape (1, S-1, ...) + others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:]) + # Concatenate => shape (B, S, ...) + combined = torch.cat([query, others], dim=1) + + # Finally flatten => shape (B*S, ...) + combined = combined.view(B * S, *combined.shape[2:]) + return combined diff --git a/capvector-pi05/src/vggt/pyproject.toml b/capvector-pi05/src/vggt/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..2e3ea16f6bc84967a90d7d319c204ca92936b8b5 --- /dev/null +++ b/capvector-pi05/src/vggt/pyproject.toml @@ -0,0 +1,52 @@ +[project] +authors = [{name = "Jianyuan Wang", email = "jianyuan@robots.ox.ac.uk"}] +dependencies = [ + "numpy<2", + "Pillow", + "huggingface_hub", + "einops", + "safetensors", + "opencv-python", +] +name = "vggt" +requires-python = ">= 3.10" +version = "0.0.1" + +[project.optional-dependencies] +demo = [ + "gradio==5.17.1", + "viser==0.2.23", + "tqdm", + "hydra-core", + "omegaconf", + "opencv-python", + "scipy", + "onnxruntime", + "requests", + "trimesh", + "matplotlib", +] + +# Using setuptools as the build backend +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +# setuptools configuration +[tool.setuptools.packages.find] +where = ["."] +include = ["vggt*"] + +# Pixi configuration +[tool.pixi.workspace] +channels = ["conda-forge"] +platforms = ["linux-64"] + +[tool.pixi.pypi-dependencies] +vggt = { path = ".", editable = true } + +[tool.pixi.environments] +default = { solve-group = "default" } +demo = { features = ["demo"], solve-group = "default" } + +[tool.pixi.tasks]