Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- capvector-pi05/src/openpi/policies/policy_test.py +34 -0
- capvector-pi05/src/openpi/serving/websocket_policy_server.py +90 -0
- capvector-pi05/src/openpi/shared/__init__.py +0 -0
- capvector-pi05/src/openpi/shared/download.py +194 -0
- capvector-pi05/src/openpi/shared/download_test.py +54 -0
- capvector-pi05/src/openpi/shared/image_tools.py +186 -0
- capvector-pi05/src/openpi/shared/image_tools_test.py +37 -0
- capvector-pi05/src/openpi/shared/nnx_utils.py +69 -0
- capvector-pi05/src/openpi/shared/normalize.py +146 -0
- capvector-pi05/src/openpi/shared/normalize_test.py +43 -0
- capvector-pi05/src/openpi/training/checkpoints.py +159 -0
- capvector-pi05/src/openpi/training/config.py +1033 -0
- capvector-pi05/src/openpi/training/data_loader.py +540 -0
- capvector-pi05/src/openpi/training/data_loader_test.py +84 -0
- capvector-pi05/src/openpi/training/droid_rlds_dataset.py +221 -0
- capvector-pi05/src/openpi/training/misc/roboarena_config.py +116 -0
- capvector-pi05/src/openpi/training/optimizer.py +109 -0
- capvector-pi05/src/openpi/training/sharding.py +102 -0
- capvector-pi05/src/openpi/training/utils.py +38 -0
- capvector-pi05/src/openpi/training/weight_loaders.py +104 -0
- capvector-pi05/src/vggt/__init__.py +0 -0
- capvector-pi05/src/vggt/dependency/__init__.py +3 -0
- capvector-pi05/src/vggt/dependency/distortion.py +182 -0
- capvector-pi05/src/vggt/dependency/np_to_pycolmap.py +320 -0
- capvector-pi05/src/vggt/dependency/projection.py +228 -0
- capvector-pi05/src/vggt/dependency/track_modules/__init__.py +0 -0
- capvector-pi05/src/vggt/dependency/track_modules/base_track_predictor.py +190 -0
- capvector-pi05/src/vggt/dependency/track_modules/blocks.py +329 -0
- capvector-pi05/src/vggt/dependency/track_modules/modules.py +202 -0
- capvector-pi05/src/vggt/dependency/track_modules/track_refine.py +419 -0
- capvector-pi05/src/vggt/dependency/track_modules/utils.py +216 -0
- capvector-pi05/src/vggt/dependency/track_predict.py +326 -0
- capvector-pi05/src/vggt/dependency/vggsfm_tracker.py +124 -0
- capvector-pi05/src/vggt/dependency/vggsfm_utils.py +305 -0
- capvector-pi05/src/vggt/heads/camera_head.py +149 -0
- capvector-pi05/src/vggt/heads/dpt_head.py +484 -0
- capvector-pi05/src/vggt/heads/head_act.py +125 -0
- capvector-pi05/src/vggt/heads/track_head.py +104 -0
- capvector-pi05/src/vggt/heads/track_modules/__init__.py +5 -0
- capvector-pi05/src/vggt/heads/track_modules/base_track_predictor.py +209 -0
- capvector-pi05/src/vggt/heads/track_modules/blocks.py +236 -0
- capvector-pi05/src/vggt/heads/track_modules/modules.py +204 -0
- capvector-pi05/src/vggt/heads/track_modules/utils.py +223 -0
- capvector-pi05/src/vggt/heads/utils.py +176 -0
- capvector-pi05/src/vggt/layers/__init__.py +11 -0
- capvector-pi05/src/vggt/layers/attention.py +93 -0
- capvector-pi05/src/vggt/layers/block.py +247 -0
- capvector-pi05/src/vggt/layers/drop_path.py +34 -0
- capvector-pi05/src/vggt/layers/layer_scale.py +22 -0
- capvector-pi05/src/vggt/layers/mlp.py +40 -0
capvector-pi05/src/openpi/policies/policy_test.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from openpi_client import action_chunk_broker
|
| 2 |
+
import pytest
|
| 3 |
+
|
| 4 |
+
from openpi.policies import aloha_policy
|
| 5 |
+
from openpi.policies import policy_config as _policy_config
|
| 6 |
+
from openpi.training import config as _config
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@pytest.mark.manual
|
| 10 |
+
def test_infer():
|
| 11 |
+
config = _config.get_config("pi0_aloha_sim")
|
| 12 |
+
policy = _policy_config.create_trained_policy(config, "gs://openpi-assets/checkpoints/pi0_aloha_sim")
|
| 13 |
+
|
| 14 |
+
example = aloha_policy.make_aloha_example()
|
| 15 |
+
result = policy.infer(example)
|
| 16 |
+
|
| 17 |
+
assert result["actions"].shape == (config.model.action_horizon, 14)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@pytest.mark.manual
|
| 21 |
+
def test_broker():
|
| 22 |
+
config = _config.get_config("pi0_aloha_sim")
|
| 23 |
+
policy = _policy_config.create_trained_policy(config, "gs://openpi-assets/checkpoints/pi0_aloha_sim")
|
| 24 |
+
|
| 25 |
+
broker = action_chunk_broker.ActionChunkBroker(
|
| 26 |
+
policy,
|
| 27 |
+
# Only execute the first half of the chunk.
|
| 28 |
+
action_horizon=config.model.action_horizon // 2,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
example = aloha_policy.make_aloha_example()
|
| 32 |
+
for _ in range(config.model.action_horizon):
|
| 33 |
+
outputs = broker.infer(example)
|
| 34 |
+
assert outputs["actions"].shape == (14,)
|
capvector-pi05/src/openpi/serving/websocket_policy_server.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import http
|
| 3 |
+
import logging
|
| 4 |
+
import time
|
| 5 |
+
import traceback
|
| 6 |
+
|
| 7 |
+
from openpi_client import base_policy as _base_policy
|
| 8 |
+
from openpi_client import msgpack_numpy
|
| 9 |
+
import websockets.asyncio.server as _server
|
| 10 |
+
import websockets.frames
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class WebsocketPolicyServer:
|
| 16 |
+
"""Serves a policy using the websocket protocol. See websocket_client_policy.py for a client implementation.
|
| 17 |
+
|
| 18 |
+
Currently only implements the `load` and `infer` methods.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
policy: _base_policy.BasePolicy,
|
| 24 |
+
host: str = "0.0.0.0",
|
| 25 |
+
port: int | None = None,
|
| 26 |
+
metadata: dict | None = None,
|
| 27 |
+
) -> None:
|
| 28 |
+
self._policy = policy
|
| 29 |
+
self._host = host
|
| 30 |
+
self._port = port
|
| 31 |
+
self._metadata = metadata or {}
|
| 32 |
+
logging.getLogger("websockets.server").setLevel(logging.INFO)
|
| 33 |
+
|
| 34 |
+
def serve_forever(self) -> None:
|
| 35 |
+
asyncio.run(self.run())
|
| 36 |
+
|
| 37 |
+
async def run(self):
|
| 38 |
+
async with _server.serve(
|
| 39 |
+
self._handler,
|
| 40 |
+
self._host,
|
| 41 |
+
self._port,
|
| 42 |
+
compression=None,
|
| 43 |
+
max_size=None,
|
| 44 |
+
process_request=_health_check,
|
| 45 |
+
) as server:
|
| 46 |
+
await server.serve_forever()
|
| 47 |
+
|
| 48 |
+
async def _handler(self, websocket: _server.ServerConnection):
|
| 49 |
+
logger.info(f"Connection from {websocket.remote_address} opened")
|
| 50 |
+
packer = msgpack_numpy.Packer()
|
| 51 |
+
|
| 52 |
+
await websocket.send(packer.pack(self._metadata))
|
| 53 |
+
|
| 54 |
+
prev_total_time = None
|
| 55 |
+
while True:
|
| 56 |
+
try:
|
| 57 |
+
start_time = time.monotonic()
|
| 58 |
+
obs = msgpack_numpy.unpackb(await websocket.recv())
|
| 59 |
+
|
| 60 |
+
infer_time = time.monotonic()
|
| 61 |
+
action = self._policy.infer(obs)
|
| 62 |
+
infer_time = time.monotonic() - infer_time
|
| 63 |
+
|
| 64 |
+
action["server_timing"] = {
|
| 65 |
+
"infer_ms": infer_time * 1000,
|
| 66 |
+
}
|
| 67 |
+
if prev_total_time is not None:
|
| 68 |
+
# We can only record the last total time since we also want to include the send time.
|
| 69 |
+
action["server_timing"]["prev_total_ms"] = prev_total_time * 1000
|
| 70 |
+
|
| 71 |
+
await websocket.send(packer.pack(action))
|
| 72 |
+
prev_total_time = time.monotonic() - start_time
|
| 73 |
+
|
| 74 |
+
except websockets.ConnectionClosed:
|
| 75 |
+
logger.info(f"Connection from {websocket.remote_address} closed")
|
| 76 |
+
break
|
| 77 |
+
except Exception:
|
| 78 |
+
await websocket.send(traceback.format_exc())
|
| 79 |
+
await websocket.close(
|
| 80 |
+
code=websockets.frames.CloseCode.INTERNAL_ERROR,
|
| 81 |
+
reason="Internal server error. Traceback included in previous frame.",
|
| 82 |
+
)
|
| 83 |
+
raise
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _health_check(connection: _server.ServerConnection, request: _server.Request) -> _server.Response | None:
|
| 87 |
+
if request.path == "/healthz":
|
| 88 |
+
return connection.respond(http.HTTPStatus.OK, "OK\n")
|
| 89 |
+
# Continue with the normal request handling.
|
| 90 |
+
return None
|
capvector-pi05/src/openpi/shared/__init__.py
ADDED
|
File without changes
|
capvector-pi05/src/openpi/shared/download.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import concurrent.futures
|
| 2 |
+
import datetime
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import pathlib
|
| 6 |
+
import re
|
| 7 |
+
import shutil
|
| 8 |
+
import stat
|
| 9 |
+
import time
|
| 10 |
+
import urllib.parse
|
| 11 |
+
|
| 12 |
+
import filelock
|
| 13 |
+
import fsspec
|
| 14 |
+
import fsspec.generic
|
| 15 |
+
import tqdm_loggable.auto as tqdm
|
| 16 |
+
|
| 17 |
+
# Environment variable to control cache directory path, ~/.cache/openpi will be used by default.
|
| 18 |
+
_OPENPI_DATA_HOME = "OPENPI_DATA_HOME"
|
| 19 |
+
DEFAULT_CACHE_DIR = "~/.cache/openpi"
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_cache_dir() -> pathlib.Path:
|
| 25 |
+
cache_dir = pathlib.Path(os.getenv(_OPENPI_DATA_HOME, DEFAULT_CACHE_DIR)).expanduser().resolve()
|
| 26 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
| 27 |
+
_set_folder_permission(cache_dir)
|
| 28 |
+
return cache_dir
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def maybe_download(url: str, *, force_download: bool = False, **kwargs) -> pathlib.Path:
|
| 32 |
+
"""Download a file or directory from a remote filesystem to the local cache, and return the local path.
|
| 33 |
+
|
| 34 |
+
If the local file already exists, it will be returned directly.
|
| 35 |
+
|
| 36 |
+
It is safe to call this function concurrently from multiple processes.
|
| 37 |
+
See `get_cache_dir` for more details on the cache directory.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
url: URL to the file to download.
|
| 41 |
+
force_download: If True, the file will be downloaded even if it already exists in the cache.
|
| 42 |
+
**kwargs: Additional arguments to pass to fsspec.
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
Local path to the downloaded file or directory. That path is guaranteed to exist and is absolute.
|
| 46 |
+
"""
|
| 47 |
+
# Don't use fsspec to parse the url to avoid unnecessary connection to the remote filesystem.
|
| 48 |
+
parsed = urllib.parse.urlparse(url)
|
| 49 |
+
|
| 50 |
+
# Short circuit if this is a local path.
|
| 51 |
+
if parsed.scheme == "":
|
| 52 |
+
path = pathlib.Path(url)
|
| 53 |
+
if not path.exists():
|
| 54 |
+
raise FileNotFoundError(f"File not found at {url}")
|
| 55 |
+
return path.resolve()
|
| 56 |
+
|
| 57 |
+
cache_dir = get_cache_dir()
|
| 58 |
+
|
| 59 |
+
local_path = cache_dir / parsed.netloc / parsed.path.strip("/")
|
| 60 |
+
local_path = local_path.resolve()
|
| 61 |
+
|
| 62 |
+
# Check if the cache should be invalidated.
|
| 63 |
+
invalidate_cache = False
|
| 64 |
+
if local_path.exists():
|
| 65 |
+
if force_download or _should_invalidate_cache(cache_dir, local_path):
|
| 66 |
+
invalidate_cache = True
|
| 67 |
+
else:
|
| 68 |
+
return local_path
|
| 69 |
+
|
| 70 |
+
try:
|
| 71 |
+
lock_path = local_path.with_suffix(".lock")
|
| 72 |
+
with filelock.FileLock(lock_path):
|
| 73 |
+
# Ensure consistent permissions for the lock file.
|
| 74 |
+
_ensure_permissions(lock_path)
|
| 75 |
+
# First, remove the existing cache if it is expired.
|
| 76 |
+
if invalidate_cache:
|
| 77 |
+
logger.info(f"Removing expired cached entry: {local_path}")
|
| 78 |
+
if local_path.is_dir():
|
| 79 |
+
shutil.rmtree(local_path)
|
| 80 |
+
else:
|
| 81 |
+
local_path.unlink()
|
| 82 |
+
|
| 83 |
+
# Download the data to a local cache.
|
| 84 |
+
logger.info(f"Downloading {url} to {local_path}")
|
| 85 |
+
scratch_path = local_path.with_suffix(".partial")
|
| 86 |
+
_download_fsspec(url, scratch_path, **kwargs)
|
| 87 |
+
|
| 88 |
+
shutil.move(scratch_path, local_path)
|
| 89 |
+
_ensure_permissions(local_path)
|
| 90 |
+
|
| 91 |
+
except PermissionError as e:
|
| 92 |
+
msg = (
|
| 93 |
+
f"Local file permission error was encountered while downloading {url}. "
|
| 94 |
+
f"Please try again after removing the cached data using: `rm -rf {local_path}*`"
|
| 95 |
+
)
|
| 96 |
+
raise PermissionError(msg) from e
|
| 97 |
+
|
| 98 |
+
return local_path
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _download_fsspec(url: str, local_path: pathlib.Path, **kwargs) -> None:
|
| 102 |
+
"""Download a file from a remote filesystem to the local cache, and return the local path."""
|
| 103 |
+
fs, _ = fsspec.core.url_to_fs(url, **kwargs)
|
| 104 |
+
info = fs.info(url)
|
| 105 |
+
# Folders are represented by 0-byte objects with a trailing forward slash.
|
| 106 |
+
if is_dir := (info["type"] == "directory" or (info["size"] == 0 and info["name"].endswith("/"))):
|
| 107 |
+
total_size = fs.du(url)
|
| 108 |
+
else:
|
| 109 |
+
total_size = info["size"]
|
| 110 |
+
with tqdm.tqdm(total=total_size, unit="iB", unit_scale=True, unit_divisor=1024) as pbar:
|
| 111 |
+
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
| 112 |
+
future = executor.submit(fs.get, url, local_path, recursive=is_dir)
|
| 113 |
+
while not future.done():
|
| 114 |
+
current_size = sum(f.stat().st_size for f in [*local_path.rglob("*"), local_path] if f.is_file())
|
| 115 |
+
pbar.update(current_size - pbar.n)
|
| 116 |
+
time.sleep(1)
|
| 117 |
+
pbar.update(total_size - pbar.n)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _set_permission(path: pathlib.Path, target_permission: int):
|
| 121 |
+
"""chmod requires executable permission to be set, so we skip if the permission is already match with the target."""
|
| 122 |
+
if path.stat().st_mode & target_permission == target_permission:
|
| 123 |
+
logger.debug(f"Skipping {path} because it already has correct permissions")
|
| 124 |
+
return
|
| 125 |
+
path.chmod(target_permission)
|
| 126 |
+
logger.debug(f"Set {path} to {target_permission}")
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def _set_folder_permission(folder_path: pathlib.Path) -> None:
|
| 130 |
+
"""Set folder permission to be read, write and searchable."""
|
| 131 |
+
_set_permission(folder_path, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def _ensure_permissions(path: pathlib.Path) -> None:
|
| 135 |
+
"""Since we are sharing cache directory with containerized runtime as well as training script, we need to
|
| 136 |
+
ensure that the cache directory has the correct permissions.
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
def _setup_folder_permission_between_cache_dir_and_path(path: pathlib.Path) -> None:
|
| 140 |
+
cache_dir = get_cache_dir()
|
| 141 |
+
relative_path = path.relative_to(cache_dir)
|
| 142 |
+
moving_path = cache_dir
|
| 143 |
+
for part in relative_path.parts:
|
| 144 |
+
_set_folder_permission(moving_path / part)
|
| 145 |
+
moving_path = moving_path / part
|
| 146 |
+
|
| 147 |
+
def _set_file_permission(file_path: pathlib.Path) -> None:
|
| 148 |
+
"""Set all files to be read & writable, if it is a script, keep it as a script."""
|
| 149 |
+
file_rw = stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IWGRP | stat.S_IROTH | stat.S_IWOTH
|
| 150 |
+
if file_path.stat().st_mode & 0o100:
|
| 151 |
+
_set_permission(file_path, file_rw | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)
|
| 152 |
+
else:
|
| 153 |
+
_set_permission(file_path, file_rw)
|
| 154 |
+
|
| 155 |
+
_setup_folder_permission_between_cache_dir_and_path(path)
|
| 156 |
+
for root, dirs, files in os.walk(str(path)):
|
| 157 |
+
root_path = pathlib.Path(root)
|
| 158 |
+
for file in files:
|
| 159 |
+
file_path = root_path / file
|
| 160 |
+
_set_file_permission(file_path)
|
| 161 |
+
|
| 162 |
+
for dir in dirs:
|
| 163 |
+
dir_path = root_path / dir
|
| 164 |
+
_set_folder_permission(dir_path)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def _get_mtime(year: int, month: int, day: int) -> float:
|
| 168 |
+
"""Get the mtime of a given date at midnight UTC."""
|
| 169 |
+
date = datetime.datetime(year, month, day, tzinfo=datetime.UTC)
|
| 170 |
+
return time.mktime(date.timetuple())
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# Map of relative paths, defined as regular expressions, to expiration timestamps (mtime format).
|
| 174 |
+
# Partial matching will be used from top to bottom and the first match will be chosen.
|
| 175 |
+
# Cached entries will be retained only if they are newer than the expiration timestamp.
|
| 176 |
+
_INVALIDATE_CACHE_DIRS: dict[re.Pattern, float] = {
|
| 177 |
+
re.compile("openpi-assets/checkpoints/pi0_aloha_pen_uncap"): _get_mtime(2025, 2, 17),
|
| 178 |
+
re.compile("openpi-assets/checkpoints/pi0_libero"): _get_mtime(2025, 2, 6),
|
| 179 |
+
re.compile("openpi-assets/checkpoints/"): _get_mtime(2025, 2, 3),
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def _should_invalidate_cache(cache_dir: pathlib.Path, local_path: pathlib.Path) -> bool:
|
| 184 |
+
"""Invalidate the cache if it is expired. Return True if the cache was invalidated."""
|
| 185 |
+
|
| 186 |
+
assert local_path.exists(), f"File not found at {local_path}"
|
| 187 |
+
|
| 188 |
+
relative_path = str(local_path.relative_to(cache_dir))
|
| 189 |
+
for pattern, expire_time in _INVALIDATE_CACHE_DIRS.items():
|
| 190 |
+
if pattern.match(relative_path):
|
| 191 |
+
# Remove if not newer than the expiration timestamp.
|
| 192 |
+
return local_path.stat().st_mtime <= expire_time
|
| 193 |
+
|
| 194 |
+
return False
|
capvector-pi05/src/openpi/shared/download_test.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pathlib
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
|
| 5 |
+
import openpi.shared.download as download
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@pytest.fixture(scope="session", autouse=True)
|
| 9 |
+
def set_openpi_data_home(tmp_path_factory):
|
| 10 |
+
temp_dir = tmp_path_factory.mktemp("openpi_data")
|
| 11 |
+
with pytest.MonkeyPatch().context() as mp:
|
| 12 |
+
mp.setenv("OPENPI_DATA_HOME", str(temp_dir))
|
| 13 |
+
yield
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def test_download_local(tmp_path: pathlib.Path):
|
| 17 |
+
local_path = tmp_path / "local"
|
| 18 |
+
local_path.touch()
|
| 19 |
+
|
| 20 |
+
result = download.maybe_download(str(local_path))
|
| 21 |
+
assert result == local_path
|
| 22 |
+
|
| 23 |
+
with pytest.raises(FileNotFoundError):
|
| 24 |
+
download.maybe_download("bogus")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def test_download_gs_dir():
|
| 28 |
+
remote_path = "gs://openpi-assets/testdata/random"
|
| 29 |
+
|
| 30 |
+
local_path = download.maybe_download(remote_path)
|
| 31 |
+
assert local_path.exists()
|
| 32 |
+
|
| 33 |
+
new_local_path = download.maybe_download(remote_path)
|
| 34 |
+
assert new_local_path == local_path
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def test_download_gs():
|
| 38 |
+
remote_path = "gs://openpi-assets/testdata/random/random_512kb.bin"
|
| 39 |
+
|
| 40 |
+
local_path = download.maybe_download(remote_path)
|
| 41 |
+
assert local_path.exists()
|
| 42 |
+
|
| 43 |
+
new_local_path = download.maybe_download(remote_path)
|
| 44 |
+
assert new_local_path == local_path
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def test_download_fsspec():
|
| 48 |
+
remote_path = "gs://big_vision/paligemma_tokenizer.model"
|
| 49 |
+
|
| 50 |
+
local_path = download.maybe_download(remote_path, gs={"token": "anon"})
|
| 51 |
+
assert local_path.exists()
|
| 52 |
+
|
| 53 |
+
new_local_path = download.maybe_download(remote_path, gs={"token": "anon"})
|
| 54 |
+
assert new_local_path == local_path
|
capvector-pi05/src/openpi/shared/image_tools.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
|
| 3 |
+
import jax
|
| 4 |
+
import jax.numpy as jnp
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F # noqa: N812
|
| 7 |
+
|
| 8 |
+
import openpi.shared.array_typing as at
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@functools.partial(jax.jit, static_argnums=(1, 2, 3))
|
| 12 |
+
@at.typecheck
|
| 13 |
+
def resize_with_pad(
|
| 14 |
+
images: at.UInt8[at.Array, "*b h w c"] | at.Float[at.Array, "*b h w c"],
|
| 15 |
+
height: int,
|
| 16 |
+
width: int,
|
| 17 |
+
method: jax.image.ResizeMethod = jax.image.ResizeMethod.LINEAR,
|
| 18 |
+
) -> at.UInt8[at.Array, "*b {height} {width} c"] | at.Float[at.Array, "*b {height} {width} c"]:
|
| 19 |
+
"""Replicates tf.image.resize_with_pad. Resizes an image to a target height and width without distortion
|
| 20 |
+
by padding with black. If the image is float32, it must be in the range [-1, 1].
|
| 21 |
+
"""
|
| 22 |
+
has_batch_dim = images.ndim == 4
|
| 23 |
+
if not has_batch_dim:
|
| 24 |
+
images = images[None] # type: ignore
|
| 25 |
+
cur_height, cur_width = images.shape[1:3]
|
| 26 |
+
ratio = max(cur_width / width, cur_height / height)
|
| 27 |
+
resized_height = int(cur_height / ratio)
|
| 28 |
+
resized_width = int(cur_width / ratio)
|
| 29 |
+
resized_images = jax.image.resize(
|
| 30 |
+
images, (images.shape[0], resized_height, resized_width, images.shape[3]), method=method
|
| 31 |
+
)
|
| 32 |
+
if images.dtype == jnp.uint8:
|
| 33 |
+
# round from float back to uint8
|
| 34 |
+
resized_images = jnp.round(resized_images).clip(0, 255).astype(jnp.uint8)
|
| 35 |
+
elif images.dtype == jnp.float32:
|
| 36 |
+
resized_images = resized_images.clip(-1.0, 1.0)
|
| 37 |
+
else:
|
| 38 |
+
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
| 39 |
+
|
| 40 |
+
pad_h0, remainder_h = divmod(height - resized_height, 2)
|
| 41 |
+
pad_h1 = pad_h0 + remainder_h
|
| 42 |
+
pad_w0, remainder_w = divmod(width - resized_width, 2)
|
| 43 |
+
pad_w1 = pad_w0 + remainder_w
|
| 44 |
+
padded_images = jnp.pad(
|
| 45 |
+
resized_images,
|
| 46 |
+
((0, 0), (pad_h0, pad_h1), (pad_w0, pad_w1), (0, 0)),
|
| 47 |
+
constant_values=0 if images.dtype == jnp.uint8 else -1.0,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
if not has_batch_dim:
|
| 51 |
+
padded_images = padded_images[0]
|
| 52 |
+
return padded_images
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def resize_with_pad_torch(
|
| 56 |
+
images: torch.Tensor,
|
| 57 |
+
height: int,
|
| 58 |
+
width: int,
|
| 59 |
+
mode: str = "bilinear",
|
| 60 |
+
) -> torch.Tensor:
|
| 61 |
+
"""PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion
|
| 62 |
+
by padding with black. If the image is float32, it must be in the range [-1, 1].
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
images: Tensor of shape [*b, h, w, c] or [*b, c, h, w]
|
| 66 |
+
height: Target height
|
| 67 |
+
width: Target width
|
| 68 |
+
mode: Interpolation mode ('bilinear', 'nearest', etc.)
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
Resized and padded tensor with same shape format as input
|
| 72 |
+
"""
|
| 73 |
+
# Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w]
|
| 74 |
+
if images.shape[-1] <= 4: # Assume channels-last format
|
| 75 |
+
channels_last = True
|
| 76 |
+
# Convert to channels-first for torch operations
|
| 77 |
+
if images.dim() == 3:
|
| 78 |
+
images = images.unsqueeze(0) # Add batch dimension
|
| 79 |
+
images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w]
|
| 80 |
+
else:
|
| 81 |
+
channels_last = False
|
| 82 |
+
if images.dim() == 3:
|
| 83 |
+
images = images.unsqueeze(0) # Add batch dimension
|
| 84 |
+
|
| 85 |
+
batch_size, channels, cur_height, cur_width = images.shape
|
| 86 |
+
|
| 87 |
+
# Calculate resize ratio
|
| 88 |
+
ratio = max(cur_width / width, cur_height / height)
|
| 89 |
+
resized_height = int(cur_height / ratio)
|
| 90 |
+
resized_width = int(cur_width / ratio)
|
| 91 |
+
|
| 92 |
+
# Resize
|
| 93 |
+
resized_images = F.interpolate(
|
| 94 |
+
images, size=(resized_height, resized_width), mode=mode, align_corners=False if mode == "bilinear" else None
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Handle dtype-specific clipping
|
| 98 |
+
if images.dtype == torch.uint8:
|
| 99 |
+
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
| 100 |
+
elif images.dtype == torch.float32:
|
| 101 |
+
resized_images = resized_images.clamp(-1.0, 1.0)
|
| 102 |
+
else:
|
| 103 |
+
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
| 104 |
+
|
| 105 |
+
# Calculate padding
|
| 106 |
+
pad_h0, remainder_h = divmod(height - resized_height, 2)
|
| 107 |
+
pad_h1 = pad_h0 + remainder_h
|
| 108 |
+
pad_w0, remainder_w = divmod(width - resized_width, 2)
|
| 109 |
+
pad_w1 = pad_w0 + remainder_w
|
| 110 |
+
|
| 111 |
+
# Pad
|
| 112 |
+
constant_value = 0 if images.dtype == torch.uint8 else -1.0
|
| 113 |
+
padded_images = F.pad(
|
| 114 |
+
resized_images,
|
| 115 |
+
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
| 116 |
+
mode="constant",
|
| 117 |
+
value=constant_value,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Convert back to original format if needed
|
| 121 |
+
if channels_last:
|
| 122 |
+
padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
| 123 |
+
if batch_size == 1 and images.shape[0] == 1:
|
| 124 |
+
padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added
|
| 125 |
+
|
| 126 |
+
return padded_images
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def replace_padding_0to1_torch(image: torch.Tensor,) -> torch.Tensor:
|
| 130 |
+
"""PyTorch version of replace_padding_0to1.
|
| 131 |
+
OpenPI requires images with 0 value paddings, while VGGT series requires 1 value paddings.
|
| 132 |
+
Here it achieves this bounding-box based padding replacement.
|
| 133 |
+
Args:
|
| 134 |
+
image: Tensor of shape [*b, h, w, c]
|
| 135 |
+
Returns:
|
| 136 |
+
Padding-replaced tensor with same shape as input
|
| 137 |
+
"""
|
| 138 |
+
single = False
|
| 139 |
+
if image.dim() == 3:
|
| 140 |
+
image = image.unsqueeze(0)
|
| 141 |
+
single = True
|
| 142 |
+
|
| 143 |
+
b, h, w, c = image.shape
|
| 144 |
+
device = image.device
|
| 145 |
+
|
| 146 |
+
nonzero_any = (image != 0).any(dim=-1)
|
| 147 |
+
|
| 148 |
+
row_any = nonzero_any.any(dim=2)
|
| 149 |
+
col_any = nonzero_any.any(dim=1)
|
| 150 |
+
|
| 151 |
+
top = row_any.to(torch.float32).argmax(dim=1)
|
| 152 |
+
bottom = h - 1 - row_any.flip(dims=[1]).to(torch.float32).argmax(dim=1)
|
| 153 |
+
left = col_any.to(torch.float32).argmax(dim=1)
|
| 154 |
+
right = w - 1 - col_any.flip(dims=[1]).to(torch.float32).argmax(dim=1)
|
| 155 |
+
|
| 156 |
+
has_any = row_any.any(dim=1)
|
| 157 |
+
top = torch.where(has_any, top, torch.zeros_like(top))
|
| 158 |
+
bottom = torch.where(has_any, bottom, torch.full_like(bottom, h - 1))
|
| 159 |
+
left = torch.where(has_any, left, torch.zeros_like(left))
|
| 160 |
+
right = torch.where(has_any, right, torch.full_like(right, w - 1))
|
| 161 |
+
|
| 162 |
+
rows = torch.arange(h, device=device).view(1, h, 1)
|
| 163 |
+
cols = torch.arange(w, device=device).view(1, 1, w)
|
| 164 |
+
top_v = top.view(b, 1, 1)
|
| 165 |
+
bottom_v = bottom.view(b, 1, 1)
|
| 166 |
+
left_v = left.view(b, 1, 1)
|
| 167 |
+
right_v = right.view(b, 1, 1)
|
| 168 |
+
|
| 169 |
+
row_mask = (rows >= top_v) & (rows <= bottom_v)
|
| 170 |
+
col_mask = (cols >= left_v) & (cols <= right_v)
|
| 171 |
+
inside_mask = row_mask & col_mask
|
| 172 |
+
|
| 173 |
+
padding_mask = ~inside_mask
|
| 174 |
+
|
| 175 |
+
pixel_zero = (image == 0).all(dim=-1)
|
| 176 |
+
|
| 177 |
+
final_mask = padding_mask & pixel_zero
|
| 178 |
+
|
| 179 |
+
if final_mask.any():
|
| 180 |
+
mask_exp = final_mask.unsqueeze(-1).expand_as(image)
|
| 181 |
+
one_t = torch.tensor(1, dtype=image.dtype, device=device)
|
| 182 |
+
image = torch.where(mask_exp, one_t, image)
|
| 183 |
+
|
| 184 |
+
if single:
|
| 185 |
+
image = image.squeeze(0)
|
| 186 |
+
return image
|
capvector-pi05/src/openpi/shared/image_tools_test.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import jax.numpy as jnp
|
| 2 |
+
|
| 3 |
+
from openpi.shared import image_tools
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def test_resize_with_pad_shapes():
|
| 7 |
+
# Test case 1: Resize image with larger dimensions
|
| 8 |
+
images = jnp.zeros((2, 10, 10, 3), dtype=jnp.uint8) # Input images of shape (batch_size, height, width, channels)
|
| 9 |
+
height = 20
|
| 10 |
+
width = 20
|
| 11 |
+
resized_images = image_tools.resize_with_pad(images, height, width)
|
| 12 |
+
assert resized_images.shape == (2, height, width, 3)
|
| 13 |
+
assert jnp.all(resized_images == 0)
|
| 14 |
+
|
| 15 |
+
# Test case 2: Resize image with smaller dimensions
|
| 16 |
+
images = jnp.zeros((3, 30, 30, 3), dtype=jnp.uint8)
|
| 17 |
+
height = 15
|
| 18 |
+
width = 15
|
| 19 |
+
resized_images = image_tools.resize_with_pad(images, height, width)
|
| 20 |
+
assert resized_images.shape == (3, height, width, 3)
|
| 21 |
+
assert jnp.all(resized_images == 0)
|
| 22 |
+
|
| 23 |
+
# Test case 3: Resize image with the same dimensions
|
| 24 |
+
images = jnp.zeros((1, 50, 50, 3), dtype=jnp.uint8)
|
| 25 |
+
height = 50
|
| 26 |
+
width = 50
|
| 27 |
+
resized_images = image_tools.resize_with_pad(images, height, width)
|
| 28 |
+
assert resized_images.shape == (1, height, width, 3)
|
| 29 |
+
assert jnp.all(resized_images == 0)
|
| 30 |
+
|
| 31 |
+
# Test case 3: Resize image with odd-numbered padding
|
| 32 |
+
images = jnp.zeros((1, 256, 320, 3), dtype=jnp.uint8)
|
| 33 |
+
height = 60
|
| 34 |
+
width = 80
|
| 35 |
+
resized_images = image_tools.resize_with_pad(images, height, width)
|
| 36 |
+
assert resized_images.shape == (1, height, width, 3)
|
| 37 |
+
assert jnp.all(resized_images == 0)
|
capvector-pi05/src/openpi/shared/nnx_utils.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections.abc import Callable
|
| 2 |
+
import dataclasses
|
| 3 |
+
import functools
|
| 4 |
+
import inspect
|
| 5 |
+
import re
|
| 6 |
+
from typing import Any, ParamSpec, TypeVar
|
| 7 |
+
|
| 8 |
+
import flax.nnx as nnx
|
| 9 |
+
import jax
|
| 10 |
+
|
| 11 |
+
P = ParamSpec("P")
|
| 12 |
+
R = TypeVar("R")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def module_jit(meth: Callable[P, R], *jit_args, **jit_kwargs) -> Callable[P, R]:
|
| 16 |
+
"""A higher-order function to JIT-compile `nnx.Module` methods, freezing the module's state in the process.
|
| 17 |
+
|
| 18 |
+
Why not `nnx.jit`? For some reason, naively applying `nnx.jit` to `nnx.Module` methods, bound or unbound, uses much
|
| 19 |
+
more memory than necessary. I'm guessing it has something to do with the fact that it must keep track of module
|
| 20 |
+
mutations. Also, `nnx.jit` has some inherent overhead compared to a standard `jax.jit`, since every call must
|
| 21 |
+
traverse the NNX module graph. See https://github.com/google/flax/discussions/4224 for details.
|
| 22 |
+
|
| 23 |
+
`module_jit` is an alternative that avoids these issues by freezing the module's state. The function returned by
|
| 24 |
+
`module_jit` acts exactly like the original method, except that the state of the module is frozen to whatever it was
|
| 25 |
+
when `module_jit` was called. Mutations to the module within `meth` are still allowed, but they will be discarded
|
| 26 |
+
after the method call completes.
|
| 27 |
+
"""
|
| 28 |
+
if not (inspect.ismethod(meth) and isinstance(meth.__self__, nnx.Module)):
|
| 29 |
+
raise ValueError("module_jit must only be used on bound methods of nnx.Modules.")
|
| 30 |
+
|
| 31 |
+
graphdef, state = nnx.split(meth.__self__)
|
| 32 |
+
|
| 33 |
+
def fun(state: nnx.State, *args: P.args, **kwargs: P.kwargs) -> R:
|
| 34 |
+
module = nnx.merge(graphdef, state)
|
| 35 |
+
return meth.__func__(module, *args, **kwargs)
|
| 36 |
+
|
| 37 |
+
jitted_fn = jax.jit(fun, *jit_args, **jit_kwargs)
|
| 38 |
+
|
| 39 |
+
@functools.wraps(meth)
|
| 40 |
+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
| 41 |
+
return jitted_fn(state, *args, **kwargs)
|
| 42 |
+
|
| 43 |
+
return wrapper
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@dataclasses.dataclass(frozen=True)
|
| 47 |
+
class PathRegex:
|
| 48 |
+
"""NNX Filter that matches paths using a regex.
|
| 49 |
+
|
| 50 |
+
By default, paths are joined with a `/` separator. This can be overridden by setting the `sep` argument.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
pattern: str | re.Pattern
|
| 54 |
+
sep: str = "/"
|
| 55 |
+
|
| 56 |
+
def __post_init__(self):
|
| 57 |
+
if not isinstance(self.pattern, re.Pattern):
|
| 58 |
+
object.__setattr__(self, "pattern", re.compile(self.pattern))
|
| 59 |
+
|
| 60 |
+
def __call__(self, path: nnx.filterlib.PathParts, x: Any) -> bool:
|
| 61 |
+
joined_path = self.sep.join(str(x) for x in path)
|
| 62 |
+
assert isinstance(self.pattern, re.Pattern)
|
| 63 |
+
return self.pattern.fullmatch(joined_path) is not None
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def state_map(state: nnx.State, filter: nnx.filterlib.Filter, fn: Callable[[Any], Any]) -> nnx.State:
|
| 67 |
+
"""Apply a function to the leaves of the state that match the filter."""
|
| 68 |
+
filtered_keys = set(state.filter(filter).flat_state())
|
| 69 |
+
return state.map(lambda k, v: fn(v) if k in filtered_keys else v)
|
capvector-pi05/src/openpi/shared/normalize.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import pathlib
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import numpydantic
|
| 6 |
+
import pydantic
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@pydantic.dataclasses.dataclass
|
| 10 |
+
class NormStats:
|
| 11 |
+
mean: numpydantic.NDArray
|
| 12 |
+
std: numpydantic.NDArray
|
| 13 |
+
q01: numpydantic.NDArray | None = None # 1st quantile
|
| 14 |
+
q99: numpydantic.NDArray | None = None # 99th quantile
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class RunningStats:
|
| 18 |
+
"""Compute running statistics of a batch of vectors."""
|
| 19 |
+
|
| 20 |
+
def __init__(self):
|
| 21 |
+
self._count = 0
|
| 22 |
+
self._mean = None
|
| 23 |
+
self._mean_of_squares = None
|
| 24 |
+
self._min = None
|
| 25 |
+
self._max = None
|
| 26 |
+
self._histograms = None
|
| 27 |
+
self._bin_edges = None
|
| 28 |
+
self._num_quantile_bins = 5000 # for computing quantiles on the fly
|
| 29 |
+
|
| 30 |
+
def update(self, batch: np.ndarray) -> None:
|
| 31 |
+
"""
|
| 32 |
+
Update the running statistics with a batch of vectors.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
vectors (np.ndarray): An array where all dimensions except the last are batch dimensions.
|
| 36 |
+
"""
|
| 37 |
+
batch = batch.reshape(-1, batch.shape[-1])
|
| 38 |
+
num_elements, vector_length = batch.shape
|
| 39 |
+
if self._count == 0:
|
| 40 |
+
self._mean = np.mean(batch, axis=0)
|
| 41 |
+
self._mean_of_squares = np.mean(batch**2, axis=0)
|
| 42 |
+
self._min = np.min(batch, axis=0)
|
| 43 |
+
self._max = np.max(batch, axis=0)
|
| 44 |
+
self._histograms = [np.zeros(self._num_quantile_bins) for _ in range(vector_length)]
|
| 45 |
+
self._bin_edges = [
|
| 46 |
+
np.linspace(self._min[i] - 1e-10, self._max[i] + 1e-10, self._num_quantile_bins + 1)
|
| 47 |
+
for i in range(vector_length)
|
| 48 |
+
]
|
| 49 |
+
else:
|
| 50 |
+
if vector_length != self._mean.size:
|
| 51 |
+
raise ValueError("The length of new vectors does not match the initialized vector length.")
|
| 52 |
+
new_max = np.max(batch, axis=0)
|
| 53 |
+
new_min = np.min(batch, axis=0)
|
| 54 |
+
max_changed = np.any(new_max > self._max)
|
| 55 |
+
min_changed = np.any(new_min < self._min)
|
| 56 |
+
self._max = np.maximum(self._max, new_max)
|
| 57 |
+
self._min = np.minimum(self._min, new_min)
|
| 58 |
+
|
| 59 |
+
if max_changed or min_changed:
|
| 60 |
+
self._adjust_histograms()
|
| 61 |
+
|
| 62 |
+
self._count += num_elements
|
| 63 |
+
|
| 64 |
+
batch_mean = np.mean(batch, axis=0)
|
| 65 |
+
batch_mean_of_squares = np.mean(batch**2, axis=0)
|
| 66 |
+
|
| 67 |
+
# Update running mean and mean of squares.
|
| 68 |
+
self._mean += (batch_mean - self._mean) * (num_elements / self._count)
|
| 69 |
+
self._mean_of_squares += (batch_mean_of_squares - self._mean_of_squares) * (num_elements / self._count)
|
| 70 |
+
|
| 71 |
+
self._update_histograms(batch)
|
| 72 |
+
|
| 73 |
+
def get_statistics(self) -> NormStats:
|
| 74 |
+
"""
|
| 75 |
+
Compute and return the statistics of the vectors processed so far.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
dict: A dictionary containing the computed statistics.
|
| 79 |
+
"""
|
| 80 |
+
if self._count < 2:
|
| 81 |
+
raise ValueError("Cannot compute statistics for less than 2 vectors.")
|
| 82 |
+
|
| 83 |
+
variance = self._mean_of_squares - self._mean**2
|
| 84 |
+
stddev = np.sqrt(np.maximum(0, variance))
|
| 85 |
+
q01, q99 = self._compute_quantiles([0.01, 0.99])
|
| 86 |
+
return NormStats(mean=self._mean, std=stddev, q01=q01, q99=q99)
|
| 87 |
+
|
| 88 |
+
def _adjust_histograms(self):
|
| 89 |
+
"""Adjust histograms when min or max changes."""
|
| 90 |
+
for i in range(len(self._histograms)):
|
| 91 |
+
old_edges = self._bin_edges[i]
|
| 92 |
+
new_edges = np.linspace(self._min[i], self._max[i], self._num_quantile_bins + 1)
|
| 93 |
+
|
| 94 |
+
# Redistribute the existing histogram counts to the new bins
|
| 95 |
+
new_hist, _ = np.histogram(old_edges[:-1], bins=new_edges, weights=self._histograms[i])
|
| 96 |
+
|
| 97 |
+
self._histograms[i] = new_hist
|
| 98 |
+
self._bin_edges[i] = new_edges
|
| 99 |
+
|
| 100 |
+
def _update_histograms(self, batch: np.ndarray) -> None:
|
| 101 |
+
"""Update histograms with new vectors."""
|
| 102 |
+
for i in range(batch.shape[1]):
|
| 103 |
+
hist, _ = np.histogram(batch[:, i], bins=self._bin_edges[i])
|
| 104 |
+
self._histograms[i] += hist
|
| 105 |
+
|
| 106 |
+
def _compute_quantiles(self, quantiles):
|
| 107 |
+
"""Compute quantiles based on histograms."""
|
| 108 |
+
results = []
|
| 109 |
+
for q in quantiles:
|
| 110 |
+
target_count = q * self._count
|
| 111 |
+
q_values = []
|
| 112 |
+
for hist, edges in zip(self._histograms, self._bin_edges, strict=True):
|
| 113 |
+
cumsum = np.cumsum(hist)
|
| 114 |
+
idx = np.searchsorted(cumsum, target_count)
|
| 115 |
+
q_values.append(edges[idx])
|
| 116 |
+
results.append(np.array(q_values))
|
| 117 |
+
return results
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class _NormStatsDict(pydantic.BaseModel):
|
| 121 |
+
norm_stats: dict[str, NormStats]
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def serialize_json(norm_stats: dict[str, NormStats]) -> str:
|
| 125 |
+
"""Serialize the running statistics to a JSON string."""
|
| 126 |
+
return _NormStatsDict(norm_stats=norm_stats).model_dump_json(indent=2)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def deserialize_json(data: str) -> dict[str, NormStats]:
|
| 130 |
+
"""Deserialize the running statistics from a JSON string."""
|
| 131 |
+
return _NormStatsDict(**json.loads(data)).norm_stats
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def save(directory: pathlib.Path | str, norm_stats: dict[str, NormStats]) -> None:
|
| 135 |
+
"""Save the normalization stats to a directory."""
|
| 136 |
+
path = pathlib.Path(directory) / "norm_stats.json"
|
| 137 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 138 |
+
path.write_text(serialize_json(norm_stats))
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def load(directory: pathlib.Path | str) -> dict[str, NormStats]:
|
| 142 |
+
"""Load the normalization stats from a directory."""
|
| 143 |
+
path = pathlib.Path(directory) / "norm_stats.json"
|
| 144 |
+
if not path.exists():
|
| 145 |
+
raise FileNotFoundError(f"Norm stats file not found at: {path}")
|
| 146 |
+
return deserialize_json(path.read_text())
|
capvector-pi05/src/openpi/shared/normalize_test.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
import openpi.shared.normalize as normalize
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def test_normalize_update():
|
| 7 |
+
arr = np.arange(12).reshape(4, 3) # 4 vectors of length 3
|
| 8 |
+
|
| 9 |
+
stats = normalize.RunningStats()
|
| 10 |
+
for i in range(len(arr)):
|
| 11 |
+
stats.update(arr[i : i + 1]) # Update with one vector at a time
|
| 12 |
+
results = stats.get_statistics()
|
| 13 |
+
|
| 14 |
+
assert np.allclose(results.mean, np.mean(arr, axis=0))
|
| 15 |
+
assert np.allclose(results.std, np.std(arr, axis=0))
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def test_serialize_deserialize():
|
| 19 |
+
stats = normalize.RunningStats()
|
| 20 |
+
stats.update(np.arange(12).reshape(4, 3)) # 4 vectors of length 3
|
| 21 |
+
|
| 22 |
+
norm_stats = {"test": stats.get_statistics()}
|
| 23 |
+
norm_stats2 = normalize.deserialize_json(normalize.serialize_json(norm_stats))
|
| 24 |
+
assert np.allclose(norm_stats["test"].mean, norm_stats2["test"].mean)
|
| 25 |
+
assert np.allclose(norm_stats["test"].std, norm_stats2["test"].std)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def test_multiple_batch_dimensions():
|
| 29 |
+
# Test with multiple batch dimensions: (2, 3, 4) where 4 is vector dimension
|
| 30 |
+
batch_shape = (2, 3, 4)
|
| 31 |
+
arr = np.random.rand(*batch_shape)
|
| 32 |
+
|
| 33 |
+
stats = normalize.RunningStats()
|
| 34 |
+
stats.update(arr) # Should handle (2, 3, 4) -> reshape to (6, 4)
|
| 35 |
+
results = stats.get_statistics()
|
| 36 |
+
|
| 37 |
+
# Flatten batch dimensions and compute expected stats
|
| 38 |
+
flattened = arr.reshape(-1, arr.shape[-1]) # (6, 4)
|
| 39 |
+
expected_mean = np.mean(flattened, axis=0)
|
| 40 |
+
expected_std = np.std(flattened, axis=0)
|
| 41 |
+
|
| 42 |
+
assert np.allclose(results.mean, expected_mean)
|
| 43 |
+
assert np.allclose(results.std, expected_std)
|
capvector-pi05/src/openpi/training/checkpoints.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import concurrent.futures as futures
|
| 5 |
+
import dataclasses
|
| 6 |
+
import logging
|
| 7 |
+
from typing import Protocol
|
| 8 |
+
|
| 9 |
+
from etils import epath
|
| 10 |
+
import jax
|
| 11 |
+
import orbax.checkpoint as ocp
|
| 12 |
+
import orbax.checkpoint.future as future
|
| 13 |
+
|
| 14 |
+
from openpi.shared import array_typing as at
|
| 15 |
+
import openpi.shared.normalize as _normalize
|
| 16 |
+
import openpi.training.data_loader as _data_loader
|
| 17 |
+
import openpi.training.utils as training_utils
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def initialize_checkpoint_dir(
|
| 21 |
+
checkpoint_dir: epath.Path | str, *, keep_period: int | None, overwrite: bool, resume: bool
|
| 22 |
+
) -> tuple[ocp.CheckpointManager, bool]:
|
| 23 |
+
checkpoint_dir = epath.Path(checkpoint_dir).resolve()
|
| 24 |
+
resuming = False
|
| 25 |
+
if checkpoint_dir.exists():
|
| 26 |
+
if overwrite:
|
| 27 |
+
checkpoint_dir.rmtree()
|
| 28 |
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 29 |
+
logging.info(f"Wiped checkpoint directory {checkpoint_dir}")
|
| 30 |
+
elif resume:
|
| 31 |
+
resuming = True
|
| 32 |
+
else:
|
| 33 |
+
raise FileExistsError(
|
| 34 |
+
f"Checkpoint directory {checkpoint_dir} already exists. Use --overwrite or --resume "
|
| 35 |
+
"to indicate how to handle it."
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 39 |
+
|
| 40 |
+
mngr = ocp.CheckpointManager(
|
| 41 |
+
checkpoint_dir,
|
| 42 |
+
item_handlers={
|
| 43 |
+
"assets": CallbackHandler(),
|
| 44 |
+
"train_state": ocp.PyTreeCheckpointHandler(),
|
| 45 |
+
"params": ocp.PyTreeCheckpointHandler(),
|
| 46 |
+
},
|
| 47 |
+
options=ocp.CheckpointManagerOptions(
|
| 48 |
+
max_to_keep=1,
|
| 49 |
+
keep_period=keep_period,
|
| 50 |
+
create=False,
|
| 51 |
+
async_options=ocp.AsyncOptions(timeout_secs=7200),
|
| 52 |
+
),
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# Special case: the checkpoint directory exists and the user requests to resume training, but the training run did
|
| 56 |
+
# not get to the first checkpoint saved. In this case, we don't actually want the train script to try and restore a
|
| 57 |
+
# checkpoint, since it will fail.
|
| 58 |
+
if resuming and tuple(mngr.all_steps()) in [(), (0,)]:
|
| 59 |
+
logging.info("Checkpoint directory exists, but does not contain any checkpoints. Aborting resume.")
|
| 60 |
+
resuming = False
|
| 61 |
+
|
| 62 |
+
return mngr, resuming
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def save_state(
|
| 66 |
+
checkpoint_manager: ocp.CheckpointManager,
|
| 67 |
+
state: training_utils.TrainState,
|
| 68 |
+
data_loader: _data_loader.DataLoader,
|
| 69 |
+
step: int,
|
| 70 |
+
):
|
| 71 |
+
def save_assets(directory: epath.Path):
|
| 72 |
+
# Save the normalization stats.
|
| 73 |
+
data_config = data_loader.data_config()
|
| 74 |
+
norm_stats = data_config.norm_stats
|
| 75 |
+
if norm_stats is not None and data_config.asset_id is not None:
|
| 76 |
+
_normalize.save(directory / data_config.asset_id, norm_stats)
|
| 77 |
+
|
| 78 |
+
# Split params that can be used for inference into a separate item.
|
| 79 |
+
with at.disable_typechecking():
|
| 80 |
+
train_state, params = _split_params(state)
|
| 81 |
+
items = {
|
| 82 |
+
"assets": save_assets,
|
| 83 |
+
"train_state": train_state,
|
| 84 |
+
"params": {"params": params},
|
| 85 |
+
}
|
| 86 |
+
checkpoint_manager.save(step, items)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def restore_state(
|
| 90 |
+
checkpoint_manager: ocp.CheckpointManager,
|
| 91 |
+
state: training_utils.TrainState,
|
| 92 |
+
data_loader: _data_loader.DataLoader,
|
| 93 |
+
step: int | None = None,
|
| 94 |
+
) -> training_utils.TrainState:
|
| 95 |
+
del data_loader
|
| 96 |
+
|
| 97 |
+
with at.disable_typechecking():
|
| 98 |
+
# Split params that can be used for inference into a separate item.
|
| 99 |
+
train_state, params = _split_params(state)
|
| 100 |
+
restored = checkpoint_manager.restore(
|
| 101 |
+
step,
|
| 102 |
+
items={
|
| 103 |
+
"train_state": train_state,
|
| 104 |
+
"params": {"params": params},
|
| 105 |
+
},
|
| 106 |
+
)
|
| 107 |
+
return _merge_params(restored["train_state"], restored["params"])
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def load_norm_stats(assets_dir: epath.Path | str, asset_id: str) -> dict[str, _normalize.NormStats] | None:
|
| 111 |
+
norm_stats_dir = epath.Path(assets_dir) / asset_id
|
| 112 |
+
norm_stats = _normalize.load(norm_stats_dir)
|
| 113 |
+
logging.info(f"Loaded norm stats from {norm_stats_dir}")
|
| 114 |
+
return norm_stats
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class Callback(Protocol):
|
| 118 |
+
def __call__(self, directory: epath.Path) -> None: ...
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class CallbackHandler(ocp.AsyncCheckpointHandler):
|
| 122 |
+
"""A CheckpointHandler for calling an arbitrary function asynchronously. Only for saving, not for restoring."""
|
| 123 |
+
|
| 124 |
+
def save(self, directory: epath.Path, args: CallbackSave):
|
| 125 |
+
if jax.process_index() == 0:
|
| 126 |
+
args.callback(directory)
|
| 127 |
+
|
| 128 |
+
async def async_save(self, directory: epath.Path, args: CallbackSave) -> list[futures.Future]:
|
| 129 |
+
return [future.CommitFutureAwaitingContractedSignals(asyncio.to_thread(self.save, directory, args))]
|
| 130 |
+
|
| 131 |
+
def restore(self, *args, **kwargs):
|
| 132 |
+
raise NotImplementedError("CallbackHandler does not support restore")
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
@ocp.args.register_with_handler(CallbackHandler, for_save=True)
|
| 136 |
+
@dataclasses.dataclass
|
| 137 |
+
class CallbackSave(ocp.args.CheckpointArgs):
|
| 138 |
+
callback: Callback
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
@ocp.args.register_with_handler(CallbackHandler, for_restore=True)
|
| 142 |
+
class CallbackRestore(ocp.args.CheckpointArgs): ...
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _split_params(state: training_utils.TrainState) -> tuple[training_utils.TrainState, at.Params]:
|
| 146 |
+
if state.ema_params is not None:
|
| 147 |
+
params = state.ema_params
|
| 148 |
+
train_state = dataclasses.replace(state, ema_params=None)
|
| 149 |
+
else:
|
| 150 |
+
params = state.params
|
| 151 |
+
train_state = dataclasses.replace(state, params={})
|
| 152 |
+
return train_state, params
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def _merge_params(train_state: training_utils.TrainState, params: dict[str, at.Params]) -> training_utils.TrainState:
|
| 156 |
+
# Revert the logic inside `_split_params`. Assumes that existence of `params` means that EMA params were used during the split.
|
| 157 |
+
if train_state.params:
|
| 158 |
+
return dataclasses.replace(train_state, ema_params=params["params"])
|
| 159 |
+
return dataclasses.replace(train_state, params=params["params"])
|
capvector-pi05/src/openpi/training/config.py
ADDED
|
@@ -0,0 +1,1033 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""See _CONFIGS for the list of available configs."""
|
| 2 |
+
|
| 3 |
+
import abc
|
| 4 |
+
from collections.abc import Sequence
|
| 5 |
+
import dataclasses
|
| 6 |
+
import difflib
|
| 7 |
+
import logging
|
| 8 |
+
import pathlib
|
| 9 |
+
from typing import Any, Literal, Protocol, TypeAlias
|
| 10 |
+
|
| 11 |
+
import etils.epath as epath
|
| 12 |
+
import flax.nnx as nnx
|
| 13 |
+
from typing_extensions import override
|
| 14 |
+
import tyro
|
| 15 |
+
|
| 16 |
+
import openpi.models.model as _model
|
| 17 |
+
import openpi.models.pi0_config as pi0_config
|
| 18 |
+
import openpi.models.pi0_fast as pi0_fast
|
| 19 |
+
import openpi.models.tokenizer as _tokenizer
|
| 20 |
+
import openpi.policies.aloha_policy as aloha_policy
|
| 21 |
+
import openpi.policies.droid_policy as droid_policy
|
| 22 |
+
import openpi.policies.libero_policy as libero_policy
|
| 23 |
+
import openpi.shared.download as _download
|
| 24 |
+
import openpi.shared.normalize as _normalize
|
| 25 |
+
import openpi.training.droid_rlds_dataset as droid_rlds_dataset
|
| 26 |
+
import openpi.training.misc.roboarena_config as roboarena_config
|
| 27 |
+
import openpi.training.optimizer as _optimizer
|
| 28 |
+
import openpi.training.weight_loaders as weight_loaders
|
| 29 |
+
import openpi.transforms as _transforms
|
| 30 |
+
|
| 31 |
+
ModelType: TypeAlias = _model.ModelType
|
| 32 |
+
# Work around a tyro issue with using nnx.filterlib.Filter directly.
|
| 33 |
+
Filter: TypeAlias = nnx.filterlib.Filter
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclasses.dataclass(frozen=True)
|
| 37 |
+
class AssetsConfig:
|
| 38 |
+
"""Determines the location of assets (e.g., norm stats) that will be used to set up the data pipeline.
|
| 39 |
+
|
| 40 |
+
These assets will be replicated inside the checkpoint under the `assets/asset_id` directory.
|
| 41 |
+
|
| 42 |
+
This can be used to load assets from a different checkpoint (e.g., base model checkpoint) or some other
|
| 43 |
+
centralized location. For example, to load the norm stats for the Trossen robot from the base model checkpoint
|
| 44 |
+
during fine-tuning, use:
|
| 45 |
+
|
| 46 |
+
```
|
| 47 |
+
AssetsConfig(
|
| 48 |
+
assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
|
| 49 |
+
asset_id="trossen",
|
| 50 |
+
)
|
| 51 |
+
```
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
# Assets directory. If not provided, the config assets_dirs will be used. This is useful to load assets from
|
| 55 |
+
# a different checkpoint (e.g., base model checkpoint) or some other centralized location.
|
| 56 |
+
assets_dir: str | None = None
|
| 57 |
+
|
| 58 |
+
# Asset id. If not provided, the repo id will be used. This allows users to reference assets that describe
|
| 59 |
+
# different robot platforms.
|
| 60 |
+
asset_id: str | None = None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dataclasses.dataclass(frozen=True)
|
| 64 |
+
class DataConfig:
|
| 65 |
+
# LeRobot repo id. If None, fake data will be created.
|
| 66 |
+
repo_id: str | None = None
|
| 67 |
+
# Directory within the assets directory containing the data assets.
|
| 68 |
+
asset_id: str | None = None
|
| 69 |
+
# Contains precomputed normalization stats. If None, normalization will not be performed.
|
| 70 |
+
norm_stats: dict[str, _transforms.NormStats] | None = None
|
| 71 |
+
|
| 72 |
+
# Used to adopt the inputs from a dataset specific format to a common format
|
| 73 |
+
# which is expected by the data transforms.
|
| 74 |
+
repack_transforms: _transforms.Group = dataclasses.field(default_factory=_transforms.Group)
|
| 75 |
+
# Data transforms, typically include robot specific transformations. Will be applied
|
| 76 |
+
# before the data is normalized. See `model.Observation` and `model.Actions` to learn about the
|
| 77 |
+
# normalized data.
|
| 78 |
+
data_transforms: _transforms.Group = dataclasses.field(default_factory=_transforms.Group)
|
| 79 |
+
# Model specific transforms. Will be applied after the data is normalized.
|
| 80 |
+
model_transforms: _transforms.Group = dataclasses.field(default_factory=_transforms.Group)
|
| 81 |
+
# If true, will use quantile normalization. Otherwise, normal z-score normalization will be used.
|
| 82 |
+
use_quantile_norm: bool = False
|
| 83 |
+
|
| 84 |
+
# Names of keys that will be used by the data loader to generate the action sequence. The length of the
|
| 85 |
+
# sequence is defined by the `action_horizon` field in the model config. This should be adjusted if your
|
| 86 |
+
# LeRobot dataset is using different keys to represent the action.
|
| 87 |
+
action_sequence_keys: Sequence[str] = ("actions",)
|
| 88 |
+
|
| 89 |
+
# If true, will use the LeRobot dataset task to define the prompt.
|
| 90 |
+
prompt_from_task: bool = False
|
| 91 |
+
|
| 92 |
+
# Only used for RLDS data loader (ie currently only used for DROID).
|
| 93 |
+
rlds_data_dir: str | None = None
|
| 94 |
+
# Action space for DROID dataset.
|
| 95 |
+
action_space: droid_rlds_dataset.DroidActionSpace | None = None
|
| 96 |
+
# Path to the data filter file for DROID dataset
|
| 97 |
+
filter_dict_path: str | None = None
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class GroupFactory(Protocol):
|
| 101 |
+
def __call__(self, model_config: _model.BaseModelConfig) -> _transforms.Group:
|
| 102 |
+
"""Create a group."""
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@dataclasses.dataclass(frozen=True)
|
| 106 |
+
class ModelTransformFactory(GroupFactory):
|
| 107 |
+
"""Creates model transforms for standard pi0 models."""
|
| 108 |
+
|
| 109 |
+
# If provided, will determine the default prompt that be used by the model.
|
| 110 |
+
default_prompt: str | None = None
|
| 111 |
+
|
| 112 |
+
def __call__(self, model_config: _model.BaseModelConfig) -> _transforms.Group:
|
| 113 |
+
match model_config.model_type:
|
| 114 |
+
case _model.ModelType.PI0:
|
| 115 |
+
return _transforms.Group(
|
| 116 |
+
inputs=[
|
| 117 |
+
_transforms.InjectDefaultPrompt(self.default_prompt),
|
| 118 |
+
_transforms.ResizeImages(224, 224),
|
| 119 |
+
_transforms.TokenizePrompt(
|
| 120 |
+
_tokenizer.PaligemmaTokenizer(model_config.max_token_len),
|
| 121 |
+
),
|
| 122 |
+
_transforms.PadStatesAndActions(model_config.action_dim),
|
| 123 |
+
],
|
| 124 |
+
)
|
| 125 |
+
case _model.ModelType.PI05:
|
| 126 |
+
assert isinstance(model_config, pi0_config.Pi0Config)
|
| 127 |
+
return _transforms.Group(
|
| 128 |
+
inputs=[
|
| 129 |
+
_transforms.InjectDefaultPrompt(self.default_prompt),
|
| 130 |
+
_transforms.ResizeImages(224, 224),
|
| 131 |
+
_transforms.TokenizePrompt(
|
| 132 |
+
_tokenizer.PaligemmaTokenizer(model_config.max_token_len),
|
| 133 |
+
discrete_state_input=model_config.discrete_state_input,
|
| 134 |
+
),
|
| 135 |
+
_transforms.PadStatesAndActions(model_config.action_dim),
|
| 136 |
+
],
|
| 137 |
+
)
|
| 138 |
+
case _model.ModelType.PI0_FAST:
|
| 139 |
+
tokenizer_cls = (
|
| 140 |
+
_tokenizer.FASTTokenizer
|
| 141 |
+
if model_config.fast_model_tokenizer is None
|
| 142 |
+
else model_config.fast_model_tokenizer
|
| 143 |
+
)
|
| 144 |
+
tokenizer_kwargs = (
|
| 145 |
+
{} if model_config.fast_model_tokenizer_kwargs is None else model_config.fast_model_tokenizer_kwargs
|
| 146 |
+
)
|
| 147 |
+
return _transforms.Group(
|
| 148 |
+
inputs=[
|
| 149 |
+
_transforms.InjectDefaultPrompt(self.default_prompt),
|
| 150 |
+
_transforms.ResizeImages(224, 224),
|
| 151 |
+
_transforms.TokenizeFASTInputs(
|
| 152 |
+
tokenizer_cls(model_config.max_token_len, **tokenizer_kwargs),
|
| 153 |
+
),
|
| 154 |
+
],
|
| 155 |
+
outputs=[
|
| 156 |
+
_transforms.ExtractFASTActions(
|
| 157 |
+
tokenizer_cls(model_config.max_token_len, **tokenizer_kwargs),
|
| 158 |
+
action_horizon=model_config.action_horizon,
|
| 159 |
+
action_dim=model_config.action_dim,
|
| 160 |
+
)
|
| 161 |
+
],
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
@dataclasses.dataclass(frozen=True)
|
| 166 |
+
class DataConfigFactory(abc.ABC):
|
| 167 |
+
# The LeRobot repo id.
|
| 168 |
+
repo_id: str = tyro.MISSING
|
| 169 |
+
# Determines how the assets will be loaded.
|
| 170 |
+
assets: AssetsConfig = dataclasses.field(default_factory=AssetsConfig)
|
| 171 |
+
# Base config that will be updated by the factory.
|
| 172 |
+
base_config: tyro.conf.Suppress[DataConfig | None] = None
|
| 173 |
+
|
| 174 |
+
@abc.abstractmethod
|
| 175 |
+
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
|
| 176 |
+
"""Create a data config."""
|
| 177 |
+
|
| 178 |
+
def create_base_config(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
|
| 179 |
+
repo_id = self.repo_id if self.repo_id is not tyro.MISSING else None
|
| 180 |
+
asset_id = self.assets.asset_id or repo_id
|
| 181 |
+
return dataclasses.replace(
|
| 182 |
+
self.base_config or DataConfig(),
|
| 183 |
+
repo_id=repo_id,
|
| 184 |
+
asset_id=asset_id,
|
| 185 |
+
norm_stats=self._load_norm_stats(epath.Path(self.assets.assets_dir or assets_dirs), asset_id),
|
| 186 |
+
use_quantile_norm=model_config.model_type != ModelType.PI0,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
def _load_norm_stats(self, assets_dir: epath.Path, asset_id: str | None) -> dict[str, _transforms.NormStats] | None:
|
| 190 |
+
if asset_id is None:
|
| 191 |
+
return None
|
| 192 |
+
try:
|
| 193 |
+
data_assets_dir = str(assets_dir / asset_id)
|
| 194 |
+
norm_stats = _normalize.load(_download.maybe_download(data_assets_dir))
|
| 195 |
+
logging.info(f"Loaded norm stats from {data_assets_dir}")
|
| 196 |
+
return norm_stats
|
| 197 |
+
except FileNotFoundError:
|
| 198 |
+
logging.info(f"Norm stats not found in {data_assets_dir}, skipping.")
|
| 199 |
+
return None
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
@dataclasses.dataclass(frozen=True)
|
| 203 |
+
class FakeDataConfig(DataConfigFactory):
|
| 204 |
+
repo_id: str = "fake"
|
| 205 |
+
|
| 206 |
+
@override
|
| 207 |
+
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
|
| 208 |
+
return DataConfig(repo_id=self.repo_id)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
@dataclasses.dataclass(frozen=True)
|
| 212 |
+
class SimpleDataConfig(DataConfigFactory):
|
| 213 |
+
# Factory for the data transforms.
|
| 214 |
+
data_transforms: tyro.conf.Suppress[GroupFactory] = dataclasses.field(default_factory=GroupFactory)
|
| 215 |
+
# Factory for the model transforms.
|
| 216 |
+
model_transforms: tyro.conf.Suppress[GroupFactory] = dataclasses.field(default_factory=ModelTransformFactory)
|
| 217 |
+
|
| 218 |
+
@override
|
| 219 |
+
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
|
| 220 |
+
return dataclasses.replace(
|
| 221 |
+
self.create_base_config(assets_dirs, model_config),
|
| 222 |
+
data_transforms=self.data_transforms(model_config),
|
| 223 |
+
model_transforms=self.model_transforms(model_config),
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
@dataclasses.dataclass(frozen=True)
|
| 228 |
+
class LeRobotAlohaDataConfig(DataConfigFactory):
|
| 229 |
+
# If true, will convert joint dimensions to deltas with respect to the current state before passing to the model.
|
| 230 |
+
# Gripper dimensions will remain in absolute values.
|
| 231 |
+
use_delta_joint_actions: bool = True
|
| 232 |
+
# If provided, will be injected into the input data if the "prompt" key is not present.
|
| 233 |
+
default_prompt: str | None = None
|
| 234 |
+
# If true, this will convert the joint and gripper values from the standard Aloha space to
|
| 235 |
+
# the space used by the pi internal runtime which was used to train the base model. People who
|
| 236 |
+
# use standard Aloha data should set this to true.
|
| 237 |
+
adapt_to_pi: bool = True
|
| 238 |
+
|
| 239 |
+
# Repack transforms.
|
| 240 |
+
repack_transforms: tyro.conf.Suppress[_transforms.Group] = dataclasses.field(
|
| 241 |
+
default=_transforms.Group(
|
| 242 |
+
inputs=[
|
| 243 |
+
_transforms.RepackTransform(
|
| 244 |
+
{
|
| 245 |
+
"images": {"cam_high": "observation.images.top"},
|
| 246 |
+
"state": "observation.state",
|
| 247 |
+
"actions": "action",
|
| 248 |
+
}
|
| 249 |
+
)
|
| 250 |
+
]
|
| 251 |
+
)
|
| 252 |
+
)
|
| 253 |
+
# Action keys that will be used to read the action sequence from the dataset.
|
| 254 |
+
action_sequence_keys: Sequence[str] = ("action",)
|
| 255 |
+
|
| 256 |
+
@override
|
| 257 |
+
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
|
| 258 |
+
data_transforms = _transforms.Group(
|
| 259 |
+
inputs=[aloha_policy.AlohaInputs(adapt_to_pi=self.adapt_to_pi)],
|
| 260 |
+
outputs=[aloha_policy.AlohaOutputs(adapt_to_pi=self.adapt_to_pi)],
|
| 261 |
+
)
|
| 262 |
+
if self.use_delta_joint_actions:
|
| 263 |
+
delta_action_mask = _transforms.make_bool_mask(6, -1, 6, -1)
|
| 264 |
+
data_transforms = data_transforms.push(
|
| 265 |
+
inputs=[_transforms.DeltaActions(delta_action_mask)],
|
| 266 |
+
outputs=[_transforms.AbsoluteActions(delta_action_mask)],
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
model_transforms = ModelTransformFactory(default_prompt=self.default_prompt)(model_config)
|
| 270 |
+
|
| 271 |
+
return dataclasses.replace(
|
| 272 |
+
self.create_base_config(assets_dirs, model_config),
|
| 273 |
+
repack_transforms=self.repack_transforms,
|
| 274 |
+
data_transforms=data_transforms,
|
| 275 |
+
model_transforms=model_transforms,
|
| 276 |
+
action_sequence_keys=self.action_sequence_keys,
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
@dataclasses.dataclass(frozen=True)
|
| 281 |
+
class LeRobotLiberoDataConfig(DataConfigFactory):
|
| 282 |
+
"""
|
| 283 |
+
This config is used to configure transforms that are applied at various parts of the data pipeline.
|
| 284 |
+
For your own dataset, you can copy this class and modify the transforms to match your dataset based on the
|
| 285 |
+
comments below.
|
| 286 |
+
"""
|
| 287 |
+
|
| 288 |
+
extra_delta_transform: bool = False
|
| 289 |
+
|
| 290 |
+
@override
|
| 291 |
+
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
|
| 292 |
+
# The repack transform is *only* applied to the data coming from the dataset,
|
| 293 |
+
# and *not* during inference. We can use it to make inputs from the dataset look
|
| 294 |
+
# as close as possible to those coming from the inference environment (e.g. match the keys).
|
| 295 |
+
# Below, we match the keys in the dataset (which we defined in the data conversion script) to
|
| 296 |
+
# the keys we use in our inference pipeline (defined in the inference script for libero).
|
| 297 |
+
# For your own dataset, first figure out what keys your environment passes to the policy server
|
| 298 |
+
# and then modify the mappings below so your dataset's keys get matched to those target keys.
|
| 299 |
+
# The repack transform simply remaps key names here.
|
| 300 |
+
repack_transform = _transforms.Group(
|
| 301 |
+
inputs=[
|
| 302 |
+
_transforms.RepackTransform(
|
| 303 |
+
{
|
| 304 |
+
"observation/image": "image",
|
| 305 |
+
"observation/wrist_image": "wrist_image",
|
| 306 |
+
"observation/state": "state",
|
| 307 |
+
"actions": "actions",
|
| 308 |
+
"prompt": "prompt",
|
| 309 |
+
}
|
| 310 |
+
)
|
| 311 |
+
]
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
# The data transforms are applied to the data coming from the dataset *and* during inference.
|
| 315 |
+
# Below, we define the transforms for data going into the model (``inputs``) and the transforms
|
| 316 |
+
# for data coming out of the model (``outputs``) (the latter is only used during inference).
|
| 317 |
+
# We defined these transforms in `libero_policy.py`. You can check the detailed comments there for
|
| 318 |
+
# how to modify the transforms to match your dataset. Once you created your own transforms, you can
|
| 319 |
+
# replace the transforms below with your own.
|
| 320 |
+
data_transforms = _transforms.Group(
|
| 321 |
+
inputs=[libero_policy.LiberoInputs(model_type=model_config.model_type)],
|
| 322 |
+
outputs=[libero_policy.LiberoOutputs()],
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
# One additional data transform: pi0 models are trained on delta actions (relative to the first
|
| 326 |
+
# state in each action chunk). IF your data has ``absolute`` actions (e.g. target joint angles)
|
| 327 |
+
# you can uncomment the following line to convert the actions to delta actions. The only exception
|
| 328 |
+
# is for the gripper actions which are always absolute.
|
| 329 |
+
# In the example below, we would apply the delta conversion to the first 6 actions (joints) and
|
| 330 |
+
# leave the 7th action (gripper) unchanged, i.e. absolute.
|
| 331 |
+
# In Libero, the raw actions in the dataset are already delta actions, so we *do not* need to
|
| 332 |
+
# apply a separate delta conversion (that's why it's commented out). Choose whether to apply this
|
| 333 |
+
# transform based on whether your dataset uses ``absolute`` or ``delta`` actions out of the box.
|
| 334 |
+
|
| 335 |
+
# LIBERO already represents actions as deltas, but we have some old Pi0 checkpoints that are trained with this
|
| 336 |
+
# extra delta transform.
|
| 337 |
+
if self.extra_delta_transform:
|
| 338 |
+
delta_action_mask = _transforms.make_bool_mask(6, -1)
|
| 339 |
+
data_transforms = data_transforms.push(
|
| 340 |
+
inputs=[_transforms.DeltaActions(delta_action_mask)],
|
| 341 |
+
outputs=[_transforms.AbsoluteActions(delta_action_mask)],
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
# Model transforms include things like tokenizing the prompt and action targets
|
| 345 |
+
# You do not need to change anything here for your own dataset.
|
| 346 |
+
model_transforms = ModelTransformFactory()(model_config)
|
| 347 |
+
|
| 348 |
+
# We return all data transforms for training and inference. No need to change anything here.
|
| 349 |
+
return dataclasses.replace(
|
| 350 |
+
self.create_base_config(assets_dirs, model_config),
|
| 351 |
+
repack_transforms=repack_transform,
|
| 352 |
+
data_transforms=data_transforms,
|
| 353 |
+
model_transforms=model_transforms,
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
@dataclasses.dataclass(frozen=True)
|
| 358 |
+
class RLDSDroidDataConfig(DataConfigFactory):
|
| 359 |
+
"""
|
| 360 |
+
Config for training on DROID, using RLDS data format (for efficient training on larger datasets).
|
| 361 |
+
"""
|
| 362 |
+
|
| 363 |
+
rlds_data_dir: str | None = None
|
| 364 |
+
action_space: droid_rlds_dataset.DroidActionSpace | None = None
|
| 365 |
+
|
| 366 |
+
# Filtering options. Can pass a path to a dictionary that maps episodes to timestep ranges
|
| 367 |
+
# to tuples denoting ranges of time steps to keep (start, end). Episodes are uniquely identified with
|
| 368 |
+
# f"{recording_folderpath}--{file_path}", both of which are present in the RLDS episode metadata.
|
| 369 |
+
# Path to the filter dictionary file.
|
| 370 |
+
filter_dict_path: str | None = "gs://openpi-assets/droid/droid_sample_ranges_v1_0_1.json"
|
| 371 |
+
|
| 372 |
+
@override
|
| 373 |
+
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
|
| 374 |
+
repack_transform = _transforms.Group(
|
| 375 |
+
inputs=[
|
| 376 |
+
_transforms.RepackTransform(
|
| 377 |
+
{
|
| 378 |
+
"observation/exterior_image_1_left": "observation/image",
|
| 379 |
+
"observation/wrist_image_left": "observation/wrist_image",
|
| 380 |
+
"observation/joint_position": "observation/joint_position",
|
| 381 |
+
"observation/gripper_position": "observation/gripper_position",
|
| 382 |
+
"actions": "actions",
|
| 383 |
+
"prompt": "prompt",
|
| 384 |
+
}
|
| 385 |
+
)
|
| 386 |
+
]
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
data_transforms = _transforms.Group(
|
| 390 |
+
inputs=[droid_policy.DroidInputs(model_type=model_config.model_type)],
|
| 391 |
+
outputs=[droid_policy.DroidOutputs()],
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
if self.action_space == droid_rlds_dataset.DroidActionSpace.JOINT_POSITION:
|
| 395 |
+
# Data loader returns absolute joint position actions -- convert to delta actions for training.
|
| 396 |
+
delta_action_mask = _transforms.make_bool_mask(7, -1)
|
| 397 |
+
data_transforms = data_transforms.push(
|
| 398 |
+
inputs=[_transforms.DeltaActions(delta_action_mask)],
|
| 399 |
+
outputs=[_transforms.AbsoluteActions(delta_action_mask)],
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
model_transforms = ModelTransformFactory()(model_config)
|
| 403 |
+
|
| 404 |
+
assert self.rlds_data_dir is not None, "Need to set rlds data dir for RLDS data loader."
|
| 405 |
+
|
| 406 |
+
return dataclasses.replace(
|
| 407 |
+
self.create_base_config(assets_dirs, model_config),
|
| 408 |
+
repack_transforms=repack_transform,
|
| 409 |
+
data_transforms=data_transforms,
|
| 410 |
+
model_transforms=model_transforms,
|
| 411 |
+
rlds_data_dir=self.rlds_data_dir,
|
| 412 |
+
action_space=self.action_space,
|
| 413 |
+
filter_dict_path=self.filter_dict_path,
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
@dataclasses.dataclass(frozen=True)
|
| 418 |
+
class LeRobotDROIDDataConfig(DataConfigFactory):
|
| 419 |
+
"""
|
| 420 |
+
Example data config for custom DROID dataset in LeRobot format.
|
| 421 |
+
To convert your custom DROID dataset (<10s of hours) to LeRobot format, see examples/droid/convert_droid_data_to_lerobot.py
|
| 422 |
+
"""
|
| 423 |
+
|
| 424 |
+
@override
|
| 425 |
+
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
|
| 426 |
+
repack_transform = _transforms.Group(
|
| 427 |
+
inputs=[
|
| 428 |
+
_transforms.RepackTransform(
|
| 429 |
+
{
|
| 430 |
+
"observation/exterior_image_1_left": "exterior_image_1_left",
|
| 431 |
+
"observation/exterior_image_2_left": "exterior_image_2_left",
|
| 432 |
+
"observation/wrist_image_left": "wrist_image_left",
|
| 433 |
+
"observation/joint_position": "joint_position",
|
| 434 |
+
"observation/gripper_position": "gripper_position",
|
| 435 |
+
"actions": "actions",
|
| 436 |
+
"prompt": "prompt",
|
| 437 |
+
}
|
| 438 |
+
)
|
| 439 |
+
]
|
| 440 |
+
)
|
| 441 |
+
# We assume joint *velocity* actions, so we should *not* apply an additional delta transform.
|
| 442 |
+
data_transforms = _transforms.Group(
|
| 443 |
+
inputs=[droid_policy.DroidInputs(model_type=model_config.model_type)],
|
| 444 |
+
outputs=[droid_policy.DroidOutputs()],
|
| 445 |
+
)
|
| 446 |
+
model_transforms = ModelTransformFactory()(model_config)
|
| 447 |
+
|
| 448 |
+
return dataclasses.replace(
|
| 449 |
+
self.create_base_config(assets_dirs, model_config),
|
| 450 |
+
repack_transforms=repack_transform,
|
| 451 |
+
data_transforms=data_transforms,
|
| 452 |
+
model_transforms=model_transforms,
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
@dataclasses.dataclass(frozen=True)
|
| 457 |
+
class TrainConfig:
|
| 458 |
+
# Name of the config. Must be unique. Will be used to reference this config.
|
| 459 |
+
name: tyro.conf.Suppress[str]
|
| 460 |
+
# Project name.
|
| 461 |
+
project_name: str = "openpi"
|
| 462 |
+
# Experiment name. Will be used to name the metadata and checkpoint directories.
|
| 463 |
+
exp_name: str = tyro.MISSING
|
| 464 |
+
|
| 465 |
+
# Defines the model config. Some attributes (action_dim, action_horizon, and max_token_len) are shared by all models
|
| 466 |
+
# -- see BaseModelConfig. Specific model implementations (e.g., Pi0Config) inherit from BaseModelConfig and may
|
| 467 |
+
# define additional attributes.
|
| 468 |
+
model: _model.BaseModelConfig = dataclasses.field(default_factory=pi0_config.Pi0Config)
|
| 469 |
+
|
| 470 |
+
# A weight loader can optionally load (possibly partial) weights from disk after the model is initialized.
|
| 471 |
+
weight_loader: weight_loaders.WeightLoader = dataclasses.field(default_factory=weight_loaders.NoOpWeightLoader)
|
| 472 |
+
|
| 473 |
+
# Optional path to a PyTorch checkpoint to load weights from.
|
| 474 |
+
pytorch_weight_path: str | None = None
|
| 475 |
+
|
| 476 |
+
# Spatial Forcing configs
|
| 477 |
+
vggt_weight_path: str | None = None
|
| 478 |
+
vggt_dim: int = 1024
|
| 479 |
+
|
| 480 |
+
vla_layers_align: int | None = None # total 18 for paligemma-2b
|
| 481 |
+
vggt_layers_align: int | None = None # total 24 for VGGT
|
| 482 |
+
|
| 483 |
+
pooling_func: str | None = None
|
| 484 |
+
use_vggt_pe: bool | None = None
|
| 485 |
+
use_vlm_norm: bool | None = None
|
| 486 |
+
|
| 487 |
+
align_loss_coeff: float = 0.0
|
| 488 |
+
|
| 489 |
+
# CapVector configs
|
| 490 |
+
regularization_vector_path: str | None = None
|
| 491 |
+
regularization_coeff: float = 0.0
|
| 492 |
+
|
| 493 |
+
# Precision for PyTorch training.
|
| 494 |
+
pytorch_training_precision: Literal["bfloat16", "float32"] = "bfloat16"
|
| 495 |
+
|
| 496 |
+
lr_schedule: _optimizer.LRScheduleConfig = dataclasses.field(default_factory=_optimizer.CosineDecaySchedule)
|
| 497 |
+
optimizer: _optimizer.OptimizerConfig = dataclasses.field(default_factory=_optimizer.AdamW)
|
| 498 |
+
ema_decay: float | None = 0.99
|
| 499 |
+
|
| 500 |
+
# Specifies which weights should be frozen.
|
| 501 |
+
freeze_filter: tyro.conf.Suppress[Filter] = dataclasses.field(default_factory=nnx.Nothing)
|
| 502 |
+
|
| 503 |
+
# Determines the data to be trained on.
|
| 504 |
+
data: DataConfigFactory = dataclasses.field(default_factory=FakeDataConfig)
|
| 505 |
+
|
| 506 |
+
# Base directory for config assets (e.g., norm stats).
|
| 507 |
+
assets_base_dir: str = "./assets"
|
| 508 |
+
# Base directory for checkpoints.
|
| 509 |
+
checkpoint_base_dir: str = "./checkpoints"
|
| 510 |
+
|
| 511 |
+
# Random seed that will be used by random generators during training.
|
| 512 |
+
seed: int = 42
|
| 513 |
+
# Global batch size.
|
| 514 |
+
batch_size: int = 32
|
| 515 |
+
# Number of workers to use for the data loader. Increasing this number will speed up data loading but
|
| 516 |
+
# will increase memory and CPU usage.
|
| 517 |
+
num_workers: int = 2
|
| 518 |
+
# Number of train steps (batches) to run.
|
| 519 |
+
num_train_steps: int = 30_000
|
| 520 |
+
|
| 521 |
+
# How often (in steps) to log training metrics.
|
| 522 |
+
log_interval: int = 100
|
| 523 |
+
# How often (in steps) to save checkpoints.
|
| 524 |
+
save_interval: int = 1000
|
| 525 |
+
# If set, any existing checkpoints matching step % keep_period == 0 will not be deleted.
|
| 526 |
+
keep_period: int | None = 5000
|
| 527 |
+
|
| 528 |
+
# If true, will overwrite the checkpoint directory if it already exists.
|
| 529 |
+
overwrite: bool = False
|
| 530 |
+
# If true, will resume training from the last checkpoint.
|
| 531 |
+
resume: bool = False
|
| 532 |
+
|
| 533 |
+
# If true, will enable wandb logging.
|
| 534 |
+
wandb_enabled: bool = True
|
| 535 |
+
|
| 536 |
+
# Used to pass metadata to the policy server.
|
| 537 |
+
policy_metadata: dict[str, Any] | None = None
|
| 538 |
+
|
| 539 |
+
# If the value is greater than 1, FSDP will be enabled and shard across number of specified devices; overall
|
| 540 |
+
# device memory will be reduced but training could potentially be slower.
|
| 541 |
+
# eg. if total device is 4 and fsdp devices is 2; then the model will shard to 2 devices and run
|
| 542 |
+
# data parallel between 2 groups of devices.
|
| 543 |
+
fsdp_devices: int = 1
|
| 544 |
+
|
| 545 |
+
@property
|
| 546 |
+
def assets_dirs(self) -> pathlib.Path:
|
| 547 |
+
"""Get the assets directory for this config."""
|
| 548 |
+
return (pathlib.Path(self.assets_base_dir) / self.name).resolve()
|
| 549 |
+
|
| 550 |
+
@property
|
| 551 |
+
def checkpoint_dir(self) -> pathlib.Path:
|
| 552 |
+
"""Get the checkpoint directory for this config."""
|
| 553 |
+
if not self.exp_name:
|
| 554 |
+
raise ValueError("--exp_name must be set")
|
| 555 |
+
return (pathlib.Path(self.checkpoint_base_dir) / self.name / self.exp_name).resolve()
|
| 556 |
+
|
| 557 |
+
@property
|
| 558 |
+
def trainable_filter(self) -> nnx.filterlib.Filter:
|
| 559 |
+
"""Get the filter for the trainable parameters."""
|
| 560 |
+
return nnx.All(nnx.Param, nnx.Not(self.freeze_filter))
|
| 561 |
+
|
| 562 |
+
def __post_init__(self) -> None:
|
| 563 |
+
if self.resume and self.overwrite:
|
| 564 |
+
raise ValueError("Cannot resume and overwrite at the same time.")
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
# Use `get_config` if you need to get a config by name in your code.
|
| 568 |
+
_CONFIGS = [
|
| 569 |
+
#
|
| 570 |
+
# Inference Aloha configs.
|
| 571 |
+
#
|
| 572 |
+
TrainConfig(
|
| 573 |
+
name="pi0_aloha",
|
| 574 |
+
model=pi0_config.Pi0Config(),
|
| 575 |
+
data=LeRobotAlohaDataConfig(
|
| 576 |
+
assets=AssetsConfig(asset_id="trossen"),
|
| 577 |
+
),
|
| 578 |
+
policy_metadata={"reset_pose": [0, -1.5, 1.5, 0, 0, 0]},
|
| 579 |
+
),
|
| 580 |
+
TrainConfig(
|
| 581 |
+
name="pi05_aloha",
|
| 582 |
+
model=pi0_config.Pi0Config(pi05=True),
|
| 583 |
+
data=LeRobotAlohaDataConfig(
|
| 584 |
+
assets=AssetsConfig(asset_id="trossen"),
|
| 585 |
+
),
|
| 586 |
+
policy_metadata={"reset_pose": [0, -1.5, 1.5, 0, 0, 0]},
|
| 587 |
+
),
|
| 588 |
+
TrainConfig(
|
| 589 |
+
name="pi0_aloha_towel",
|
| 590 |
+
model=pi0_config.Pi0Config(),
|
| 591 |
+
data=LeRobotAlohaDataConfig(
|
| 592 |
+
assets=AssetsConfig(asset_id="trossen"),
|
| 593 |
+
default_prompt="fold the towel",
|
| 594 |
+
),
|
| 595 |
+
policy_metadata={"reset_pose": [0, -1.5, 1.5, 0, 0, 0]},
|
| 596 |
+
),
|
| 597 |
+
TrainConfig(
|
| 598 |
+
name="pi0_aloha_tupperware",
|
| 599 |
+
model=pi0_config.Pi0Config(),
|
| 600 |
+
data=LeRobotAlohaDataConfig(
|
| 601 |
+
assets=AssetsConfig(asset_id="trossen"),
|
| 602 |
+
default_prompt="open the tupperware and put the food on the plate",
|
| 603 |
+
),
|
| 604 |
+
policy_metadata={"reset_pose": [0, -1.5, 1.5, 0, 0, 0]},
|
| 605 |
+
),
|
| 606 |
+
#
|
| 607 |
+
# Inference DROID configs.
|
| 608 |
+
#
|
| 609 |
+
TrainConfig(
|
| 610 |
+
name="pi0_droid",
|
| 611 |
+
model=pi0_config.Pi0Config(action_horizon=10),
|
| 612 |
+
data=SimpleDataConfig(
|
| 613 |
+
assets=AssetsConfig(asset_id="droid"),
|
| 614 |
+
data_transforms=lambda model: _transforms.Group(
|
| 615 |
+
inputs=[droid_policy.DroidInputs(model_type=ModelType.PI0)],
|
| 616 |
+
outputs=[droid_policy.DroidOutputs()],
|
| 617 |
+
),
|
| 618 |
+
base_config=DataConfig(
|
| 619 |
+
prompt_from_task=True,
|
| 620 |
+
),
|
| 621 |
+
),
|
| 622 |
+
),
|
| 623 |
+
TrainConfig(
|
| 624 |
+
name="pi0_fast_droid",
|
| 625 |
+
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=10),
|
| 626 |
+
data=SimpleDataConfig(
|
| 627 |
+
assets=AssetsConfig(asset_id="droid"),
|
| 628 |
+
data_transforms=lambda model: _transforms.Group(
|
| 629 |
+
inputs=[droid_policy.DroidInputs(model_type=ModelType.PI0_FAST)],
|
| 630 |
+
outputs=[droid_policy.DroidOutputs()],
|
| 631 |
+
),
|
| 632 |
+
base_config=DataConfig(
|
| 633 |
+
prompt_from_task=True,
|
| 634 |
+
),
|
| 635 |
+
),
|
| 636 |
+
),
|
| 637 |
+
TrainConfig(
|
| 638 |
+
name="pi05_droid",
|
| 639 |
+
model=pi0_config.Pi0Config(action_horizon=15, pi05=True),
|
| 640 |
+
data=SimpleDataConfig(
|
| 641 |
+
assets=AssetsConfig(asset_id="droid"),
|
| 642 |
+
data_transforms=lambda model: _transforms.Group(
|
| 643 |
+
inputs=[droid_policy.DroidInputs(model_type=ModelType.PI05)],
|
| 644 |
+
outputs=[droid_policy.DroidOutputs()],
|
| 645 |
+
),
|
| 646 |
+
base_config=DataConfig(
|
| 647 |
+
prompt_from_task=True,
|
| 648 |
+
),
|
| 649 |
+
),
|
| 650 |
+
),
|
| 651 |
+
#
|
| 652 |
+
# Fine-tuning Libero configs.
|
| 653 |
+
#
|
| 654 |
+
# These train configs define the hyperparameters for fine-tuning the base model on your own dataset.
|
| 655 |
+
# They are used to define key elements like the dataset you are training on, the base checkpoint you
|
| 656 |
+
# are using, and other hyperparameters like how many training steps to run or what learning rate to use.
|
| 657 |
+
# For your own dataset, you can copy this class and modify the dataset name, and data transforms based on
|
| 658 |
+
# the comments below.
|
| 659 |
+
TrainConfig(
|
| 660 |
+
# Change the name to reflect your model and dataset.
|
| 661 |
+
name="pi0_libero",
|
| 662 |
+
# Here you define the model config -- In this example we use pi0 as the model
|
| 663 |
+
# architecture and perform *full* finetuning. in the examples below we show how to modify
|
| 664 |
+
# this to perform *low-memory* (LORA) finetuning and use pi0-FAST as an alternative architecture.
|
| 665 |
+
model=pi0_config.Pi0Config(),
|
| 666 |
+
# Here you define the dataset you are training on. In this example we use the Libero
|
| 667 |
+
# dataset. For your own dataset, you can change the repo_id to point to your dataset.
|
| 668 |
+
# Also modify the DataConfig to use the new config you made for your dataset above.
|
| 669 |
+
data=LeRobotLiberoDataConfig(
|
| 670 |
+
repo_id="physical-intelligence/libero",
|
| 671 |
+
base_config=DataConfig(
|
| 672 |
+
# This flag determines whether we load the prompt (i.e. the task instruction) from the
|
| 673 |
+
# ``task`` field in the LeRobot dataset. If set to True, the prompt will show up in
|
| 674 |
+
# a field called ``prompt`` in the input dict. The recommended setting is True.
|
| 675 |
+
prompt_from_task=True,
|
| 676 |
+
),
|
| 677 |
+
extra_delta_transform=True,
|
| 678 |
+
),
|
| 679 |
+
# Here you define which pre-trained checkpoint you want to load to initialize the model.
|
| 680 |
+
# This should match the model config you chose above -- i.e. in this case we use the pi0 base model.
|
| 681 |
+
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"),
|
| 682 |
+
# Below you can define other hyperparameters like the learning rate, number of training steps, etc.
|
| 683 |
+
# Check the base TrainConfig class for a full list of available hyperparameters.
|
| 684 |
+
num_train_steps=30_000,
|
| 685 |
+
),
|
| 686 |
+
TrainConfig(
|
| 687 |
+
name="pi0_libero_low_mem_finetune",
|
| 688 |
+
# Here is an example of loading a pi0 model for LoRA fine-tuning.
|
| 689 |
+
model=pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora"),
|
| 690 |
+
data=LeRobotLiberoDataConfig(
|
| 691 |
+
repo_id="physical-intelligence/libero",
|
| 692 |
+
base_config=DataConfig(prompt_from_task=True),
|
| 693 |
+
extra_delta_transform=True,
|
| 694 |
+
),
|
| 695 |
+
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"),
|
| 696 |
+
num_train_steps=30_000,
|
| 697 |
+
# The freeze filter defines which parameters should be frozen during training.
|
| 698 |
+
# We have a convenience function in the model config that returns the default freeze filter
|
| 699 |
+
# for the given model config for LoRA finetuning. Just make sure it matches the model config
|
| 700 |
+
# you chose above.
|
| 701 |
+
freeze_filter=pi0_config.Pi0Config(
|
| 702 |
+
paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora"
|
| 703 |
+
).get_freeze_filter(),
|
| 704 |
+
# Turn off EMA for LoRA finetuning.
|
| 705 |
+
ema_decay=None,
|
| 706 |
+
),
|
| 707 |
+
TrainConfig(
|
| 708 |
+
name="pi0_fast_libero",
|
| 709 |
+
# Here is an example of loading a pi0-FAST model for full finetuning.
|
| 710 |
+
# Modify action_dim and action_horizon to match your dataset (action horizon is equal to
|
| 711 |
+
# the desired action chunk length).
|
| 712 |
+
# The max_token_len is the maximum number of (non-image) tokens the model can handle.
|
| 713 |
+
# This includes the tokenized prompt, proprioceptive state, and (FAST-tokenized) action tokens.
|
| 714 |
+
# Choosing this value too small may chop off tokens at the end of your sequence (the code will throw
|
| 715 |
+
# a warning), while choosing it too large will waste memory (since we pad each batch element to the
|
| 716 |
+
# max_token_len). A good rule of thumb is to use approx 180 for single-arm robots, and approx 250 for
|
| 717 |
+
# two-arm robots. Generally, err on the lower side here first, and potentially increase the value if
|
| 718 |
+
# you see many warnings being thrown during training.
|
| 719 |
+
model=pi0_fast.Pi0FASTConfig(action_dim=7, action_horizon=10, max_token_len=180),
|
| 720 |
+
data=LeRobotLiberoDataConfig(
|
| 721 |
+
repo_id="physical-intelligence/libero",
|
| 722 |
+
base_config=DataConfig(prompt_from_task=True),
|
| 723 |
+
extra_delta_transform=True,
|
| 724 |
+
),
|
| 725 |
+
# Note that we load the pi0-FAST base model checkpoint here.
|
| 726 |
+
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_fast_base/params"),
|
| 727 |
+
num_train_steps=30_000,
|
| 728 |
+
),
|
| 729 |
+
TrainConfig(
|
| 730 |
+
name="pi0_fast_libero_low_mem_finetune",
|
| 731 |
+
# Here is an example of loading a pi0-FAST model for LoRA finetuning.
|
| 732 |
+
# For setting action_dim, action_horizon, and max_token_len, see the comments above.
|
| 733 |
+
model=pi0_fast.Pi0FASTConfig(
|
| 734 |
+
action_dim=7, action_horizon=10, max_token_len=180, paligemma_variant="gemma_2b_lora"
|
| 735 |
+
),
|
| 736 |
+
data=LeRobotLiberoDataConfig(
|
| 737 |
+
repo_id="physical-intelligence/libero",
|
| 738 |
+
base_config=DataConfig(prompt_from_task=True),
|
| 739 |
+
extra_delta_transform=True,
|
| 740 |
+
),
|
| 741 |
+
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_fast_base/params"),
|
| 742 |
+
num_train_steps=30_000,
|
| 743 |
+
# Again, make sure to match the model config above when extracting the freeze filter
|
| 744 |
+
# that specifies which parameters should be frozen during LoRA finetuning.
|
| 745 |
+
freeze_filter=pi0_fast.Pi0FASTConfig(
|
| 746 |
+
action_dim=7, action_horizon=10, max_token_len=180, paligemma_variant="gemma_2b_lora"
|
| 747 |
+
).get_freeze_filter(),
|
| 748 |
+
# Turn off EMA for LoRA finetuning.
|
| 749 |
+
ema_decay=None,
|
| 750 |
+
),
|
| 751 |
+
TrainConfig(
|
| 752 |
+
name="pi05_libero",
|
| 753 |
+
model=pi0_config.Pi0Config(pi05=True, action_horizon=10, discrete_state_input=False),
|
| 754 |
+
data=LeRobotLiberoDataConfig(
|
| 755 |
+
repo_id="physical-intelligence/libero",
|
| 756 |
+
base_config=DataConfig(prompt_from_task=True),
|
| 757 |
+
extra_delta_transform=False,
|
| 758 |
+
),
|
| 759 |
+
batch_size=256,
|
| 760 |
+
lr_schedule=_optimizer.CosineDecaySchedule(
|
| 761 |
+
warmup_steps=10_000,
|
| 762 |
+
peak_lr=5e-5,
|
| 763 |
+
decay_steps=1_000_000,
|
| 764 |
+
decay_lr=5e-5,
|
| 765 |
+
),
|
| 766 |
+
optimizer=_optimizer.AdamW(clip_gradient_norm=1.0),
|
| 767 |
+
ema_decay=0.999,
|
| 768 |
+
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_base/params"),
|
| 769 |
+
pytorch_weight_path="/path/to/your/pytorch_weight_path",
|
| 770 |
+
num_train_steps=30_000,
|
| 771 |
+
),
|
| 772 |
+
#
|
| 773 |
+
# Fine-tuning Aloha configs.
|
| 774 |
+
#
|
| 775 |
+
# Personal Tasks
|
| 776 |
+
TrainConfig(
|
| 777 |
+
name="pi05_capvector_aloha_place_block", # <config_name>
|
| 778 |
+
model=pi0_config.Pi0Config(pi05=True, discrete_state_input=False),
|
| 779 |
+
data=LeRobotAlohaDataConfig(
|
| 780 |
+
repo_id="cobot_dataset/place_one_floor_block", # your datasets repo_id, like "<org>/<dataset-name>"
|
| 781 |
+
default_prompt="place the green block",
|
| 782 |
+
repack_transforms=_transforms.Group(
|
| 783 |
+
inputs=[
|
| 784 |
+
_transforms.RepackTransform(
|
| 785 |
+
{
|
| 786 |
+
"images": {
|
| 787 |
+
"cam_high": "observation.images.cam_high",
|
| 788 |
+
"cam_left_wrist": "observation.images.cam_left_wrist",
|
| 789 |
+
"cam_right_wrist": "observation.images.cam_right_wrist",
|
| 790 |
+
},
|
| 791 |
+
"state": "observation.state",
|
| 792 |
+
"actions": "action",
|
| 793 |
+
}
|
| 794 |
+
)
|
| 795 |
+
]
|
| 796 |
+
),
|
| 797 |
+
),
|
| 798 |
+
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_base/params"),
|
| 799 |
+
pytorch_weight_path='./checkpoints/vector_init/pi05SF-LIBEROspatial_minus_pi05-LIBEROspatial',
|
| 800 |
+
# CapVector
|
| 801 |
+
regularization_vector_path='checkpoints/diff/pi05SF-LIBEROspatial_minus_pi05-LIBEROspatial.pth',
|
| 802 |
+
regularization_coeff=1e-4,
|
| 803 |
+
#
|
| 804 |
+
num_train_steps=30_000,
|
| 805 |
+
batch_size=16,
|
| 806 |
+
ema_decay=None,
|
| 807 |
+
wandb_enabled=False,
|
| 808 |
+
),
|
| 809 |
+
#
|
| 810 |
+
# This is a test config that is used to illustate how train on a custom LeRobot dataset.
|
| 811 |
+
# For instuctions on how to convert and train on your own Aloha dataset see examples/aloha_real/README.md
|
| 812 |
+
TrainConfig(
|
| 813 |
+
name="pi0_aloha_pen_uncap",
|
| 814 |
+
model=pi0_config.Pi0Config(),
|
| 815 |
+
data=LeRobotAlohaDataConfig(
|
| 816 |
+
repo_id="physical-intelligence/aloha_pen_uncap_diverse",
|
| 817 |
+
assets=AssetsConfig(
|
| 818 |
+
assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
|
| 819 |
+
asset_id="trossen",
|
| 820 |
+
),
|
| 821 |
+
default_prompt="uncap the pen",
|
| 822 |
+
repack_transforms=_transforms.Group(
|
| 823 |
+
inputs=[
|
| 824 |
+
_transforms.RepackTransform(
|
| 825 |
+
{
|
| 826 |
+
"images": {
|
| 827 |
+
"cam_high": "observation.images.cam_high",
|
| 828 |
+
"cam_left_wrist": "observation.images.cam_left_wrist",
|
| 829 |
+
"cam_right_wrist": "observation.images.cam_right_wrist",
|
| 830 |
+
},
|
| 831 |
+
"state": "observation.state",
|
| 832 |
+
"actions": "action",
|
| 833 |
+
}
|
| 834 |
+
)
|
| 835 |
+
]
|
| 836 |
+
),
|
| 837 |
+
),
|
| 838 |
+
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"),
|
| 839 |
+
num_train_steps=20_000,
|
| 840 |
+
),
|
| 841 |
+
TrainConfig(
|
| 842 |
+
name="pi05_aloha_pen_uncap",
|
| 843 |
+
model=pi0_config.Pi0Config(pi05=True),
|
| 844 |
+
data=LeRobotAlohaDataConfig(
|
| 845 |
+
repo_id="physical-intelligence/aloha_pen_uncap_diverse",
|
| 846 |
+
assets=AssetsConfig(
|
| 847 |
+
assets_dir="gs://openpi-assets/checkpoints/pi05_base/assets",
|
| 848 |
+
asset_id="trossen",
|
| 849 |
+
),
|
| 850 |
+
default_prompt="uncap the pen",
|
| 851 |
+
repack_transforms=_transforms.Group(
|
| 852 |
+
inputs=[
|
| 853 |
+
_transforms.RepackTransform(
|
| 854 |
+
{
|
| 855 |
+
"images": {
|
| 856 |
+
"cam_high": "observation.images.cam_high",
|
| 857 |
+
"cam_left_wrist": "observation.images.cam_left_wrist",
|
| 858 |
+
"cam_right_wrist": "observation.images.cam_right_wrist",
|
| 859 |
+
},
|
| 860 |
+
"state": "observation.state",
|
| 861 |
+
"actions": "action",
|
| 862 |
+
}
|
| 863 |
+
)
|
| 864 |
+
]
|
| 865 |
+
),
|
| 866 |
+
),
|
| 867 |
+
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_base/params"),
|
| 868 |
+
num_train_steps=20_000,
|
| 869 |
+
batch_size=64,
|
| 870 |
+
),
|
| 871 |
+
#
|
| 872 |
+
# Fine-tuning DROID configs.
|
| 873 |
+
#
|
| 874 |
+
TrainConfig(
|
| 875 |
+
# This config is for fine-tuning pi0-FAST-base on the *full* DROID dataset.
|
| 876 |
+
# We use RLDS data loading to make training on this large dataset tractable.
|
| 877 |
+
# For fine-tuning on your own DROID dataset, see below.
|
| 878 |
+
name="pi0_fast_full_droid_finetune",
|
| 879 |
+
model=pi0_fast.Pi0FASTConfig(
|
| 880 |
+
action_dim=8,
|
| 881 |
+
action_horizon=16,
|
| 882 |
+
max_token_len=180,
|
| 883 |
+
),
|
| 884 |
+
data=RLDSDroidDataConfig(
|
| 885 |
+
repo_id="droid",
|
| 886 |
+
# Set this to the path to your DROID RLDS dataset (the parent directory of the `droid` directory).
|
| 887 |
+
rlds_data_dir="<path_to_droid_rlds_dataset>",
|
| 888 |
+
action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION,
|
| 889 |
+
),
|
| 890 |
+
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_fast_base/params"),
|
| 891 |
+
lr_schedule=_optimizer.CosineDecaySchedule(
|
| 892 |
+
warmup_steps=1_000,
|
| 893 |
+
peak_lr=5e-5,
|
| 894 |
+
decay_steps=1_000_000,
|
| 895 |
+
decay_lr=5e-5,
|
| 896 |
+
),
|
| 897 |
+
num_train_steps=100_000, # 100k steps should be sufficient, takes ~2 days on 8x H100s
|
| 898 |
+
batch_size=256,
|
| 899 |
+
log_interval=100,
|
| 900 |
+
save_interval=5000,
|
| 901 |
+
keep_period=20_000,
|
| 902 |
+
num_workers=0, # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally
|
| 903 |
+
),
|
| 904 |
+
TrainConfig(
|
| 905 |
+
# This config is for fine-tuning pi05 on the *full* DROID dataset.
|
| 906 |
+
# We use RLDS data loading to make training on this large dataset tractable.
|
| 907 |
+
# For fine-tuning on your own DROID dataset, see below.
|
| 908 |
+
name="pi05_full_droid_finetune",
|
| 909 |
+
model=pi0_config.Pi0Config(
|
| 910 |
+
pi05=True,
|
| 911 |
+
action_dim=32,
|
| 912 |
+
action_horizon=16,
|
| 913 |
+
),
|
| 914 |
+
data=RLDSDroidDataConfig(
|
| 915 |
+
repo_id="droid",
|
| 916 |
+
# Set this to the path to your DROID RLDS dataset (the parent directory of the `droid` directory).
|
| 917 |
+
rlds_data_dir="/mnt/pi-data/kevin",
|
| 918 |
+
action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION,
|
| 919 |
+
assets=AssetsConfig(
|
| 920 |
+
assets_dir="gs://openpi-assets/checkpoints/pi05_base/assets/",
|
| 921 |
+
asset_id="droid",
|
| 922 |
+
),
|
| 923 |
+
),
|
| 924 |
+
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_base/params"),
|
| 925 |
+
lr_schedule=_optimizer.CosineDecaySchedule(
|
| 926 |
+
warmup_steps=1_000,
|
| 927 |
+
peak_lr=5e-5,
|
| 928 |
+
decay_steps=1_000_000,
|
| 929 |
+
decay_lr=5e-5,
|
| 930 |
+
),
|
| 931 |
+
num_train_steps=100_000,
|
| 932 |
+
batch_size=256,
|
| 933 |
+
log_interval=100,
|
| 934 |
+
save_interval=5000,
|
| 935 |
+
keep_period=10_000,
|
| 936 |
+
num_workers=0, # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally
|
| 937 |
+
),
|
| 938 |
+
TrainConfig(
|
| 939 |
+
# This config is for fine-tuning pi05-DROID on a custom (smaller) DROID dataset.
|
| 940 |
+
# Here, we use LeRobot data format (like for all other fine-tuning examples)
|
| 941 |
+
# To convert your custom DROID dataset (<10s of hours) to LeRobot format, see examples/droid/convert_droid_data_to_lerobot.py
|
| 942 |
+
name="pi05_droid_finetune",
|
| 943 |
+
model=pi0_config.Pi0Config(
|
| 944 |
+
pi05=True,
|
| 945 |
+
action_dim=32, # pi05 is trained with 32-dim actions
|
| 946 |
+
action_horizon=16,
|
| 947 |
+
),
|
| 948 |
+
data=LeRobotDROIDDataConfig(
|
| 949 |
+
# Replace with your custom DROID LeRobot dataset repo id.
|
| 950 |
+
repo_id="your_hf_username/my_droid_dataset",
|
| 951 |
+
base_config=DataConfig(prompt_from_task=True),
|
| 952 |
+
assets=AssetsConfig(
|
| 953 |
+
# Important: reuse the original DROID norm stats during fine-tuning!
|
| 954 |
+
assets_dir="gs://openpi-assets/checkpoints/pi05_droid/assets",
|
| 955 |
+
asset_id="droid",
|
| 956 |
+
),
|
| 957 |
+
),
|
| 958 |
+
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_droid/params"),
|
| 959 |
+
num_train_steps=20_000,
|
| 960 |
+
batch_size=32,
|
| 961 |
+
),
|
| 962 |
+
#
|
| 963 |
+
# ALOHA Sim configs. This config is used to demonstrate how to train on a simple simulated environment.
|
| 964 |
+
#
|
| 965 |
+
TrainConfig(
|
| 966 |
+
name="pi0_aloha_sim",
|
| 967 |
+
model=pi0_config.Pi0Config(),
|
| 968 |
+
data=LeRobotAlohaDataConfig(
|
| 969 |
+
repo_id="lerobot/aloha_sim_transfer_cube_human",
|
| 970 |
+
default_prompt="Transfer cube",
|
| 971 |
+
use_delta_joint_actions=False,
|
| 972 |
+
),
|
| 973 |
+
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"),
|
| 974 |
+
num_train_steps=20_000,
|
| 975 |
+
),
|
| 976 |
+
#
|
| 977 |
+
# Debugging configs.
|
| 978 |
+
#
|
| 979 |
+
TrainConfig(
|
| 980 |
+
name="debug",
|
| 981 |
+
data=FakeDataConfig(),
|
| 982 |
+
batch_size=2,
|
| 983 |
+
model=pi0_config.Pi0Config(paligemma_variant="dummy", action_expert_variant="dummy"),
|
| 984 |
+
save_interval=100,
|
| 985 |
+
overwrite=True,
|
| 986 |
+
exp_name="debug",
|
| 987 |
+
num_train_steps=10,
|
| 988 |
+
wandb_enabled=False,
|
| 989 |
+
),
|
| 990 |
+
TrainConfig(
|
| 991 |
+
name="debug_restore",
|
| 992 |
+
data=FakeDataConfig(),
|
| 993 |
+
batch_size=2,
|
| 994 |
+
model=pi0_config.Pi0Config(paligemma_variant="dummy", action_expert_variant="dummy"),
|
| 995 |
+
weight_loader=weight_loaders.CheckpointWeightLoader("./checkpoints/debug/debug/9/params"),
|
| 996 |
+
overwrite=True,
|
| 997 |
+
exp_name="debug",
|
| 998 |
+
num_train_steps=10,
|
| 999 |
+
wandb_enabled=False,
|
| 1000 |
+
),
|
| 1001 |
+
TrainConfig(
|
| 1002 |
+
name="debug_pi05",
|
| 1003 |
+
model=pi0_config.Pi0Config(pi05=True, paligemma_variant="dummy", action_expert_variant="dummy"),
|
| 1004 |
+
data=FakeDataConfig(),
|
| 1005 |
+
batch_size=2,
|
| 1006 |
+
num_train_steps=10,
|
| 1007 |
+
overwrite=True,
|
| 1008 |
+
exp_name="debug_pi05",
|
| 1009 |
+
wandb_enabled=False,
|
| 1010 |
+
),
|
| 1011 |
+
#
|
| 1012 |
+
# RoboArena configs.
|
| 1013 |
+
#
|
| 1014 |
+
*roboarena_config.get_roboarena_configs(),
|
| 1015 |
+
]
|
| 1016 |
+
|
| 1017 |
+
if len({config.name for config in _CONFIGS}) != len(_CONFIGS):
|
| 1018 |
+
raise ValueError("Config names must be unique.")
|
| 1019 |
+
_CONFIGS_DICT = {config.name: config for config in _CONFIGS}
|
| 1020 |
+
|
| 1021 |
+
|
| 1022 |
+
def cli() -> TrainConfig:
|
| 1023 |
+
return tyro.extras.overridable_config_cli({k: (k, v) for k, v in _CONFIGS_DICT.items()})
|
| 1024 |
+
|
| 1025 |
+
|
| 1026 |
+
def get_config(config_name: str) -> TrainConfig:
|
| 1027 |
+
"""Get a config by name."""
|
| 1028 |
+
if config_name not in _CONFIGS_DICT:
|
| 1029 |
+
closest = difflib.get_close_matches(config_name, _CONFIGS_DICT.keys(), n=1, cutoff=0.0)
|
| 1030 |
+
closest_str = f" Did you mean '{closest[0]}'? " if closest else ""
|
| 1031 |
+
raise ValueError(f"Config '{config_name}' not found.{closest_str}")
|
| 1032 |
+
|
| 1033 |
+
return _CONFIGS_DICT[config_name]
|
capvector-pi05/src/openpi/training/data_loader.py
ADDED
|
@@ -0,0 +1,540 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections.abc import Iterator, Sequence
|
| 2 |
+
import logging
|
| 3 |
+
import multiprocessing
|
| 4 |
+
import os
|
| 5 |
+
import typing
|
| 6 |
+
from typing import Literal, Protocol, SupportsIndex, TypeVar
|
| 7 |
+
|
| 8 |
+
import jax
|
| 9 |
+
import jax.numpy as jnp
|
| 10 |
+
import lerobot.common.datasets.lerobot_dataset as lerobot_dataset
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
import openpi.models.model as _model
|
| 15 |
+
import openpi.training.config as _config
|
| 16 |
+
from openpi.training.droid_rlds_dataset import DroidRldsDataset
|
| 17 |
+
import openpi.transforms as _transforms
|
| 18 |
+
|
| 19 |
+
T_co = TypeVar("T_co", covariant=True)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Dataset(Protocol[T_co]):
|
| 23 |
+
"""Interface for a dataset with random access."""
|
| 24 |
+
|
| 25 |
+
def __getitem__(self, index: SupportsIndex) -> T_co:
|
| 26 |
+
raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")
|
| 27 |
+
|
| 28 |
+
def __len__(self) -> int:
|
| 29 |
+
raise NotImplementedError("Subclasses of Dataset should implement __len__.")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class IterableDataset(Protocol[T_co]):
|
| 33 |
+
"""Interface for an iterable dataset."""
|
| 34 |
+
|
| 35 |
+
def __iter__(self) -> Iterator[T_co]:
|
| 36 |
+
raise NotImplementedError("Subclasses of IterableDataset should implement __iter__.")
|
| 37 |
+
|
| 38 |
+
def __len__(self) -> int:
|
| 39 |
+
raise NotImplementedError("Subclasses of Dataset should implement __len__.")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class DataLoader(Protocol[T_co]):
|
| 43 |
+
"""Interface for a data loader."""
|
| 44 |
+
|
| 45 |
+
def data_config(self) -> _config.DataConfig:
|
| 46 |
+
"""Get the data config for this data loader."""
|
| 47 |
+
raise NotImplementedError("Subclasses of DataLoader should implement data_config.")
|
| 48 |
+
|
| 49 |
+
def __iter__(self) -> Iterator[T_co]:
|
| 50 |
+
raise NotImplementedError("Subclasses of DataLoader should implement __iter__.")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class TransformedDataset(Dataset[T_co]):
|
| 54 |
+
def __init__(self, dataset: Dataset, transforms: Sequence[_transforms.DataTransformFn]):
|
| 55 |
+
self._dataset = dataset
|
| 56 |
+
self._transform = _transforms.compose(transforms)
|
| 57 |
+
|
| 58 |
+
def __getitem__(self, index: SupportsIndex) -> T_co:
|
| 59 |
+
return self._transform(self._dataset[index])
|
| 60 |
+
|
| 61 |
+
def __len__(self) -> int:
|
| 62 |
+
return len(self._dataset)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class IterableTransformedDataset(IterableDataset[T_co]):
|
| 66 |
+
def __init__(
|
| 67 |
+
self,
|
| 68 |
+
dataset: IterableDataset,
|
| 69 |
+
transforms: Sequence[_transforms.DataTransformFn],
|
| 70 |
+
*,
|
| 71 |
+
is_batched: bool = False,
|
| 72 |
+
):
|
| 73 |
+
self._dataset = dataset
|
| 74 |
+
self._transform = _transforms.compose(transforms)
|
| 75 |
+
self._is_batched = is_batched
|
| 76 |
+
|
| 77 |
+
def __iter__(self):
|
| 78 |
+
for sample in self._dataset:
|
| 79 |
+
if self._is_batched:
|
| 80 |
+
# Transforms are designed to be applied to individual samples. So we need to split the batch into
|
| 81 |
+
# individual samples and apply the transform to each sample individually.
|
| 82 |
+
batch_size = next(v.shape[0] for v in sample.values())
|
| 83 |
+
|
| 84 |
+
# Split batch into individual samples using tree_map
|
| 85 |
+
individual_samples = [jax.tree.map(lambda x: x[i], sample) for i in range(batch_size)] # noqa: B023
|
| 86 |
+
|
| 87 |
+
# Transform each sample
|
| 88 |
+
transformed = [self._transform(s) for s in individual_samples]
|
| 89 |
+
|
| 90 |
+
# Recombine batch with tree_map
|
| 91 |
+
yield jax.tree.map(lambda *x: np.stack(x, axis=0), *transformed)
|
| 92 |
+
else:
|
| 93 |
+
yield self._transform(sample)
|
| 94 |
+
|
| 95 |
+
def __len__(self) -> int:
|
| 96 |
+
return len(self._dataset)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class FakeDataset(Dataset):
|
| 100 |
+
def __init__(self, model_config: _model.BaseModelConfig, num_samples: int):
|
| 101 |
+
self._num_samples = num_samples
|
| 102 |
+
self._observation_spec, self._action_spec = model_config.inputs_spec()
|
| 103 |
+
|
| 104 |
+
def __getitem__(self, index: SupportsIndex) -> dict:
|
| 105 |
+
rng = jax.random.key(index.__index__())
|
| 106 |
+
|
| 107 |
+
def make_from_spec(spec: jax.ShapeDtypeStruct):
|
| 108 |
+
nonlocal rng
|
| 109 |
+
rng, data_rng = jax.random.split(rng)
|
| 110 |
+
# Remove the batch dimension.
|
| 111 |
+
shape = spec.shape[1:]
|
| 112 |
+
if spec.dtype == jnp.float32:
|
| 113 |
+
return jax.random.uniform(data_rng, shape=shape, minval=-1.0, maxval=1.0)
|
| 114 |
+
if spec.dtype == jnp.int32:
|
| 115 |
+
return jax.random.randint(data_rng, shape=shape, minval=0, maxval=2048)
|
| 116 |
+
return jnp.zeros(shape=shape, dtype=spec.dtype)
|
| 117 |
+
|
| 118 |
+
observation = jax.tree.map(make_from_spec, self._observation_spec)
|
| 119 |
+
action = jax.tree.map(make_from_spec, self._action_spec)
|
| 120 |
+
|
| 121 |
+
return {
|
| 122 |
+
**observation.to_dict(),
|
| 123 |
+
"actions": action,
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
def __len__(self) -> int:
|
| 127 |
+
return self._num_samples
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def create_torch_dataset(
|
| 131 |
+
data_config: _config.DataConfig, action_horizon: int, model_config: _model.BaseModelConfig
|
| 132 |
+
) -> Dataset:
|
| 133 |
+
"""Create a dataset for training."""
|
| 134 |
+
repo_id = data_config.repo_id
|
| 135 |
+
if repo_id is None:
|
| 136 |
+
raise ValueError("Repo ID is not set. Cannot create dataset.")
|
| 137 |
+
if repo_id == "fake":
|
| 138 |
+
return FakeDataset(model_config, num_samples=1024)
|
| 139 |
+
|
| 140 |
+
dataset_meta = lerobot_dataset.LeRobotDatasetMetadata(repo_id)
|
| 141 |
+
dataset = lerobot_dataset.LeRobotDataset(
|
| 142 |
+
data_config.repo_id,
|
| 143 |
+
delta_timestamps={
|
| 144 |
+
key: [t / dataset_meta.fps for t in range(action_horizon)] for key in data_config.action_sequence_keys
|
| 145 |
+
},
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
if data_config.prompt_from_task:
|
| 149 |
+
dataset = TransformedDataset(dataset, [_transforms.PromptFromLeRobotTask(dataset_meta.tasks)])
|
| 150 |
+
|
| 151 |
+
return dataset
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def create_rlds_dataset(
|
| 155 |
+
data_config: _config.DataConfig,
|
| 156 |
+
action_horizon: int,
|
| 157 |
+
batch_size: int,
|
| 158 |
+
*,
|
| 159 |
+
shuffle: bool = False,
|
| 160 |
+
) -> Dataset:
|
| 161 |
+
# At the moment, we only support DROID for RLDS datasets.
|
| 162 |
+
return DroidRldsDataset(
|
| 163 |
+
data_dir=data_config.rlds_data_dir,
|
| 164 |
+
batch_size=batch_size,
|
| 165 |
+
shuffle=shuffle,
|
| 166 |
+
action_chunk_size=action_horizon,
|
| 167 |
+
action_space=data_config.action_space,
|
| 168 |
+
filter_dict_path=data_config.filter_dict_path,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def transform_dataset(dataset: Dataset, data_config: _config.DataConfig, *, skip_norm_stats: bool = False) -> Dataset:
|
| 173 |
+
"""Transform the dataset by applying the data transforms."""
|
| 174 |
+
norm_stats = {}
|
| 175 |
+
if data_config.repo_id != "fake" and not skip_norm_stats:
|
| 176 |
+
if data_config.norm_stats is None:
|
| 177 |
+
raise ValueError(
|
| 178 |
+
"Normalization stats not found. "
|
| 179 |
+
"Make sure to run `scripts/compute_norm_stats.py --config-name=<your-config>`."
|
| 180 |
+
)
|
| 181 |
+
norm_stats = data_config.norm_stats
|
| 182 |
+
|
| 183 |
+
return TransformedDataset(
|
| 184 |
+
dataset,
|
| 185 |
+
[
|
| 186 |
+
*data_config.repack_transforms.inputs,
|
| 187 |
+
*data_config.data_transforms.inputs,
|
| 188 |
+
_transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
|
| 189 |
+
*data_config.model_transforms.inputs,
|
| 190 |
+
],
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def transform_iterable_dataset(
|
| 195 |
+
dataset: IterableDataset,
|
| 196 |
+
data_config: _config.DataConfig,
|
| 197 |
+
*,
|
| 198 |
+
skip_norm_stats: bool = False,
|
| 199 |
+
is_batched: bool = False,
|
| 200 |
+
) -> IterableDataset:
|
| 201 |
+
"""Transform the dataset by applying the data transforms."""
|
| 202 |
+
norm_stats = {}
|
| 203 |
+
if data_config.repo_id != "fake" and not skip_norm_stats:
|
| 204 |
+
if data_config.norm_stats is None:
|
| 205 |
+
raise ValueError(
|
| 206 |
+
"Normalization stats not found. "
|
| 207 |
+
"Make sure to run `scripts/compute_norm_stats.py --config-name=<your-config>`."
|
| 208 |
+
)
|
| 209 |
+
norm_stats = data_config.norm_stats
|
| 210 |
+
|
| 211 |
+
return IterableTransformedDataset(
|
| 212 |
+
dataset,
|
| 213 |
+
[
|
| 214 |
+
*data_config.repack_transforms.inputs,
|
| 215 |
+
*data_config.data_transforms.inputs,
|
| 216 |
+
_transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
|
| 217 |
+
*data_config.model_transforms.inputs,
|
| 218 |
+
],
|
| 219 |
+
is_batched=is_batched,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def create_data_loader(
|
| 224 |
+
config: _config.TrainConfig,
|
| 225 |
+
*,
|
| 226 |
+
sharding: jax.sharding.Sharding | None = None,
|
| 227 |
+
shuffle: bool = False,
|
| 228 |
+
num_batches: int | None = None,
|
| 229 |
+
skip_norm_stats: bool = False,
|
| 230 |
+
framework: Literal["jax", "pytorch"] = "jax",
|
| 231 |
+
) -> DataLoader[tuple[_model.Observation, _model.Actions]]:
|
| 232 |
+
"""Create a data loader for training.
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
config: The training configuration.
|
| 236 |
+
sharding: The sharding to use for the data loader (JAX only).
|
| 237 |
+
shuffle: Whether to shuffle the data.
|
| 238 |
+
num_batches: Determines the number of batches to return.
|
| 239 |
+
skip_norm_stats: Whether to skip data normalization.
|
| 240 |
+
framework: The framework to use ("jax" or "pytorch").
|
| 241 |
+
"""
|
| 242 |
+
data_config = config.data.create(config.assets_dirs, config.model)
|
| 243 |
+
logging.info(f"data_config: {data_config}")
|
| 244 |
+
|
| 245 |
+
if data_config.rlds_data_dir is not None:
|
| 246 |
+
return create_rlds_data_loader(
|
| 247 |
+
data_config,
|
| 248 |
+
action_horizon=config.model.action_horizon,
|
| 249 |
+
batch_size=config.batch_size,
|
| 250 |
+
sharding=sharding,
|
| 251 |
+
shuffle=shuffle,
|
| 252 |
+
num_batches=num_batches,
|
| 253 |
+
skip_norm_stats=skip_norm_stats,
|
| 254 |
+
framework=framework,
|
| 255 |
+
)
|
| 256 |
+
return create_torch_data_loader(
|
| 257 |
+
data_config,
|
| 258 |
+
model_config=config.model,
|
| 259 |
+
action_horizon=config.model.action_horizon,
|
| 260 |
+
batch_size=config.batch_size,
|
| 261 |
+
sharding=sharding,
|
| 262 |
+
shuffle=shuffle,
|
| 263 |
+
num_batches=num_batches,
|
| 264 |
+
num_workers=config.num_workers,
|
| 265 |
+
seed=config.seed,
|
| 266 |
+
skip_norm_stats=skip_norm_stats,
|
| 267 |
+
framework=framework,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def create_torch_data_loader(
|
| 272 |
+
data_config: _config.DataConfig,
|
| 273 |
+
model_config: _model.BaseModelConfig,
|
| 274 |
+
action_horizon: int,
|
| 275 |
+
batch_size: int,
|
| 276 |
+
*,
|
| 277 |
+
sharding: jax.sharding.Sharding | None = None,
|
| 278 |
+
skip_norm_stats: bool = False,
|
| 279 |
+
shuffle: bool = False,
|
| 280 |
+
num_batches: int | None = None,
|
| 281 |
+
num_workers: int = 0,
|
| 282 |
+
seed: int = 0,
|
| 283 |
+
framework: str = "jax",
|
| 284 |
+
) -> DataLoader[tuple[_model.Observation, _model.Actions]]:
|
| 285 |
+
"""Create a data loader for training.
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
data_config: The data configuration.
|
| 289 |
+
action_horizon: The action horizon.
|
| 290 |
+
batch_size: The batch size.
|
| 291 |
+
sharding: The sharding to use for the data loader. If None, the data loader will
|
| 292 |
+
use a single device sharding.
|
| 293 |
+
skip_norm_stats: Whether to skip data normalization.
|
| 294 |
+
shuffle: Whether to shuffle the data.
|
| 295 |
+
num_batches: Determines the number of batches to return. If the number exceeds the
|
| 296 |
+
number of batches in the dataset, the data loader will loop over the dataset.
|
| 297 |
+
If not provided, will iterate over the dataset indefinitely.
|
| 298 |
+
num_workers: The number of worker processes to use. If zero, the data loader will
|
| 299 |
+
execute in the main process.
|
| 300 |
+
seed: The seed to use for shuffling the data.
|
| 301 |
+
"""
|
| 302 |
+
dataset = create_torch_dataset(data_config, action_horizon, model_config)
|
| 303 |
+
dataset = transform_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats)
|
| 304 |
+
|
| 305 |
+
# Use TorchDataLoader for both frameworks
|
| 306 |
+
# For PyTorch DDP, create DistributedSampler and divide batch size by world size
|
| 307 |
+
# For JAX, divide by process count
|
| 308 |
+
sampler = None
|
| 309 |
+
if framework == "pytorch":
|
| 310 |
+
if torch.distributed.is_initialized():
|
| 311 |
+
sampler = torch.utils.data.distributed.DistributedSampler(
|
| 312 |
+
dataset,
|
| 313 |
+
num_replicas=torch.distributed.get_world_size(),
|
| 314 |
+
rank=torch.distributed.get_rank(),
|
| 315 |
+
shuffle=shuffle,
|
| 316 |
+
drop_last=True,
|
| 317 |
+
)
|
| 318 |
+
local_batch_size = batch_size // torch.distributed.get_world_size()
|
| 319 |
+
else:
|
| 320 |
+
local_batch_size = batch_size
|
| 321 |
+
else:
|
| 322 |
+
local_batch_size = batch_size // jax.process_count()
|
| 323 |
+
|
| 324 |
+
logging.info(f"local_batch_size: {local_batch_size}")
|
| 325 |
+
data_loader = TorchDataLoader(
|
| 326 |
+
dataset,
|
| 327 |
+
local_batch_size=local_batch_size,
|
| 328 |
+
sharding=None if framework == "pytorch" else sharding,
|
| 329 |
+
shuffle=(sampler is None and shuffle), # Don't shuffle if using sampler
|
| 330 |
+
sampler=sampler,
|
| 331 |
+
num_batches=num_batches,
|
| 332 |
+
num_workers=num_workers,
|
| 333 |
+
seed=seed,
|
| 334 |
+
framework=framework,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
return DataLoaderImpl(data_config, data_loader)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def create_rlds_data_loader(
|
| 341 |
+
data_config: _config.DataConfig,
|
| 342 |
+
action_horizon: int,
|
| 343 |
+
batch_size: int,
|
| 344 |
+
*,
|
| 345 |
+
sharding: jax.sharding.Sharding | None = None,
|
| 346 |
+
skip_norm_stats: bool = False,
|
| 347 |
+
shuffle: bool = False,
|
| 348 |
+
num_batches: int | None = None,
|
| 349 |
+
framework: str = "jax",
|
| 350 |
+
) -> DataLoader[tuple[_model.Observation, _model.Actions]]:
|
| 351 |
+
"""Create an RLDS data loader for training.
|
| 352 |
+
|
| 353 |
+
Note: This data loader requires some extra dependencies -- see examples/droid/README_train.md
|
| 354 |
+
|
| 355 |
+
Args:
|
| 356 |
+
data_config: The data configuration.
|
| 357 |
+
action_horizon: The action horizon.
|
| 358 |
+
batch_size: The batch size.
|
| 359 |
+
sharding: The sharding to use for the data loader. If None, the data loader will
|
| 360 |
+
use a single device sharding.
|
| 361 |
+
skip_norm_stats: Whether to skip data normalization.
|
| 362 |
+
shuffle: Whether to shuffle the data.
|
| 363 |
+
num_batches: Determines the number of batches to return. If the number exceeds the
|
| 364 |
+
number of batches in the dataset, the data loader will loop over the dataset.
|
| 365 |
+
If not provided, will iterate over the dataset indefinitely.
|
| 366 |
+
"""
|
| 367 |
+
if framework == "pytorch":
|
| 368 |
+
raise NotImplementedError("PyTorch RLDS data loader is not supported yet")
|
| 369 |
+
dataset = create_rlds_dataset(data_config, action_horizon, batch_size, shuffle=shuffle)
|
| 370 |
+
dataset = transform_iterable_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats, is_batched=True)
|
| 371 |
+
|
| 372 |
+
data_loader = RLDSDataLoader(
|
| 373 |
+
dataset,
|
| 374 |
+
sharding=sharding,
|
| 375 |
+
num_batches=num_batches,
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
return DataLoaderImpl(data_config, data_loader)
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
class TorchDataLoader:
|
| 382 |
+
"""Torch data loader implementation."""
|
| 383 |
+
|
| 384 |
+
def __init__(
|
| 385 |
+
self,
|
| 386 |
+
dataset,
|
| 387 |
+
local_batch_size: int,
|
| 388 |
+
*,
|
| 389 |
+
sharding: jax.sharding.Sharding | None = None,
|
| 390 |
+
shuffle: bool = False,
|
| 391 |
+
sampler: torch.utils.data.Sampler | None = None,
|
| 392 |
+
num_batches: int | None = None,
|
| 393 |
+
num_workers: int = 0,
|
| 394 |
+
seed: int = 0,
|
| 395 |
+
framework: str = "jax",
|
| 396 |
+
):
|
| 397 |
+
"""Create a PyTorch data loader.
|
| 398 |
+
|
| 399 |
+
Args:
|
| 400 |
+
dataset: The dataset to load.
|
| 401 |
+
local_batch_size: The local batch size for each process.
|
| 402 |
+
sharding: The sharding to use for the data loader.
|
| 403 |
+
shuffle: Whether to shuffle the data.
|
| 404 |
+
num_batches: If provided, determines the number of returned batches. If the
|
| 405 |
+
number is larger than the number of batches in the dataset, the data loader
|
| 406 |
+
will loop over the dataset. If not provided, will iterate over the dataset
|
| 407 |
+
indefinitely.
|
| 408 |
+
num_workers: The number of worker processes to use. If zero, the data loader will
|
| 409 |
+
execute in the main process.
|
| 410 |
+
seed: The seed to use for shuffling the data.
|
| 411 |
+
"""
|
| 412 |
+
if jax.process_count() > 1:
|
| 413 |
+
raise NotImplementedError("Data loading with multiple processes is not supported.")
|
| 414 |
+
|
| 415 |
+
if len(dataset) < local_batch_size:
|
| 416 |
+
raise ValueError(f"Local batch size ({local_batch_size}) is larger than the dataset size ({len(dataset)}).")
|
| 417 |
+
|
| 418 |
+
# Store sharding - None for PyTorch, JAX sharding for JAX
|
| 419 |
+
self._sharding = sharding
|
| 420 |
+
if sharding is None and framework == "jax":
|
| 421 |
+
# Use data parallel sharding by default for JAX only.
|
| 422 |
+
self._sharding = jax.sharding.NamedSharding(
|
| 423 |
+
jax.sharding.Mesh(jax.devices(), ("B",)),
|
| 424 |
+
jax.sharding.PartitionSpec("B"),
|
| 425 |
+
)
|
| 426 |
+
self._num_batches = num_batches
|
| 427 |
+
|
| 428 |
+
mp_context = None
|
| 429 |
+
if num_workers > 0:
|
| 430 |
+
mp_context = multiprocessing.get_context("spawn")
|
| 431 |
+
|
| 432 |
+
generator = torch.Generator()
|
| 433 |
+
generator.manual_seed(seed)
|
| 434 |
+
self._data_loader = torch.utils.data.DataLoader(
|
| 435 |
+
typing.cast(torch.utils.data.Dataset, dataset),
|
| 436 |
+
batch_size=local_batch_size,
|
| 437 |
+
shuffle=(sampler is None and shuffle), # Don't shuffle if using sampler
|
| 438 |
+
sampler=sampler,
|
| 439 |
+
num_workers=num_workers,
|
| 440 |
+
multiprocessing_context=mp_context,
|
| 441 |
+
persistent_workers=num_workers > 0,
|
| 442 |
+
collate_fn=_collate_fn,
|
| 443 |
+
worker_init_fn=_worker_init_fn,
|
| 444 |
+
drop_last=True,
|
| 445 |
+
generator=generator,
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
@property
|
| 449 |
+
def torch_loader(self) -> torch.utils.data.DataLoader:
|
| 450 |
+
return self._data_loader
|
| 451 |
+
|
| 452 |
+
def __iter__(self):
|
| 453 |
+
num_items = 0
|
| 454 |
+
while True:
|
| 455 |
+
data_iter = iter(self._data_loader)
|
| 456 |
+
while True:
|
| 457 |
+
if self._num_batches is not None and num_items >= self._num_batches:
|
| 458 |
+
return
|
| 459 |
+
try:
|
| 460 |
+
batch = next(data_iter)
|
| 461 |
+
except StopIteration:
|
| 462 |
+
break # We've exhausted the dataset. Create a new iterator and start over.
|
| 463 |
+
num_items += 1
|
| 464 |
+
# For JAX, convert to sharded arrays; for PyTorch, return torch tensors
|
| 465 |
+
if self._sharding is not None:
|
| 466 |
+
yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch)
|
| 467 |
+
else:
|
| 468 |
+
yield jax.tree.map(torch.as_tensor, batch)
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def _collate_fn(items):
|
| 472 |
+
"""Collate the batch elements into batched numpy arrays."""
|
| 473 |
+
# Make sure to convert to numpy arrays before stacking since some of the incoming elements
|
| 474 |
+
# may be JAX arrays.
|
| 475 |
+
return jax.tree.map(lambda *xs: np.stack([np.asarray(x) for x in xs], axis=0), *items)
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def _worker_init_fn(worker_id: int) -> None:
|
| 479 |
+
"""Tell JAX inside the worker process not to preallocate the GPU memory."""
|
| 480 |
+
# NOTE: This is called after jax is imported inside the worker process. This
|
| 481 |
+
# means that this approach will not work for selecting the backend.
|
| 482 |
+
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
|
| 483 |
+
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
class RLDSDataLoader:
|
| 487 |
+
"""Shallow wrapper around the DROID data loader to make it compatible with openpi.
|
| 488 |
+
|
| 489 |
+
All batching already happens in the DROID dataset, so we don't need to do anything here.
|
| 490 |
+
"""
|
| 491 |
+
|
| 492 |
+
def __init__(
|
| 493 |
+
self,
|
| 494 |
+
dataset: DroidRldsDataset,
|
| 495 |
+
*,
|
| 496 |
+
sharding: jax.sharding.Sharding | None = None,
|
| 497 |
+
num_batches: int | None = None,
|
| 498 |
+
):
|
| 499 |
+
self._dataset = dataset
|
| 500 |
+
self._num_batches = num_batches
|
| 501 |
+
|
| 502 |
+
if jax.process_count() > 1:
|
| 503 |
+
raise NotImplementedError("Data loading with multiple processes is not supported.")
|
| 504 |
+
|
| 505 |
+
if sharding is None:
|
| 506 |
+
# Use data parallel sharding by default.
|
| 507 |
+
sharding = jax.sharding.NamedSharding(
|
| 508 |
+
jax.sharding.Mesh(jax.devices(), ("B",)),
|
| 509 |
+
jax.sharding.PartitionSpec("B"),
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
self._sharding = sharding
|
| 513 |
+
self._num_batches = num_batches
|
| 514 |
+
|
| 515 |
+
def __iter__(self):
|
| 516 |
+
num_items = 0
|
| 517 |
+
while True:
|
| 518 |
+
data_iter = iter(self._dataset)
|
| 519 |
+
while True:
|
| 520 |
+
if self._num_batches is not None and num_items >= self._num_batches:
|
| 521 |
+
return
|
| 522 |
+
try:
|
| 523 |
+
batch = next(data_iter)
|
| 524 |
+
except StopIteration:
|
| 525 |
+
break # We've exhausted the dataset. Create a new iterator and start over.
|
| 526 |
+
num_items += 1
|
| 527 |
+
yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch)
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
class DataLoaderImpl(DataLoader):
|
| 531 |
+
def __init__(self, data_config: _config.DataConfig, data_loader: TorchDataLoader | RLDSDataLoader):
|
| 532 |
+
self._data_config = data_config
|
| 533 |
+
self._data_loader = data_loader
|
| 534 |
+
|
| 535 |
+
def data_config(self) -> _config.DataConfig:
|
| 536 |
+
return self._data_config
|
| 537 |
+
|
| 538 |
+
def __iter__(self):
|
| 539 |
+
for batch in self._data_loader:
|
| 540 |
+
yield _model.Observation.from_dict(batch), batch["actions"]
|
capvector-pi05/src/openpi/training/data_loader_test.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
|
| 3 |
+
import jax
|
| 4 |
+
|
| 5 |
+
from openpi.models import pi0_config
|
| 6 |
+
from openpi.training import config as _config
|
| 7 |
+
from openpi.training import data_loader as _data_loader
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def test_torch_data_loader():
|
| 11 |
+
config = pi0_config.Pi0Config(action_dim=24, action_horizon=50, max_token_len=48)
|
| 12 |
+
dataset = _data_loader.FakeDataset(config, 16)
|
| 13 |
+
|
| 14 |
+
loader = _data_loader.TorchDataLoader(
|
| 15 |
+
dataset,
|
| 16 |
+
local_batch_size=4,
|
| 17 |
+
num_batches=2,
|
| 18 |
+
)
|
| 19 |
+
batches = list(loader)
|
| 20 |
+
|
| 21 |
+
assert len(batches) == 2
|
| 22 |
+
for batch in batches:
|
| 23 |
+
assert all(x.shape[0] == 4 for x in jax.tree.leaves(batch))
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def test_torch_data_loader_infinite():
|
| 27 |
+
config = pi0_config.Pi0Config(action_dim=24, action_horizon=50, max_token_len=48)
|
| 28 |
+
dataset = _data_loader.FakeDataset(config, 4)
|
| 29 |
+
|
| 30 |
+
loader = _data_loader.TorchDataLoader(dataset, local_batch_size=4)
|
| 31 |
+
data_iter = iter(loader)
|
| 32 |
+
|
| 33 |
+
for _ in range(10):
|
| 34 |
+
_ = next(data_iter)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def test_torch_data_loader_parallel():
|
| 38 |
+
config = pi0_config.Pi0Config(action_dim=24, action_horizon=50, max_token_len=48)
|
| 39 |
+
dataset = _data_loader.FakeDataset(config, 10)
|
| 40 |
+
|
| 41 |
+
loader = _data_loader.TorchDataLoader(dataset, local_batch_size=4, num_batches=2, num_workers=2)
|
| 42 |
+
batches = list(loader)
|
| 43 |
+
|
| 44 |
+
assert len(batches) == 2
|
| 45 |
+
|
| 46 |
+
for batch in batches:
|
| 47 |
+
assert all(x.shape[0] == 4 for x in jax.tree.leaves(batch))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def test_with_fake_dataset():
|
| 51 |
+
config = _config.get_config("debug")
|
| 52 |
+
|
| 53 |
+
loader = _data_loader.create_data_loader(config, skip_norm_stats=True, num_batches=2)
|
| 54 |
+
batches = list(loader)
|
| 55 |
+
|
| 56 |
+
assert len(batches) == 2
|
| 57 |
+
|
| 58 |
+
for batch in batches:
|
| 59 |
+
assert all(x.shape[0] == config.batch_size for x in jax.tree.leaves(batch))
|
| 60 |
+
|
| 61 |
+
for _, actions in batches:
|
| 62 |
+
assert actions.shape == (config.batch_size, config.model.action_horizon, config.model.action_dim)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def test_with_real_dataset():
|
| 66 |
+
config = _config.get_config("pi0_aloha_sim")
|
| 67 |
+
config = dataclasses.replace(config, batch_size=4)
|
| 68 |
+
|
| 69 |
+
loader = _data_loader.create_data_loader(
|
| 70 |
+
config,
|
| 71 |
+
# Skip since we may not have the data available.
|
| 72 |
+
skip_norm_stats=True,
|
| 73 |
+
num_batches=2,
|
| 74 |
+
shuffle=True,
|
| 75 |
+
)
|
| 76 |
+
# Make sure that we can get the data config.
|
| 77 |
+
assert loader.data_config().repo_id == config.data.repo_id
|
| 78 |
+
|
| 79 |
+
batches = list(loader)
|
| 80 |
+
|
| 81 |
+
assert len(batches) == 2
|
| 82 |
+
|
| 83 |
+
for _, actions in batches:
|
| 84 |
+
assert actions.shape == (config.batch_size, config.model.action_horizon, config.model.action_dim)
|
capvector-pi05/src/openpi/training/droid_rlds_dataset.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RLDS-based data loader for DROID.
|
| 3 |
+
While openpi typically uses LeRobot's data loader, it is not currently scalable enough for larger datasets like DROID.
|
| 4 |
+
Thus, we provide a data loader example here that uses the RLDS data format.
|
| 5 |
+
The data loader also applies a few DROID-specific data filters / transformations.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from enum import Enum
|
| 9 |
+
from enum import auto
|
| 10 |
+
import json
|
| 11 |
+
import logging
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
import tqdm
|
| 15 |
+
|
| 16 |
+
import openpi.shared.download as download
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class DroidActionSpace(Enum):
|
| 20 |
+
"""Action space for DROID dataset."""
|
| 21 |
+
|
| 22 |
+
JOINT_POSITION = auto()
|
| 23 |
+
JOINT_VELOCITY = auto()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class DroidRldsDataset:
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
data_dir: str,
|
| 30 |
+
batch_size: int,
|
| 31 |
+
*, # Force keyword-only arguments
|
| 32 |
+
shuffle: bool = True,
|
| 33 |
+
action_chunk_size: int = 16,
|
| 34 |
+
# We default to joint position actions, since they allow policy evaluation in simulation.
|
| 35 |
+
action_space: DroidActionSpace = DroidActionSpace.JOINT_POSITION,
|
| 36 |
+
max_loaded_steps_per_episode: int = 100,
|
| 37 |
+
# Reduce this if you are running out of memory, but careful -- below ~100k shuffling is not sufficiently random.
|
| 38 |
+
shuffle_buffer_size: int = 250_000,
|
| 39 |
+
num_parallel_reads: int = -1, # -1 == tf.data.AUTOTUNE -- hack to not import tf at top level
|
| 40 |
+
num_parallel_calls: int = -1, # -1 == tf.data.AUTOTUNE -- hack to not import tf at top level
|
| 41 |
+
filter_dict_path=None, # Path to json file with indices to sample during training
|
| 42 |
+
):
|
| 43 |
+
# Import tensorflow here to not make it mandatory in case RLDS data loader is not used.
|
| 44 |
+
import dlimp as dl
|
| 45 |
+
import tensorflow as tf
|
| 46 |
+
import tensorflow_datasets as tfds
|
| 47 |
+
|
| 48 |
+
# Configure Tensorflow with *no GPU devices* (to prevent clobber with PyTorch / JAX)
|
| 49 |
+
tf.config.set_visible_devices([], "GPU")
|
| 50 |
+
|
| 51 |
+
builder = tfds.builder("droid", data_dir=data_dir, version="1.0.1")
|
| 52 |
+
dataset = dl.DLataset.from_rlds(builder, split="train", shuffle=shuffle, num_parallel_reads=num_parallel_reads)
|
| 53 |
+
|
| 54 |
+
# Filter out any unsuccessful trajectories -- we use the file name to check this
|
| 55 |
+
dataset = dataset.filter(
|
| 56 |
+
lambda traj: tf.strings.regex_full_match(
|
| 57 |
+
traj["traj_metadata"]["episode_metadata"]["file_path"][0], ".*success.*"
|
| 58 |
+
)
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# # Repeat dataset so we never run out of data.
|
| 62 |
+
dataset = dataset.repeat()
|
| 63 |
+
|
| 64 |
+
# Load the filter dictionary if provided.
|
| 65 |
+
# The filter dictionary is a JSON file that maps episode keys to ranges of frames to sample
|
| 66 |
+
# (e.g.,
|
| 67 |
+
# {
|
| 68 |
+
# "<episode key>": [[0, 100], [200, 300]]
|
| 69 |
+
# }
|
| 70 |
+
# means keep frames 0-99 and 200-299).
|
| 71 |
+
if filter_dict_path is not None:
|
| 72 |
+
cached_filter_dict_path = download.maybe_download(filter_dict_path)
|
| 73 |
+
with Path(cached_filter_dict_path).open("r") as f:
|
| 74 |
+
filter_dict = json.load(f)
|
| 75 |
+
|
| 76 |
+
logging.info(f"Using filter dictionary with {len(filter_dict)} episodes")
|
| 77 |
+
|
| 78 |
+
keys_tensor = []
|
| 79 |
+
values_tensor = []
|
| 80 |
+
|
| 81 |
+
for episode_key, ranges in tqdm.tqdm(filter_dict.items(), desc="Creating idle filter hash table..."):
|
| 82 |
+
for start, end in ranges:
|
| 83 |
+
for t in range(start, end):
|
| 84 |
+
frame_key = f"{episode_key}--{t}"
|
| 85 |
+
keys_tensor.append(frame_key)
|
| 86 |
+
values_tensor.append(True)
|
| 87 |
+
self.filter_table = tf.lookup.StaticHashTable(
|
| 88 |
+
tf.lookup.KeyValueTensorInitializer(keys_tensor, values_tensor), default_value=False
|
| 89 |
+
)
|
| 90 |
+
logging.info("Filter hash table initialized")
|
| 91 |
+
else:
|
| 92 |
+
self.filter_table = tf.lookup.StaticHashTable(
|
| 93 |
+
tf.lookup.KeyValueTensorInitializer([""], [True]), default_value=True
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
def restructure(traj):
|
| 97 |
+
"""Reformat observation and action keys, sample language instruction."""
|
| 98 |
+
# Important: we use joint *position* action space -- easier to simulate!
|
| 99 |
+
actions = tf.concat(
|
| 100 |
+
(
|
| 101 |
+
(
|
| 102 |
+
traj["action_dict"]["joint_position"]
|
| 103 |
+
if action_space == DroidActionSpace.JOINT_POSITION
|
| 104 |
+
else traj["action_dict"]["joint_velocity"]
|
| 105 |
+
),
|
| 106 |
+
traj["action_dict"]["gripper_position"],
|
| 107 |
+
),
|
| 108 |
+
axis=-1,
|
| 109 |
+
)
|
| 110 |
+
# Randomly samples one of the two exterior images in DROID during training (we only train with one at a time).
|
| 111 |
+
# Note: the "left" refers to the left camera in the stereo pair, we only train on the left camera.
|
| 112 |
+
exterior_img = tf.cond(
|
| 113 |
+
tf.random.uniform(shape=[]) > 0.5,
|
| 114 |
+
lambda: traj["observation"]["exterior_image_1_left"],
|
| 115 |
+
lambda: traj["observation"]["exterior_image_2_left"],
|
| 116 |
+
)
|
| 117 |
+
wrist_img = traj["observation"]["wrist_image_left"]
|
| 118 |
+
# Randomly sample one of the three language instructions
|
| 119 |
+
instruction = tf.random.shuffle(
|
| 120 |
+
[traj["language_instruction"], traj["language_instruction_2"], traj["language_instruction_3"]]
|
| 121 |
+
)[0]
|
| 122 |
+
|
| 123 |
+
traj_len = tf.shape(traj["action"])[0]
|
| 124 |
+
indices = tf.as_string(tf.range(traj_len))
|
| 125 |
+
|
| 126 |
+
# Data filtering:
|
| 127 |
+
# Compute a uniquely-identifying step ID by concatenating the recording folderpath, file path,
|
| 128 |
+
# and each step's time step index. This will index into the filter hash table, and if it returns true,
|
| 129 |
+
# then the frame passes the filter.
|
| 130 |
+
step_id = (
|
| 131 |
+
traj["traj_metadata"]["episode_metadata"]["recording_folderpath"]
|
| 132 |
+
+ "--"
|
| 133 |
+
+ traj["traj_metadata"]["episode_metadata"]["file_path"]
|
| 134 |
+
+ "--"
|
| 135 |
+
+ indices
|
| 136 |
+
)
|
| 137 |
+
passes_filter = self.filter_table.lookup(step_id)
|
| 138 |
+
|
| 139 |
+
return {
|
| 140 |
+
"actions": actions,
|
| 141 |
+
"observation": {
|
| 142 |
+
"image": exterior_img,
|
| 143 |
+
"wrist_image": wrist_img,
|
| 144 |
+
"joint_position": traj["observation"]["joint_position"],
|
| 145 |
+
"gripper_position": traj["observation"]["gripper_position"],
|
| 146 |
+
},
|
| 147 |
+
"prompt": instruction,
|
| 148 |
+
"step_id": step_id,
|
| 149 |
+
"passes_filter": passes_filter,
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
dataset = dataset.traj_map(restructure, num_parallel_calls)
|
| 153 |
+
|
| 154 |
+
def chunk_actions(traj):
|
| 155 |
+
"""Splits episode into action chunks."""
|
| 156 |
+
traj_len = tf.shape(traj["actions"])[0]
|
| 157 |
+
|
| 158 |
+
# For each step in the trajectory, construct indices for the next n actions
|
| 159 |
+
action_chunk_indices = tf.broadcast_to(
|
| 160 |
+
tf.range(action_chunk_size)[None],
|
| 161 |
+
[traj_len, action_chunk_size],
|
| 162 |
+
) + tf.broadcast_to(
|
| 163 |
+
tf.range(traj_len)[:, None],
|
| 164 |
+
[traj_len, action_chunk_size],
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Cap to length of the sequence --> final chunks will repeat the last action
|
| 168 |
+
# This makes sense, since we are using absolute joint + gripper position actions
|
| 169 |
+
action_chunk_indices = tf.minimum(action_chunk_indices, traj_len - 1)
|
| 170 |
+
|
| 171 |
+
# Gather the actions for each chunk
|
| 172 |
+
traj["actions"] = tf.gather(traj["actions"], action_chunk_indices)
|
| 173 |
+
return traj
|
| 174 |
+
|
| 175 |
+
dataset = dataset.traj_map(chunk_actions, num_parallel_calls)
|
| 176 |
+
|
| 177 |
+
# Flatten: map from trajectory dataset to dataset of individual action chunks
|
| 178 |
+
dataset = dataset.flatten(num_parallel_calls=num_parallel_calls)
|
| 179 |
+
|
| 180 |
+
# Filter data that doesn't pass the filter
|
| 181 |
+
def filter_from_dict(frame):
|
| 182 |
+
return frame["passes_filter"]
|
| 183 |
+
|
| 184 |
+
dataset = dataset.filter(filter_from_dict)
|
| 185 |
+
|
| 186 |
+
# Remove "passes_filter" key from output
|
| 187 |
+
def remove_passes_filter(frame):
|
| 188 |
+
frame.pop("passes_filter")
|
| 189 |
+
return frame
|
| 190 |
+
|
| 191 |
+
dataset = dataset.map(remove_passes_filter)
|
| 192 |
+
|
| 193 |
+
# Decode images: RLDS saves encoded images, only decode now for efficiency
|
| 194 |
+
def decode_images(traj):
|
| 195 |
+
traj["observation"]["image"] = tf.io.decode_image(
|
| 196 |
+
traj["observation"]["image"], expand_animations=False, dtype=tf.uint8
|
| 197 |
+
)
|
| 198 |
+
traj["observation"]["wrist_image"] = tf.io.decode_image(
|
| 199 |
+
traj["observation"]["wrist_image"], expand_animations=False, dtype=tf.uint8
|
| 200 |
+
)
|
| 201 |
+
return traj
|
| 202 |
+
|
| 203 |
+
dataset = dataset.frame_map(decode_images, num_parallel_calls)
|
| 204 |
+
|
| 205 |
+
# Shuffle, batch
|
| 206 |
+
dataset = dataset.shuffle(shuffle_buffer_size)
|
| 207 |
+
dataset = dataset.batch(batch_size)
|
| 208 |
+
# Note =>> Seems to reduce memory usage without affecting speed?
|
| 209 |
+
dataset = dataset.with_ram_budget(1)
|
| 210 |
+
|
| 211 |
+
self.dataset = dataset
|
| 212 |
+
self.batch_size = batch_size
|
| 213 |
+
self.shuffle = shuffle
|
| 214 |
+
|
| 215 |
+
def __iter__(self):
|
| 216 |
+
yield from self.dataset.as_numpy_iterator()
|
| 217 |
+
|
| 218 |
+
def __len__(self):
|
| 219 |
+
# This is the approximate number of samples in DROID after filtering.
|
| 220 |
+
# Easier to hardcode than to iterate through the dataset and compute it.
|
| 221 |
+
return 20_000_000
|
capvector-pi05/src/openpi/training/misc/roboarena_config.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""RoboArena baseline policy configs."""
|
| 2 |
+
|
| 3 |
+
from typing import TypeAlias
|
| 4 |
+
|
| 5 |
+
import openpi.models.model as _model
|
| 6 |
+
import openpi.models.pi0_config as pi0_config
|
| 7 |
+
import openpi.models.pi0_fast as pi0_fast
|
| 8 |
+
import openpi.models.tokenizer as _tokenizer
|
| 9 |
+
import openpi.policies.droid_policy as droid_policy
|
| 10 |
+
import openpi.transforms as _transforms
|
| 11 |
+
|
| 12 |
+
ModelType: TypeAlias = _model.ModelType
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_roboarena_configs():
|
| 16 |
+
# Import here to avoid circular imports.
|
| 17 |
+
from openpi.training.config import AssetsConfig
|
| 18 |
+
from openpi.training.config import DataConfig
|
| 19 |
+
from openpi.training.config import SimpleDataConfig
|
| 20 |
+
from openpi.training.config import TrainConfig
|
| 21 |
+
|
| 22 |
+
return [
|
| 23 |
+
#
|
| 24 |
+
# RoboArena DROID baseline inference configs.
|
| 25 |
+
#
|
| 26 |
+
TrainConfig(
|
| 27 |
+
# Trained from PaliGemma, using RT-2 / OpenVLA style binning tokenizer.
|
| 28 |
+
name="paligemma_binning_droid",
|
| 29 |
+
model=pi0_fast.Pi0FASTConfig(
|
| 30 |
+
action_dim=8,
|
| 31 |
+
action_horizon=15,
|
| 32 |
+
max_token_len=400,
|
| 33 |
+
fast_model_tokenizer=_tokenizer.BinningTokenizer,
|
| 34 |
+
),
|
| 35 |
+
data=SimpleDataConfig(
|
| 36 |
+
assets=AssetsConfig(asset_id="droid"),
|
| 37 |
+
data_transforms=lambda model: _transforms.Group(
|
| 38 |
+
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
|
| 39 |
+
outputs=[droid_policy.DroidOutputs()],
|
| 40 |
+
),
|
| 41 |
+
base_config=DataConfig(
|
| 42 |
+
prompt_from_task=True,
|
| 43 |
+
),
|
| 44 |
+
),
|
| 45 |
+
),
|
| 46 |
+
TrainConfig(
|
| 47 |
+
# Trained from PaliGemma, using FAST tokenizer (using universal FAST+ tokenizer).
|
| 48 |
+
name="paligemma_fast_droid",
|
| 49 |
+
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15),
|
| 50 |
+
data=SimpleDataConfig(
|
| 51 |
+
assets=AssetsConfig(asset_id="droid"),
|
| 52 |
+
data_transforms=lambda model: _transforms.Group(
|
| 53 |
+
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
|
| 54 |
+
outputs=[droid_policy.DroidOutputs()],
|
| 55 |
+
),
|
| 56 |
+
base_config=DataConfig(
|
| 57 |
+
prompt_from_task=True,
|
| 58 |
+
),
|
| 59 |
+
),
|
| 60 |
+
),
|
| 61 |
+
TrainConfig(
|
| 62 |
+
# Trained from PaliGemma, using FAST tokenizer (tokenizer trained on DROID dataset).
|
| 63 |
+
name="paligemma_fast_specialist_droid",
|
| 64 |
+
model=pi0_fast.Pi0FASTConfig(
|
| 65 |
+
action_dim=8,
|
| 66 |
+
action_horizon=15,
|
| 67 |
+
fast_model_tokenizer=_tokenizer.FASTTokenizer,
|
| 68 |
+
fast_model_tokenizer_kwargs={"fast_tokenizer_path": "KarlP/fast_droid_specialist"},
|
| 69 |
+
),
|
| 70 |
+
data=SimpleDataConfig(
|
| 71 |
+
assets=AssetsConfig(asset_id="droid"),
|
| 72 |
+
data_transforms=lambda model: _transforms.Group(
|
| 73 |
+
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
|
| 74 |
+
outputs=[droid_policy.DroidOutputs()],
|
| 75 |
+
),
|
| 76 |
+
base_config=DataConfig(
|
| 77 |
+
prompt_from_task=True,
|
| 78 |
+
),
|
| 79 |
+
),
|
| 80 |
+
),
|
| 81 |
+
TrainConfig(
|
| 82 |
+
# Trained from PaliGemma, using FSQ tokenizer.
|
| 83 |
+
name="paligemma_vq_droid",
|
| 84 |
+
model=pi0_fast.Pi0FASTConfig(
|
| 85 |
+
action_dim=8,
|
| 86 |
+
action_horizon=15,
|
| 87 |
+
fast_model_tokenizer=_tokenizer.FSQTokenizer,
|
| 88 |
+
fast_model_tokenizer_kwargs={"fsq_tokenizer_path": "gs://openpi-assets/tokenizers/droid_fsq_tokenizer"},
|
| 89 |
+
),
|
| 90 |
+
data=SimpleDataConfig(
|
| 91 |
+
assets=AssetsConfig(asset_id="droid"),
|
| 92 |
+
data_transforms=lambda model: _transforms.Group(
|
| 93 |
+
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
|
| 94 |
+
outputs=[droid_policy.DroidOutputs()],
|
| 95 |
+
),
|
| 96 |
+
base_config=DataConfig(
|
| 97 |
+
prompt_from_task=True,
|
| 98 |
+
),
|
| 99 |
+
),
|
| 100 |
+
),
|
| 101 |
+
TrainConfig(
|
| 102 |
+
# pi0-style diffusion / flow VLA, trained on DROID from PaliGemma.
|
| 103 |
+
name="paligemma_diffusion_droid",
|
| 104 |
+
model=pi0_config.Pi0Config(action_horizon=10, action_dim=8),
|
| 105 |
+
data=SimpleDataConfig(
|
| 106 |
+
assets=AssetsConfig(asset_id="droid"),
|
| 107 |
+
data_transforms=lambda model: _transforms.Group(
|
| 108 |
+
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim)],
|
| 109 |
+
outputs=[droid_policy.DroidOutputs()],
|
| 110 |
+
),
|
| 111 |
+
base_config=DataConfig(
|
| 112 |
+
prompt_from_task=True,
|
| 113 |
+
),
|
| 114 |
+
),
|
| 115 |
+
),
|
| 116 |
+
]
|
capvector-pi05/src/openpi/training/optimizer.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
from typing import Protocol, runtime_checkable
|
| 3 |
+
|
| 4 |
+
import jax.numpy as jnp
|
| 5 |
+
import optax
|
| 6 |
+
|
| 7 |
+
import openpi.shared.array_typing as at
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@runtime_checkable
|
| 11 |
+
class LRScheduleConfig(Protocol):
|
| 12 |
+
def create(self) -> optax.Schedule: ...
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclasses.dataclass(frozen=True)
|
| 16 |
+
class CosineDecaySchedule(LRScheduleConfig):
|
| 17 |
+
"""Cosine decay schedule with warmup."""
|
| 18 |
+
|
| 19 |
+
warmup_steps: int = 1_000
|
| 20 |
+
peak_lr: float = 2.5e-5
|
| 21 |
+
decay_steps: int = 30_000
|
| 22 |
+
decay_lr: float = 2.5e-6
|
| 23 |
+
|
| 24 |
+
def create(self) -> optax.Schedule:
|
| 25 |
+
return optax.warmup_cosine_decay_schedule(
|
| 26 |
+
init_value=self.peak_lr / (self.warmup_steps + 1),
|
| 27 |
+
peak_value=self.peak_lr,
|
| 28 |
+
warmup_steps=self.warmup_steps,
|
| 29 |
+
decay_steps=self.decay_steps,
|
| 30 |
+
end_value=self.decay_lr,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclasses.dataclass(frozen=True)
|
| 35 |
+
class RsqrtDecaySchedule(LRScheduleConfig):
|
| 36 |
+
"""Inverse square root decay schedule with warmup."""
|
| 37 |
+
|
| 38 |
+
warmup_steps: int = 1_000
|
| 39 |
+
peak_lr: float = 5e-5
|
| 40 |
+
timescale: float = 10_000
|
| 41 |
+
|
| 42 |
+
def create(self) -> optax.Schedule:
|
| 43 |
+
return optax.join_schedules(
|
| 44 |
+
[
|
| 45 |
+
optax.linear_schedule(
|
| 46 |
+
init_value=self.peak_lr / (self.warmup_steps + 1),
|
| 47 |
+
end_value=self.peak_lr,
|
| 48 |
+
transition_steps=self.warmup_steps,
|
| 49 |
+
),
|
| 50 |
+
lambda step: self.peak_lr / jnp.sqrt((self.timescale + step) / self.timescale),
|
| 51 |
+
],
|
| 52 |
+
[self.warmup_steps],
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@runtime_checkable
|
| 57 |
+
class OptimizerConfig(Protocol):
|
| 58 |
+
def create(
|
| 59 |
+
self,
|
| 60 |
+
lr: optax.ScalarOrSchedule,
|
| 61 |
+
weight_decay_mask: at.PyTree | None = None,
|
| 62 |
+
) -> optax.GradientTransformation: ...
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@dataclasses.dataclass(frozen=True)
|
| 66 |
+
class AdamW(OptimizerConfig):
|
| 67 |
+
"""AdamW optimizer."""
|
| 68 |
+
|
| 69 |
+
b1: float = 0.9
|
| 70 |
+
b2: float = 0.95
|
| 71 |
+
eps: float = 1e-8
|
| 72 |
+
# Changing this to 0 can cause out-of-memory errors for some reason, so we set it to a negligible value.
|
| 73 |
+
weight_decay: float = 1e-10
|
| 74 |
+
clip_gradient_norm: float = 1.0
|
| 75 |
+
|
| 76 |
+
def create(
|
| 77 |
+
self,
|
| 78 |
+
lr: optax.ScalarOrSchedule,
|
| 79 |
+
weight_decay_mask: at.PyTree | None = None,
|
| 80 |
+
) -> optax.GradientTransformation:
|
| 81 |
+
tx = optax.adamw(
|
| 82 |
+
lr, b1=self.b1, b2=self.b2, eps=self.eps, weight_decay=self.weight_decay, mask=weight_decay_mask
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
return optax.chain(optax.clip_by_global_norm(self.clip_gradient_norm), tx)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@dataclasses.dataclass(frozen=True)
|
| 89 |
+
class SGD(OptimizerConfig):
|
| 90 |
+
"""SGD optimizer."""
|
| 91 |
+
|
| 92 |
+
lr: float = 5e-5
|
| 93 |
+
momentum: float = 0.9
|
| 94 |
+
nesterov: bool = False
|
| 95 |
+
|
| 96 |
+
def create(
|
| 97 |
+
self,
|
| 98 |
+
lr: optax.ScalarOrSchedule,
|
| 99 |
+
weight_decay_mask: at.PyTree | None = None,
|
| 100 |
+
) -> optax.GradientTransformation:
|
| 101 |
+
assert weight_decay_mask is None, "Weight decay is not supported for SGD"
|
| 102 |
+
return optax.sgd(lr, momentum=self.momentum, nesterov=self.nesterov)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def create_optimizer(
|
| 106 |
+
optimizer: OptimizerConfig, lr_schedule: LRScheduleConfig, weight_decay_mask: at.PyTree | None = None
|
| 107 |
+
) -> optax.GradientTransformation:
|
| 108 |
+
lr = lr_schedule.create()
|
| 109 |
+
return optimizer.create(lr, weight_decay_mask=weight_decay_mask)
|
capvector-pi05/src/openpi/training/sharding.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
import jax
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
BATCH_AXIS = "batch"
|
| 8 |
+
FSDP_AXIS = "fsdp"
|
| 9 |
+
# In FSDP, we shard the data across both the batch and FSDP axes.
|
| 10 |
+
DATA_AXIS = (BATCH_AXIS, FSDP_AXIS)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class _MeshState:
|
| 14 |
+
active_mesh: jax.sharding.Mesh | None = None
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def make_mesh(num_fsdp_devices: int) -> jax.sharding.Mesh:
|
| 18 |
+
if jax.device_count() % num_fsdp_devices != 0:
|
| 19 |
+
raise ValueError(
|
| 20 |
+
f"Number of devices {jax.device_count()} must be divisible by the number of FSDP devices {num_fsdp_devices}."
|
| 21 |
+
)
|
| 22 |
+
mesh_shape = (jax.device_count() // num_fsdp_devices, num_fsdp_devices)
|
| 23 |
+
return jax.make_mesh(mesh_shape, (BATCH_AXIS, FSDP_AXIS))
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@contextlib.contextmanager
|
| 27 |
+
def set_mesh(mesh: jax.sharding.Mesh):
|
| 28 |
+
"""Plumbing the mesh deep into the module tree is extremeley cumbersome; until the JAX team lands a better API, a
|
| 29 |
+
custom context manager like this one is the recommended way to maintain a reference to a global mesh. This is only used
|
| 30 |
+
in `activation_sharding_constraint` below."""
|
| 31 |
+
if _MeshState.active_mesh is not None:
|
| 32 |
+
raise ValueError("Cannot nest set_mesh context managers.")
|
| 33 |
+
_MeshState.active_mesh = mesh
|
| 34 |
+
try:
|
| 35 |
+
yield
|
| 36 |
+
finally:
|
| 37 |
+
_MeshState.active_mesh = None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def activation_sharding_constraint(pytree):
|
| 41 |
+
if _MeshState.active_mesh is None:
|
| 42 |
+
return pytree
|
| 43 |
+
return jax.lax.with_sharding_constraint(
|
| 44 |
+
pytree, jax.sharding.NamedSharding(_MeshState.active_mesh, jax.sharding.PartitionSpec(DATA_AXIS))
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def fsdp_sharding(
|
| 49 |
+
pytree,
|
| 50 |
+
mesh: jax.sharding.Mesh,
|
| 51 |
+
*,
|
| 52 |
+
min_size_mbytes: int = 4, # 4 MiB
|
| 53 |
+
log: bool = False,
|
| 54 |
+
):
|
| 55 |
+
"""Apply FSDP sharding to a pytree of arrays based on the mesh shape.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
pytree: A pytree to be apply sharding specified by the mesh, note that only array types (eg. contains .shape attr)
|
| 59 |
+
will be considered for sharding.
|
| 60 |
+
mesh: The mesh being used for applying sharding on to pytree.
|
| 61 |
+
min_size_mbytes: The minimum size of the array in MiB to be considered for sharding, any array smaller than this
|
| 62 |
+
will be replicated.
|
| 63 |
+
log: If true, will log the sharding decisions for arrays that are being considered for sharding.
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
The sharded pytree.
|
| 67 |
+
"""
|
| 68 |
+
min_size_bytes = min_size_mbytes * 2**20
|
| 69 |
+
|
| 70 |
+
def _shard_arr(kp, array: jax.ShapeDtypeStruct):
|
| 71 |
+
# if fsdp is not actually going to be used, replicate everything to avoid extraneous logging
|
| 72 |
+
if mesh.shape[FSDP_AXIS] == 1:
|
| 73 |
+
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
| 74 |
+
# replicate scalar and vector arrays
|
| 75 |
+
if not hasattr(array, "shape"):
|
| 76 |
+
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
| 77 |
+
if len(array.shape) < 2:
|
| 78 |
+
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
| 79 |
+
# replicate small arrays
|
| 80 |
+
if (arr_size := np.prod(array.shape) * np.dtype(array.dtype).itemsize) < min_size_bytes:
|
| 81 |
+
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
| 82 |
+
|
| 83 |
+
# shard matrices and larger tensors along the largest axis that is divisible by the fsdp dimension
|
| 84 |
+
axes = np.argsort(array.shape)[::-1]
|
| 85 |
+
spec = [None] * len(axes)
|
| 86 |
+
for i in axes:
|
| 87 |
+
if array.shape[i] % mesh.shape[FSDP_AXIS] == 0:
|
| 88 |
+
if log:
|
| 89 |
+
logging.info(
|
| 90 |
+
f"Sharding {jax.tree_util.keystr(kp)} of shape {array.shape} ({arr_size / 2**20:.2f} MiB) along axis {i}"
|
| 91 |
+
)
|
| 92 |
+
spec[i] = FSDP_AXIS
|
| 93 |
+
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(*spec))
|
| 94 |
+
|
| 95 |
+
# replicate if no valid sharding was found
|
| 96 |
+
if log:
|
| 97 |
+
logging.warning(
|
| 98 |
+
f"Could not find a valid sharding for {jax.tree_util.keystr(kp)} of shape {array.shape} with mesh of shape {mesh.shape}"
|
| 99 |
+
)
|
| 100 |
+
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
| 101 |
+
|
| 102 |
+
return jax.tree_util.tree_map_with_path(_shard_arr, pytree)
|
capvector-pi05/src/openpi/training/utils.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections.abc import Callable
|
| 2 |
+
from typing import Any
|
| 3 |
+
|
| 4 |
+
from flax import nnx
|
| 5 |
+
from flax import struct
|
| 6 |
+
import jax
|
| 7 |
+
import optax
|
| 8 |
+
|
| 9 |
+
from openpi.models import model as _model
|
| 10 |
+
from openpi.shared import array_typing as at
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@at.typecheck
|
| 14 |
+
@struct.dataclass
|
| 15 |
+
class TrainState:
|
| 16 |
+
step: at.Int[at.ArrayLike, ""]
|
| 17 |
+
params: nnx.State
|
| 18 |
+
model_def: nnx.GraphDef[_model.BaseModel]
|
| 19 |
+
opt_state: optax.OptState
|
| 20 |
+
tx: optax.GradientTransformation = struct.field(pytree_node=False)
|
| 21 |
+
|
| 22 |
+
ema_decay: float | None = struct.field(pytree_node=False)
|
| 23 |
+
ema_params: nnx.State | None = None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@at.typecheck
|
| 27 |
+
def tree_to_info(tree: at.PyTree, interp_func: Callable[[Any], str] = str) -> str:
|
| 28 |
+
"""Converts a PyTree into a human-readable string for logging. Optionally, `interp_func` can be provided to convert
|
| 29 |
+
the leaf values to more meaningful strings.
|
| 30 |
+
"""
|
| 31 |
+
tree, _ = jax.tree_util.tree_flatten_with_path(tree)
|
| 32 |
+
return "\n".join(f"{jax.tree_util.keystr(path)}: {interp_func(value)}" for path, value in tree)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@at.typecheck
|
| 36 |
+
def array_tree_to_info(tree: at.PyTree) -> str:
|
| 37 |
+
"""Converts a PyTree of arrays into a human-readable string for logging."""
|
| 38 |
+
return tree_to_info(tree, lambda x: f"{x.shape}@{x.dtype}")
|
capvector-pi05/src/openpi/training/weight_loaders.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import logging
|
| 3 |
+
import re
|
| 4 |
+
from typing import Protocol, runtime_checkable
|
| 5 |
+
|
| 6 |
+
import flax.traverse_util
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
import openpi.models.model as _model
|
| 10 |
+
import openpi.shared.array_typing as at
|
| 11 |
+
import openpi.shared.download as download
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@runtime_checkable
|
| 17 |
+
class WeightLoader(Protocol):
|
| 18 |
+
def load(self, params: at.Params) -> at.Params:
|
| 19 |
+
"""Loads the model weights.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
params: Parameters of the model. This is a nested structure of array-like objects that
|
| 23 |
+
represent the model's parameters.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
Loaded parameters. The structure must be identical to `params`. If returning a subset of
|
| 27 |
+
the parameters the loader must merge the loaded parameters with `params`.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclasses.dataclass(frozen=True)
|
| 32 |
+
class NoOpWeightLoader(WeightLoader):
|
| 33 |
+
def load(self, params: at.Params) -> at.Params:
|
| 34 |
+
return params
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclasses.dataclass(frozen=True)
|
| 38 |
+
class CheckpointWeightLoader(WeightLoader):
|
| 39 |
+
"""Loads an entire set of weights from a checkpoint.
|
| 40 |
+
|
| 41 |
+
Compatible with:
|
| 42 |
+
trained checkpoints:
|
| 43 |
+
example: "./checkpoints/<config>/<exp>/<step>/params"
|
| 44 |
+
released checkpoints:
|
| 45 |
+
example: "gs://openpi-assets/checkpoints/<model>/params"
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
params_path: str
|
| 49 |
+
|
| 50 |
+
def load(self, params: at.Params) -> at.Params:
|
| 51 |
+
# We are loading np.ndarray and relying on the training code to properly convert and shard the params.
|
| 52 |
+
loaded_params = _model.restore_params(download.maybe_download(self.params_path), restore_type=np.ndarray)
|
| 53 |
+
# Add all missing LoRA weights.
|
| 54 |
+
return _merge_params(loaded_params, params, missing_regex=".*lora.*")
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@dataclasses.dataclass(frozen=True)
|
| 58 |
+
class PaliGemmaWeightLoader(WeightLoader):
|
| 59 |
+
"""Loads weights from the official PaliGemma checkpoint.
|
| 60 |
+
|
| 61 |
+
This will overwrite existing weights with similar names while keeping all extra weights intact.
|
| 62 |
+
This allows us to support the action expert which is used by the Pi0 model.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
def load(self, params: at.Params) -> at.Params:
|
| 66 |
+
path = download.maybe_download(
|
| 67 |
+
"gs://vertex-model-garden-paligemma-us/paligemma/pt_224.npz", gs={"token": "anon"}
|
| 68 |
+
)
|
| 69 |
+
with path.open("rb") as f:
|
| 70 |
+
flat_params = dict(np.load(f, allow_pickle=False))
|
| 71 |
+
loaded_params = {"PaliGemma": flax.traverse_util.unflatten_dict(flat_params, sep="/")["params"]}
|
| 72 |
+
# Add all missing weights.
|
| 73 |
+
return _merge_params(loaded_params, params, missing_regex=".*")
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _merge_params(loaded_params: at.Params, params: at.Params, *, missing_regex: str) -> at.Params:
|
| 77 |
+
"""Merges the loaded parameters with the reference parameters.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
loaded_params: The parameters to merge.
|
| 81 |
+
params: The reference parameters.
|
| 82 |
+
missing_regex: A regex pattern for all missing keys that should be merged from the reference parameters.
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
A new dictionary with the merged parameters.
|
| 86 |
+
"""
|
| 87 |
+
flat_ref = flax.traverse_util.flatten_dict(params, sep="/")
|
| 88 |
+
flat_loaded = flax.traverse_util.flatten_dict(loaded_params, sep="/")
|
| 89 |
+
|
| 90 |
+
# First, take all weights that are a subset of the reference weights.
|
| 91 |
+
result = {}
|
| 92 |
+
for k, v in flat_loaded.items():
|
| 93 |
+
if k in flat_ref:
|
| 94 |
+
result[k] = v.astype(flat_ref[k].dtype) if v.dtype != flat_ref[k].dtype else v
|
| 95 |
+
|
| 96 |
+
flat_loaded.clear()
|
| 97 |
+
|
| 98 |
+
# Then, merge any missing weights as defined by the missing regex.
|
| 99 |
+
pattern = re.compile(missing_regex)
|
| 100 |
+
for k in {k for k in flat_ref if pattern.fullmatch(k)}:
|
| 101 |
+
if k not in result:
|
| 102 |
+
result[k] = flat_ref[k]
|
| 103 |
+
|
| 104 |
+
return flax.traverse_util.unflatten_dict(result, sep="/")
|
capvector-pi05/src/vggt/__init__.py
ADDED
|
File without changes
|
capvector-pi05/src/vggt/dependency/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .track_modules.track_refine import refine_track
|
| 2 |
+
from .track_modules.blocks import BasicEncoder, ShallowEncoder
|
| 3 |
+
from .track_modules.base_track_predictor import BaseTrackerPredictor
|
capvector-pi05/src/vggt/dependency/distortion.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
from typing import Union
|
| 10 |
+
|
| 11 |
+
ArrayLike = Union[np.ndarray, torch.Tensor]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _is_numpy(x: ArrayLike) -> bool:
|
| 15 |
+
return isinstance(x, np.ndarray)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _is_torch(x: ArrayLike) -> bool:
|
| 19 |
+
return isinstance(x, torch.Tensor)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _ensure_torch(x: ArrayLike) -> torch.Tensor:
|
| 23 |
+
"""Convert input to torch tensor if it's not already one."""
|
| 24 |
+
if _is_numpy(x):
|
| 25 |
+
return torch.from_numpy(x)
|
| 26 |
+
elif _is_torch(x):
|
| 27 |
+
return x
|
| 28 |
+
else:
|
| 29 |
+
return torch.tensor(x)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def single_undistortion(params, tracks_normalized):
|
| 33 |
+
"""
|
| 34 |
+
Apply undistortion to the normalized tracks using the given distortion parameters once.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN.
|
| 38 |
+
tracks_normalized (torch.Tensor or numpy.ndarray): Normalized tracks tensor of shape [batch_size, num_tracks, 2].
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
torch.Tensor: Undistorted normalized tracks tensor.
|
| 42 |
+
"""
|
| 43 |
+
params = _ensure_torch(params)
|
| 44 |
+
tracks_normalized = _ensure_torch(tracks_normalized)
|
| 45 |
+
|
| 46 |
+
u, v = tracks_normalized[..., 0].clone(), tracks_normalized[..., 1].clone()
|
| 47 |
+
u_undist, v_undist = apply_distortion(params, u, v)
|
| 48 |
+
return torch.stack([u_undist, v_undist], dim=-1)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def iterative_undistortion(params, tracks_normalized, max_iterations=100, max_step_norm=1e-10, rel_step_size=1e-6):
|
| 52 |
+
"""
|
| 53 |
+
Iteratively undistort the normalized tracks using the given distortion parameters.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN.
|
| 57 |
+
tracks_normalized (torch.Tensor or numpy.ndarray): Normalized tracks tensor of shape [batch_size, num_tracks, 2].
|
| 58 |
+
max_iterations (int): Maximum number of iterations for the undistortion process.
|
| 59 |
+
max_step_norm (float): Maximum step norm for convergence.
|
| 60 |
+
rel_step_size (float): Relative step size for numerical differentiation.
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
torch.Tensor: Undistorted normalized tracks tensor.
|
| 64 |
+
"""
|
| 65 |
+
params = _ensure_torch(params)
|
| 66 |
+
tracks_normalized = _ensure_torch(tracks_normalized)
|
| 67 |
+
|
| 68 |
+
B, N, _ = tracks_normalized.shape
|
| 69 |
+
u, v = tracks_normalized[..., 0].clone(), tracks_normalized[..., 1].clone()
|
| 70 |
+
original_u, original_v = u.clone(), v.clone()
|
| 71 |
+
|
| 72 |
+
eps = torch.finfo(u.dtype).eps
|
| 73 |
+
for idx in range(max_iterations):
|
| 74 |
+
u_undist, v_undist = apply_distortion(params, u, v)
|
| 75 |
+
dx = original_u - u_undist
|
| 76 |
+
dy = original_v - v_undist
|
| 77 |
+
|
| 78 |
+
step_u = torch.clamp(torch.abs(u) * rel_step_size, min=eps)
|
| 79 |
+
step_v = torch.clamp(torch.abs(v) * rel_step_size, min=eps)
|
| 80 |
+
|
| 81 |
+
J_00 = (apply_distortion(params, u + step_u, v)[0] - apply_distortion(params, u - step_u, v)[0]) / (2 * step_u)
|
| 82 |
+
J_01 = (apply_distortion(params, u, v + step_v)[0] - apply_distortion(params, u, v - step_v)[0]) / (2 * step_v)
|
| 83 |
+
J_10 = (apply_distortion(params, u + step_u, v)[1] - apply_distortion(params, u - step_u, v)[1]) / (2 * step_u)
|
| 84 |
+
J_11 = (apply_distortion(params, u, v + step_v)[1] - apply_distortion(params, u, v - step_v)[1]) / (2 * step_v)
|
| 85 |
+
|
| 86 |
+
J = torch.stack([torch.stack([J_00 + 1, J_01], dim=-1), torch.stack([J_10, J_11 + 1], dim=-1)], dim=-2)
|
| 87 |
+
|
| 88 |
+
delta = torch.linalg.solve(J, torch.stack([dx, dy], dim=-1))
|
| 89 |
+
|
| 90 |
+
u += delta[..., 0]
|
| 91 |
+
v += delta[..., 1]
|
| 92 |
+
|
| 93 |
+
if torch.max((delta**2).sum(dim=-1)) < max_step_norm:
|
| 94 |
+
break
|
| 95 |
+
|
| 96 |
+
return torch.stack([u, v], dim=-1)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def apply_distortion(extra_params, u, v):
|
| 100 |
+
"""
|
| 101 |
+
Applies radial or OpenCV distortion to the given 2D points.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
extra_params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN, where N can be 1, 2, or 4.
|
| 105 |
+
u (torch.Tensor or numpy.ndarray): Normalized x coordinates of shape Bxnum_tracks.
|
| 106 |
+
v (torch.Tensor or numpy.ndarray): Normalized y coordinates of shape Bxnum_tracks.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
points2D (torch.Tensor): Distorted 2D points of shape BxNx2.
|
| 110 |
+
"""
|
| 111 |
+
extra_params = _ensure_torch(extra_params)
|
| 112 |
+
u = _ensure_torch(u)
|
| 113 |
+
v = _ensure_torch(v)
|
| 114 |
+
|
| 115 |
+
num_params = extra_params.shape[1]
|
| 116 |
+
|
| 117 |
+
if num_params == 1:
|
| 118 |
+
# Simple radial distortion
|
| 119 |
+
k = extra_params[:, 0]
|
| 120 |
+
u2 = u * u
|
| 121 |
+
v2 = v * v
|
| 122 |
+
r2 = u2 + v2
|
| 123 |
+
radial = k[:, None] * r2
|
| 124 |
+
du = u * radial
|
| 125 |
+
dv = v * radial
|
| 126 |
+
|
| 127 |
+
elif num_params == 2:
|
| 128 |
+
# RadialCameraModel distortion
|
| 129 |
+
k1, k2 = extra_params[:, 0], extra_params[:, 1]
|
| 130 |
+
u2 = u * u
|
| 131 |
+
v2 = v * v
|
| 132 |
+
r2 = u2 + v2
|
| 133 |
+
radial = k1[:, None] * r2 + k2[:, None] * r2 * r2
|
| 134 |
+
du = u * radial
|
| 135 |
+
dv = v * radial
|
| 136 |
+
|
| 137 |
+
elif num_params == 4:
|
| 138 |
+
# OpenCVCameraModel distortion
|
| 139 |
+
k1, k2, p1, p2 = (extra_params[:, 0], extra_params[:, 1], extra_params[:, 2], extra_params[:, 3])
|
| 140 |
+
u2 = u * u
|
| 141 |
+
v2 = v * v
|
| 142 |
+
uv = u * v
|
| 143 |
+
r2 = u2 + v2
|
| 144 |
+
radial = k1[:, None] * r2 + k2[:, None] * r2 * r2
|
| 145 |
+
du = u * radial + 2 * p1[:, None] * uv + p2[:, None] * (r2 + 2 * u2)
|
| 146 |
+
dv = v * radial + 2 * p2[:, None] * uv + p1[:, None] * (r2 + 2 * v2)
|
| 147 |
+
else:
|
| 148 |
+
raise ValueError("Unsupported number of distortion parameters")
|
| 149 |
+
|
| 150 |
+
u = u.clone() + du
|
| 151 |
+
v = v.clone() + dv
|
| 152 |
+
|
| 153 |
+
return u, v
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
if __name__ == "__main__":
|
| 157 |
+
import random
|
| 158 |
+
import pycolmap
|
| 159 |
+
|
| 160 |
+
max_diff = 0
|
| 161 |
+
for i in range(1000):
|
| 162 |
+
# Define distortion parameters (assuming 1 parameter for simplicity)
|
| 163 |
+
B = random.randint(1, 500)
|
| 164 |
+
track_num = random.randint(100, 1000)
|
| 165 |
+
params = torch.rand((B, 1), dtype=torch.float32) # Batch size 1, 4 parameters
|
| 166 |
+
tracks_normalized = torch.rand((B, track_num, 2), dtype=torch.float32) # Batch size 1, 5 points
|
| 167 |
+
|
| 168 |
+
# Undistort the tracks
|
| 169 |
+
undistorted_tracks = iterative_undistortion(params, tracks_normalized)
|
| 170 |
+
|
| 171 |
+
for b in range(B):
|
| 172 |
+
pycolmap_intri = np.array([1, 0, 0, params[b].item()])
|
| 173 |
+
pycam = pycolmap.Camera(model="SIMPLE_RADIAL", width=1, height=1, params=pycolmap_intri, camera_id=0)
|
| 174 |
+
|
| 175 |
+
undistorted_tracks_pycolmap = pycam.cam_from_img(tracks_normalized[b].numpy())
|
| 176 |
+
diff = (undistorted_tracks[b] - undistorted_tracks_pycolmap).abs().median()
|
| 177 |
+
max_diff = max(max_diff, diff)
|
| 178 |
+
print(f"diff: {diff}, max_diff: {max_diff}")
|
| 179 |
+
|
| 180 |
+
import pdb
|
| 181 |
+
|
| 182 |
+
pdb.set_trace()
|
capvector-pi05/src/vggt/dependency/np_to_pycolmap.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pycolmap
|
| 9 |
+
from .projection import project_3D_points_np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def batch_np_matrix_to_pycolmap(
|
| 13 |
+
points3d,
|
| 14 |
+
extrinsics,
|
| 15 |
+
intrinsics,
|
| 16 |
+
tracks,
|
| 17 |
+
image_size,
|
| 18 |
+
masks=None,
|
| 19 |
+
max_reproj_error=None,
|
| 20 |
+
max_points3D_val=3000,
|
| 21 |
+
shared_camera=False,
|
| 22 |
+
camera_type="SIMPLE_PINHOLE",
|
| 23 |
+
extra_params=None,
|
| 24 |
+
min_inlier_per_frame=64,
|
| 25 |
+
points_rgb=None,
|
| 26 |
+
):
|
| 27 |
+
"""
|
| 28 |
+
Convert Batched NumPy Arrays to PyCOLMAP
|
| 29 |
+
|
| 30 |
+
Check https://github.com/colmap/pycolmap for more details about its format
|
| 31 |
+
|
| 32 |
+
NOTE that colmap expects images/cameras/points3D to be 1-indexed
|
| 33 |
+
so there is a +1 offset between colmap index and batch index
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
NOTE: different from VGGSfM, this function:
|
| 37 |
+
1. Use np instead of torch
|
| 38 |
+
2. Frame index and camera id starts from 1 rather than 0 (to fit the format of PyCOLMAP)
|
| 39 |
+
"""
|
| 40 |
+
# points3d: Px3
|
| 41 |
+
# extrinsics: Nx3x4
|
| 42 |
+
# intrinsics: Nx3x3
|
| 43 |
+
# tracks: NxPx2
|
| 44 |
+
# masks: NxP
|
| 45 |
+
# image_size: 2, assume all the frames have been padded to the same size
|
| 46 |
+
# where N is the number of frames and P is the number of tracks
|
| 47 |
+
|
| 48 |
+
N, P, _ = tracks.shape
|
| 49 |
+
assert len(extrinsics) == N
|
| 50 |
+
assert len(intrinsics) == N
|
| 51 |
+
assert len(points3d) == P
|
| 52 |
+
assert image_size.shape[0] == 2
|
| 53 |
+
|
| 54 |
+
reproj_mask = None
|
| 55 |
+
|
| 56 |
+
if max_reproj_error is not None:
|
| 57 |
+
projected_points_2d, projected_points_cam = project_3D_points_np(points3d, extrinsics, intrinsics)
|
| 58 |
+
projected_diff = np.linalg.norm(projected_points_2d - tracks, axis=-1)
|
| 59 |
+
projected_points_2d[projected_points_cam[:, -1] <= 0] = 1e6
|
| 60 |
+
reproj_mask = projected_diff < max_reproj_error
|
| 61 |
+
|
| 62 |
+
if masks is not None and reproj_mask is not None:
|
| 63 |
+
masks = np.logical_and(masks, reproj_mask)
|
| 64 |
+
elif masks is not None:
|
| 65 |
+
masks = masks
|
| 66 |
+
else:
|
| 67 |
+
masks = reproj_mask
|
| 68 |
+
|
| 69 |
+
assert masks is not None
|
| 70 |
+
|
| 71 |
+
if masks.sum(1).min() < min_inlier_per_frame:
|
| 72 |
+
print(f"Not enough inliers per frame, skip BA.")
|
| 73 |
+
return None, None
|
| 74 |
+
|
| 75 |
+
# Reconstruction object, following the format of PyCOLMAP/COLMAP
|
| 76 |
+
reconstruction = pycolmap.Reconstruction()
|
| 77 |
+
|
| 78 |
+
inlier_num = masks.sum(0)
|
| 79 |
+
valid_mask = inlier_num >= 2 # a track is invalid if without two inliers
|
| 80 |
+
valid_idx = np.nonzero(valid_mask)[0]
|
| 81 |
+
|
| 82 |
+
# Only add 3D points that have sufficient 2D points
|
| 83 |
+
for vidx in valid_idx:
|
| 84 |
+
# Use RGB colors if provided, otherwise use zeros
|
| 85 |
+
rgb = points_rgb[vidx] if points_rgb is not None else np.zeros(3)
|
| 86 |
+
reconstruction.add_point3D(points3d[vidx], pycolmap.Track(), rgb)
|
| 87 |
+
|
| 88 |
+
num_points3D = len(valid_idx)
|
| 89 |
+
camera = None
|
| 90 |
+
# frame idx
|
| 91 |
+
for fidx in range(N):
|
| 92 |
+
# set camera
|
| 93 |
+
if camera is None or (not shared_camera):
|
| 94 |
+
pycolmap_intri = _build_pycolmap_intri(fidx, intrinsics, camera_type, extra_params)
|
| 95 |
+
|
| 96 |
+
camera = pycolmap.Camera(
|
| 97 |
+
model=camera_type, width=image_size[0], height=image_size[1], params=pycolmap_intri, camera_id=fidx + 1
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# add camera
|
| 101 |
+
reconstruction.add_camera(camera)
|
| 102 |
+
|
| 103 |
+
# set image
|
| 104 |
+
cam_from_world = pycolmap.Rigid3d(
|
| 105 |
+
pycolmap.Rotation3d(extrinsics[fidx][:3, :3]), extrinsics[fidx][:3, 3]
|
| 106 |
+
) # Rot and Trans
|
| 107 |
+
|
| 108 |
+
image = pycolmap.Image(
|
| 109 |
+
id=fidx + 1, name=f"image_{fidx + 1}", camera_id=camera.camera_id, cam_from_world=cam_from_world
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
points2D_list = []
|
| 113 |
+
|
| 114 |
+
point2D_idx = 0
|
| 115 |
+
|
| 116 |
+
# NOTE point3D_id start by 1
|
| 117 |
+
for point3D_id in range(1, num_points3D + 1):
|
| 118 |
+
original_track_idx = valid_idx[point3D_id - 1]
|
| 119 |
+
|
| 120 |
+
if (reconstruction.points3D[point3D_id].xyz < max_points3D_val).all():
|
| 121 |
+
if masks[fidx][original_track_idx]:
|
| 122 |
+
# It seems we don't need +0.5 for BA
|
| 123 |
+
point2D_xy = tracks[fidx][original_track_idx]
|
| 124 |
+
# Please note when adding the Point2D object
|
| 125 |
+
# It not only requires the 2D xy location, but also the id to 3D point
|
| 126 |
+
points2D_list.append(pycolmap.Point2D(point2D_xy, point3D_id))
|
| 127 |
+
|
| 128 |
+
# add element
|
| 129 |
+
track = reconstruction.points3D[point3D_id].track
|
| 130 |
+
track.add_element(fidx + 1, point2D_idx)
|
| 131 |
+
point2D_idx += 1
|
| 132 |
+
|
| 133 |
+
assert point2D_idx == len(points2D_list)
|
| 134 |
+
|
| 135 |
+
try:
|
| 136 |
+
image.points2D = pycolmap.ListPoint2D(points2D_list)
|
| 137 |
+
image.registered = True
|
| 138 |
+
except:
|
| 139 |
+
print(f"frame {fidx + 1} is out of BA")
|
| 140 |
+
image.registered = False
|
| 141 |
+
|
| 142 |
+
# add image
|
| 143 |
+
reconstruction.add_image(image)
|
| 144 |
+
|
| 145 |
+
return reconstruction, valid_mask
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def pycolmap_to_batch_np_matrix(reconstruction, device="cpu", camera_type="SIMPLE_PINHOLE"):
|
| 149 |
+
"""
|
| 150 |
+
Convert a PyCOLMAP Reconstruction Object to batched NumPy arrays.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
reconstruction (pycolmap.Reconstruction): The reconstruction object from PyCOLMAP.
|
| 154 |
+
device (str): Ignored in NumPy version (kept for API compatibility).
|
| 155 |
+
camera_type (str): The type of camera model used (default: "SIMPLE_PINHOLE").
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
tuple: A tuple containing points3D, extrinsics, intrinsics, and optionally extra_params.
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
num_images = len(reconstruction.images)
|
| 162 |
+
max_points3D_id = max(reconstruction.point3D_ids())
|
| 163 |
+
points3D = np.zeros((max_points3D_id, 3))
|
| 164 |
+
|
| 165 |
+
for point3D_id in reconstruction.points3D:
|
| 166 |
+
points3D[point3D_id - 1] = reconstruction.points3D[point3D_id].xyz
|
| 167 |
+
|
| 168 |
+
extrinsics = []
|
| 169 |
+
intrinsics = []
|
| 170 |
+
|
| 171 |
+
extra_params = [] if camera_type == "SIMPLE_RADIAL" else None
|
| 172 |
+
|
| 173 |
+
for i in range(num_images):
|
| 174 |
+
# Extract and append extrinsics
|
| 175 |
+
pyimg = reconstruction.images[i + 1]
|
| 176 |
+
pycam = reconstruction.cameras[pyimg.camera_id]
|
| 177 |
+
matrix = pyimg.cam_from_world.matrix()
|
| 178 |
+
extrinsics.append(matrix)
|
| 179 |
+
|
| 180 |
+
# Extract and append intrinsics
|
| 181 |
+
calibration_matrix = pycam.calibration_matrix()
|
| 182 |
+
intrinsics.append(calibration_matrix)
|
| 183 |
+
|
| 184 |
+
if camera_type == "SIMPLE_RADIAL":
|
| 185 |
+
extra_params.append(pycam.params[-1])
|
| 186 |
+
|
| 187 |
+
# Convert lists to NumPy arrays instead of torch tensors
|
| 188 |
+
extrinsics = np.stack(extrinsics)
|
| 189 |
+
intrinsics = np.stack(intrinsics)
|
| 190 |
+
|
| 191 |
+
if camera_type == "SIMPLE_RADIAL":
|
| 192 |
+
extra_params = np.stack(extra_params)
|
| 193 |
+
extra_params = extra_params[:, None]
|
| 194 |
+
|
| 195 |
+
return points3D, extrinsics, intrinsics, extra_params
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
########################################################
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def batch_np_matrix_to_pycolmap_wo_track(
|
| 202 |
+
points3d,
|
| 203 |
+
points_xyf,
|
| 204 |
+
points_rgb,
|
| 205 |
+
extrinsics,
|
| 206 |
+
intrinsics,
|
| 207 |
+
image_size,
|
| 208 |
+
shared_camera=False,
|
| 209 |
+
camera_type="SIMPLE_PINHOLE",
|
| 210 |
+
):
|
| 211 |
+
"""
|
| 212 |
+
Convert Batched NumPy Arrays to PyCOLMAP
|
| 213 |
+
|
| 214 |
+
Different from batch_np_matrix_to_pycolmap, this function does not use tracks.
|
| 215 |
+
|
| 216 |
+
It saves points3d to colmap reconstruction format only to serve as init for Gaussians or other nvs methods.
|
| 217 |
+
|
| 218 |
+
Do NOT use this for BA.
|
| 219 |
+
"""
|
| 220 |
+
# points3d: Px3
|
| 221 |
+
# points_xyf: Px3, with x, y coordinates and frame indices
|
| 222 |
+
# points_rgb: Px3, rgb colors
|
| 223 |
+
# extrinsics: Nx3x4
|
| 224 |
+
# intrinsics: Nx3x3
|
| 225 |
+
# image_size: 2, assume all the frames have been padded to the same size
|
| 226 |
+
# where N is the number of frames and P is the number of tracks
|
| 227 |
+
|
| 228 |
+
N = len(extrinsics)
|
| 229 |
+
P = len(points3d)
|
| 230 |
+
|
| 231 |
+
# Reconstruction object, following the format of PyCOLMAP/COLMAP
|
| 232 |
+
reconstruction = pycolmap.Reconstruction()
|
| 233 |
+
|
| 234 |
+
for vidx in range(P):
|
| 235 |
+
reconstruction.add_point3D(points3d[vidx], pycolmap.Track(), points_rgb[vidx])
|
| 236 |
+
|
| 237 |
+
camera = None
|
| 238 |
+
# frame idx
|
| 239 |
+
for fidx in range(N):
|
| 240 |
+
# set camera
|
| 241 |
+
if camera is None or (not shared_camera):
|
| 242 |
+
pycolmap_intri = _build_pycolmap_intri(fidx, intrinsics, camera_type)
|
| 243 |
+
|
| 244 |
+
camera = pycolmap.Camera(
|
| 245 |
+
model=camera_type, width=image_size[0], height=image_size[1], params=pycolmap_intri, camera_id=fidx + 1
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# add camera
|
| 249 |
+
reconstruction.add_camera(camera)
|
| 250 |
+
|
| 251 |
+
# set image
|
| 252 |
+
cam_from_world = pycolmap.Rigid3d(
|
| 253 |
+
pycolmap.Rotation3d(extrinsics[fidx][:3, :3]), extrinsics[fidx][:3, 3]
|
| 254 |
+
) # Rot and Trans
|
| 255 |
+
|
| 256 |
+
image = pycolmap.Image(
|
| 257 |
+
id=fidx + 1, name=f"image_{fidx + 1}", camera_id=camera.camera_id, cam_from_world=cam_from_world
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
points2D_list = []
|
| 261 |
+
|
| 262 |
+
point2D_idx = 0
|
| 263 |
+
|
| 264 |
+
points_belong_to_fidx = points_xyf[:, 2].astype(np.int32) == fidx
|
| 265 |
+
points_belong_to_fidx = np.nonzero(points_belong_to_fidx)[0]
|
| 266 |
+
|
| 267 |
+
for point3D_batch_idx in points_belong_to_fidx:
|
| 268 |
+
point3D_id = point3D_batch_idx + 1
|
| 269 |
+
point2D_xyf = points_xyf[point3D_batch_idx]
|
| 270 |
+
point2D_xy = point2D_xyf[:2]
|
| 271 |
+
points2D_list.append(pycolmap.Point2D(point2D_xy, point3D_id))
|
| 272 |
+
|
| 273 |
+
# add element
|
| 274 |
+
track = reconstruction.points3D[point3D_id].track
|
| 275 |
+
track.add_element(fidx + 1, point2D_idx)
|
| 276 |
+
point2D_idx += 1
|
| 277 |
+
|
| 278 |
+
assert point2D_idx == len(points2D_list)
|
| 279 |
+
|
| 280 |
+
try:
|
| 281 |
+
image.points2D = pycolmap.ListPoint2D(points2D_list)
|
| 282 |
+
image.registered = True
|
| 283 |
+
except:
|
| 284 |
+
print(f"frame {fidx + 1} does not have any points")
|
| 285 |
+
image.registered = False
|
| 286 |
+
|
| 287 |
+
# add image
|
| 288 |
+
reconstruction.add_image(image)
|
| 289 |
+
|
| 290 |
+
return reconstruction
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def _build_pycolmap_intri(fidx, intrinsics, camera_type, extra_params=None):
|
| 294 |
+
"""
|
| 295 |
+
Helper function to get camera parameters based on camera type.
|
| 296 |
+
|
| 297 |
+
Args:
|
| 298 |
+
fidx: Frame index
|
| 299 |
+
intrinsics: Camera intrinsic parameters
|
| 300 |
+
camera_type: Type of camera model
|
| 301 |
+
extra_params: Additional parameters for certain camera types
|
| 302 |
+
|
| 303 |
+
Returns:
|
| 304 |
+
pycolmap_intri: NumPy array of camera parameters
|
| 305 |
+
"""
|
| 306 |
+
if camera_type == "PINHOLE":
|
| 307 |
+
pycolmap_intri = np.array(
|
| 308 |
+
[intrinsics[fidx][0, 0], intrinsics[fidx][1, 1], intrinsics[fidx][0, 2], intrinsics[fidx][1, 2]]
|
| 309 |
+
)
|
| 310 |
+
elif camera_type == "SIMPLE_PINHOLE":
|
| 311 |
+
focal = (intrinsics[fidx][0, 0] + intrinsics[fidx][1, 1]) / 2
|
| 312 |
+
pycolmap_intri = np.array([focal, intrinsics[fidx][0, 2], intrinsics[fidx][1, 2]])
|
| 313 |
+
elif camera_type == "SIMPLE_RADIAL":
|
| 314 |
+
raise NotImplementedError("SIMPLE_RADIAL is not supported yet")
|
| 315 |
+
focal = (intrinsics[fidx][0, 0] + intrinsics[fidx][1, 1]) / 2
|
| 316 |
+
pycolmap_intri = np.array([focal, intrinsics[fidx][0, 2], intrinsics[fidx][1, 2], extra_params[fidx][0]])
|
| 317 |
+
else:
|
| 318 |
+
raise ValueError(f"Camera type {camera_type} is not supported yet")
|
| 319 |
+
|
| 320 |
+
return pycolmap_intri
|
capvector-pi05/src/vggt/dependency/projection.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
from .distortion import apply_distortion
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def img_from_cam_np(
|
| 13 |
+
intrinsics: np.ndarray, points_cam: np.ndarray, extra_params: np.ndarray | None = None, default: float = 0.0
|
| 14 |
+
) -> np.ndarray:
|
| 15 |
+
"""
|
| 16 |
+
Apply intrinsics (and optional radial distortion) to camera-space points.
|
| 17 |
+
|
| 18 |
+
Args
|
| 19 |
+
----
|
| 20 |
+
intrinsics : (B,3,3) camera matrix K.
|
| 21 |
+
points_cam : (B,3,N) homogeneous camera coords (x, y, z)ᵀ.
|
| 22 |
+
extra_params: (B, N) or (B, k) distortion params (k = 1,2,4) or None.
|
| 23 |
+
default : value used for np.nan replacement.
|
| 24 |
+
|
| 25 |
+
Returns
|
| 26 |
+
-------
|
| 27 |
+
points2D : (B,N,2) pixel coordinates.
|
| 28 |
+
"""
|
| 29 |
+
# 1. perspective divide ───────────────────────────────────────
|
| 30 |
+
z = points_cam[:, 2:3, :] # (B,1,N)
|
| 31 |
+
points_cam_norm = points_cam / z # (B,3,N)
|
| 32 |
+
uv = points_cam_norm[:, :2, :] # (B,2,N)
|
| 33 |
+
|
| 34 |
+
# 2. optional distortion ──────────────────────────────────────
|
| 35 |
+
if extra_params is not None:
|
| 36 |
+
uu, vv = apply_distortion(extra_params, uv[:, 0], uv[:, 1])
|
| 37 |
+
uv = np.stack([uu, vv], axis=1) # (B,2,N)
|
| 38 |
+
|
| 39 |
+
# 3. homogeneous coords then K multiplication ─────────────────
|
| 40 |
+
ones = np.ones_like(uv[:, :1, :]) # (B,1,N)
|
| 41 |
+
points_cam_h = np.concatenate([uv, ones], axis=1) # (B,3,N)
|
| 42 |
+
|
| 43 |
+
# batched mat-mul: K · [u v 1]ᵀ
|
| 44 |
+
points2D_h = np.einsum("bij,bjk->bik", intrinsics, points_cam_h) # (B,3,N)
|
| 45 |
+
points2D = np.nan_to_num(points2D_h[:, :2, :], nan=default) # (B,2,N)
|
| 46 |
+
|
| 47 |
+
return points2D.transpose(0, 2, 1) # (B,N,2)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def project_3D_points_np(
|
| 51 |
+
points3D: np.ndarray,
|
| 52 |
+
extrinsics: np.ndarray,
|
| 53 |
+
intrinsics: np.ndarray | None = None,
|
| 54 |
+
extra_params: np.ndarray | None = None,
|
| 55 |
+
*,
|
| 56 |
+
default: float = 0.0,
|
| 57 |
+
only_points_cam: bool = False,
|
| 58 |
+
):
|
| 59 |
+
"""
|
| 60 |
+
NumPy clone of ``project_3D_points``.
|
| 61 |
+
|
| 62 |
+
Parameters
|
| 63 |
+
----------
|
| 64 |
+
points3D : (N,3) world-space points.
|
| 65 |
+
extrinsics : (B,3,4) [R|t] matrix for each of B cameras.
|
| 66 |
+
intrinsics : (B,3,3) K matrix (optional if you only need cam-space).
|
| 67 |
+
extra_params : (B,k) or (B,N) distortion parameters (k ∈ {1,2,4}) or None.
|
| 68 |
+
default : value used to replace NaNs.
|
| 69 |
+
only_points_cam : if True, skip the projection and return points_cam with points2D as None.
|
| 70 |
+
|
| 71 |
+
Returns
|
| 72 |
+
-------
|
| 73 |
+
(points2D, points_cam) : A tuple where points2D is (B,N,2) pixel coords or None if only_points_cam=True,
|
| 74 |
+
and points_cam is (B,3,N) camera-space coordinates.
|
| 75 |
+
"""
|
| 76 |
+
# ----- 0. prep sizes -----------------------------------------------------
|
| 77 |
+
N = points3D.shape[0] # #points
|
| 78 |
+
B = extrinsics.shape[0] # #cameras
|
| 79 |
+
|
| 80 |
+
# ----- 1. world → homogeneous -------------------------------------------
|
| 81 |
+
w_h = np.ones((N, 1), dtype=points3D.dtype)
|
| 82 |
+
points3D_h = np.concatenate([points3D, w_h], axis=1) # (N,4)
|
| 83 |
+
|
| 84 |
+
# broadcast to every camera (no actual copying with np.broadcast_to) ------
|
| 85 |
+
points3D_h_B = np.broadcast_to(points3D_h, (B, N, 4)) # (B,N,4)
|
| 86 |
+
|
| 87 |
+
# ----- 2. apply extrinsics (camera frame) ------------------------------
|
| 88 |
+
# X_cam = E · X_hom
|
| 89 |
+
# einsum: E_(b i j) · X_(b n j) → (b n i)
|
| 90 |
+
points_cam = np.einsum("bij,bnj->bni", extrinsics, points3D_h_B) # (B,N,3)
|
| 91 |
+
points_cam = points_cam.transpose(0, 2, 1) # (B,3,N)
|
| 92 |
+
|
| 93 |
+
if only_points_cam:
|
| 94 |
+
return None, points_cam
|
| 95 |
+
|
| 96 |
+
# ----- 3. intrinsics + distortion ---------------------------------------
|
| 97 |
+
if intrinsics is None:
|
| 98 |
+
raise ValueError("`intrinsics` must be provided unless only_points_cam=True")
|
| 99 |
+
|
| 100 |
+
points2D = img_from_cam_np(intrinsics, points_cam, extra_params=extra_params, default=default)
|
| 101 |
+
|
| 102 |
+
return points2D, points_cam
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def project_3D_points(points3D, extrinsics, intrinsics=None, extra_params=None, default=0, only_points_cam=False):
|
| 106 |
+
"""
|
| 107 |
+
Transforms 3D points to 2D using extrinsic and intrinsic parameters.
|
| 108 |
+
Args:
|
| 109 |
+
points3D (torch.Tensor): 3D points of shape Px3.
|
| 110 |
+
extrinsics (torch.Tensor): Extrinsic parameters of shape Bx3x4.
|
| 111 |
+
intrinsics (torch.Tensor): Intrinsic parameters of shape Bx3x3.
|
| 112 |
+
extra_params (torch.Tensor): Extra parameters of shape BxN, used for radial distortion.
|
| 113 |
+
default (float): Default value to replace NaNs.
|
| 114 |
+
only_points_cam (bool): If True, skip the projection and return points2D as None.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
tuple: (points2D, points_cam) where points2D is of shape BxNx2 or None if only_points_cam=True,
|
| 118 |
+
and points_cam is of shape Bx3xN.
|
| 119 |
+
"""
|
| 120 |
+
with torch.cuda.amp.autocast(dtype=torch.double):
|
| 121 |
+
N = points3D.shape[0] # Number of points
|
| 122 |
+
B = extrinsics.shape[0] # Batch size, i.e., number of cameras
|
| 123 |
+
points3D_homogeneous = torch.cat([points3D, torch.ones_like(points3D[..., 0:1])], dim=1) # Nx4
|
| 124 |
+
# Reshape for batch processing
|
| 125 |
+
points3D_homogeneous = points3D_homogeneous.unsqueeze(0).expand(B, -1, -1) # BxNx4
|
| 126 |
+
|
| 127 |
+
# Step 1: Apply extrinsic parameters
|
| 128 |
+
# Transform 3D points to camera coordinate system for all cameras
|
| 129 |
+
points_cam = torch.bmm(extrinsics, points3D_homogeneous.transpose(-1, -2))
|
| 130 |
+
|
| 131 |
+
if only_points_cam:
|
| 132 |
+
return None, points_cam
|
| 133 |
+
|
| 134 |
+
# Step 2: Apply intrinsic parameters and (optional) distortion
|
| 135 |
+
points2D = img_from_cam(intrinsics, points_cam, extra_params, default)
|
| 136 |
+
|
| 137 |
+
return points2D, points_cam
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def img_from_cam(intrinsics, points_cam, extra_params=None, default=0.0):
|
| 141 |
+
"""
|
| 142 |
+
Applies intrinsic parameters and optional distortion to the given 3D points.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
intrinsics (torch.Tensor): Intrinsic camera parameters of shape Bx3x3.
|
| 146 |
+
points_cam (torch.Tensor): 3D points in camera coordinates of shape Bx3xN.
|
| 147 |
+
extra_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4.
|
| 148 |
+
default (float, optional): Default value to replace NaNs in the output.
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
points2D (torch.Tensor): 2D points in pixel coordinates of shape BxNx2.
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
# Normalize by the third coordinate (homogeneous division)
|
| 155 |
+
points_cam = points_cam / points_cam[:, 2:3, :]
|
| 156 |
+
# Extract uv
|
| 157 |
+
uv = points_cam[:, :2, :]
|
| 158 |
+
|
| 159 |
+
# Apply distortion if extra_params are provided
|
| 160 |
+
if extra_params is not None:
|
| 161 |
+
uu, vv = apply_distortion(extra_params, uv[:, 0], uv[:, 1])
|
| 162 |
+
uv = torch.stack([uu, vv], dim=1)
|
| 163 |
+
|
| 164 |
+
# Prepare points_cam for batch matrix multiplication
|
| 165 |
+
points_cam_homo = torch.cat((uv, torch.ones_like(uv[:, :1, :])), dim=1) # Bx3xN
|
| 166 |
+
# Apply intrinsic parameters using batch matrix multiplication
|
| 167 |
+
points2D_homo = torch.bmm(intrinsics, points_cam_homo) # Bx3xN
|
| 168 |
+
|
| 169 |
+
# Extract x and y coordinates
|
| 170 |
+
points2D = points2D_homo[:, :2, :] # Bx2xN
|
| 171 |
+
|
| 172 |
+
# Replace NaNs with default value
|
| 173 |
+
points2D = torch.nan_to_num(points2D, nan=default)
|
| 174 |
+
|
| 175 |
+
return points2D.transpose(1, 2) # BxNx2
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
if __name__ == "__main__":
|
| 179 |
+
# Set up example input
|
| 180 |
+
B, N = 24, 10240
|
| 181 |
+
|
| 182 |
+
for _ in range(100):
|
| 183 |
+
points3D = np.random.rand(N, 3).astype(np.float64)
|
| 184 |
+
extrinsics = np.random.rand(B, 3, 4).astype(np.float64)
|
| 185 |
+
intrinsics = np.random.rand(B, 3, 3).astype(np.float64)
|
| 186 |
+
|
| 187 |
+
# Convert to torch tensors
|
| 188 |
+
points3D_torch = torch.tensor(points3D)
|
| 189 |
+
extrinsics_torch = torch.tensor(extrinsics)
|
| 190 |
+
intrinsics_torch = torch.tensor(intrinsics)
|
| 191 |
+
|
| 192 |
+
# Run NumPy implementation
|
| 193 |
+
points2D_np, points_cam_np = project_3D_points_np(points3D, extrinsics, intrinsics)
|
| 194 |
+
|
| 195 |
+
# Run torch implementation
|
| 196 |
+
points2D_torch, points_cam_torch = project_3D_points(points3D_torch, extrinsics_torch, intrinsics_torch)
|
| 197 |
+
|
| 198 |
+
# Convert torch output to numpy
|
| 199 |
+
points2D_torch_np = points2D_torch.detach().numpy()
|
| 200 |
+
points_cam_torch_np = points_cam_torch.detach().numpy()
|
| 201 |
+
|
| 202 |
+
# Compute difference
|
| 203 |
+
diff = np.abs(points2D_np - points2D_torch_np)
|
| 204 |
+
print("Difference between NumPy and PyTorch implementations:")
|
| 205 |
+
print(diff)
|
| 206 |
+
|
| 207 |
+
# Check max error
|
| 208 |
+
max_diff = np.max(diff)
|
| 209 |
+
print(f"Maximum difference: {max_diff}")
|
| 210 |
+
|
| 211 |
+
if np.allclose(points2D_np, points2D_torch_np, atol=1e-6):
|
| 212 |
+
print("Implementations match closely.")
|
| 213 |
+
else:
|
| 214 |
+
print("Significant differences detected.")
|
| 215 |
+
|
| 216 |
+
if points_cam_np is not None:
|
| 217 |
+
points_cam_diff = np.abs(points_cam_np - points_cam_torch_np)
|
| 218 |
+
print("Difference between NumPy and PyTorch camera-space coordinates:")
|
| 219 |
+
print(points_cam_diff)
|
| 220 |
+
|
| 221 |
+
# Check max error
|
| 222 |
+
max_cam_diff = np.max(points_cam_diff)
|
| 223 |
+
print(f"Maximum camera-space coordinate difference: {max_cam_diff}")
|
| 224 |
+
|
| 225 |
+
if np.allclose(points_cam_np, points_cam_torch_np, atol=1e-6):
|
| 226 |
+
print("Camera-space coordinates match closely.")
|
| 227 |
+
else:
|
| 228 |
+
print("Significant differences detected in camera-space coordinates.")
|
capvector-pi05/src/vggt/dependency/track_modules/__init__.py
ADDED
|
File without changes
|
capvector-pi05/src/vggt/dependency/track_modules/base_track_predictor.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from einops import rearrange, repeat
|
| 10 |
+
|
| 11 |
+
from .blocks import EfficientUpdateFormer, CorrBlock
|
| 12 |
+
from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class BaseTrackerPredictor(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
stride=4,
|
| 19 |
+
corr_levels=5,
|
| 20 |
+
corr_radius=4,
|
| 21 |
+
latent_dim=128,
|
| 22 |
+
hidden_size=384,
|
| 23 |
+
use_spaceatt=True,
|
| 24 |
+
depth=6,
|
| 25 |
+
fine=False,
|
| 26 |
+
):
|
| 27 |
+
super(BaseTrackerPredictor, self).__init__()
|
| 28 |
+
"""
|
| 29 |
+
The base template to create a track predictor
|
| 30 |
+
|
| 31 |
+
Modified from https://github.com/facebookresearch/co-tracker/
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
self.stride = stride
|
| 35 |
+
self.latent_dim = latent_dim
|
| 36 |
+
self.corr_levels = corr_levels
|
| 37 |
+
self.corr_radius = corr_radius
|
| 38 |
+
self.hidden_size = hidden_size
|
| 39 |
+
self.fine = fine
|
| 40 |
+
|
| 41 |
+
self.flows_emb_dim = latent_dim // 2
|
| 42 |
+
self.transformer_dim = self.corr_levels * (self.corr_radius * 2 + 1) ** 2 + self.latent_dim * 2
|
| 43 |
+
|
| 44 |
+
if self.fine:
|
| 45 |
+
# TODO this is the old dummy code, will remove this when we train next model
|
| 46 |
+
self.transformer_dim += 4 if self.transformer_dim % 2 == 0 else 5
|
| 47 |
+
else:
|
| 48 |
+
self.transformer_dim += (4 - self.transformer_dim % 4) % 4
|
| 49 |
+
|
| 50 |
+
space_depth = depth if use_spaceatt else 0
|
| 51 |
+
time_depth = depth
|
| 52 |
+
|
| 53 |
+
self.updateformer = EfficientUpdateFormer(
|
| 54 |
+
space_depth=space_depth,
|
| 55 |
+
time_depth=time_depth,
|
| 56 |
+
input_dim=self.transformer_dim,
|
| 57 |
+
hidden_size=self.hidden_size,
|
| 58 |
+
output_dim=self.latent_dim + 2,
|
| 59 |
+
mlp_ratio=4.0,
|
| 60 |
+
add_space_attn=use_spaceatt,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
self.norm = nn.GroupNorm(1, self.latent_dim)
|
| 64 |
+
|
| 65 |
+
# A linear layer to update track feats at each iteration
|
| 66 |
+
self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU())
|
| 67 |
+
|
| 68 |
+
if not self.fine:
|
| 69 |
+
self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
|
| 70 |
+
|
| 71 |
+
def forward(self, query_points, fmaps=None, iters=4, return_feat=False, down_ratio=1):
|
| 72 |
+
"""
|
| 73 |
+
query_points: B x N x 2, the number of batches, tracks, and xy
|
| 74 |
+
fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension.
|
| 75 |
+
note HH and WW is the size of feature maps instead of original images
|
| 76 |
+
"""
|
| 77 |
+
B, N, D = query_points.shape
|
| 78 |
+
B, S, C, HH, WW = fmaps.shape
|
| 79 |
+
|
| 80 |
+
assert D == 2
|
| 81 |
+
|
| 82 |
+
# Scale the input query_points because we may downsample the images
|
| 83 |
+
# by down_ratio or self.stride
|
| 84 |
+
# e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map
|
| 85 |
+
# its query_points should be query_points/4
|
| 86 |
+
if down_ratio > 1:
|
| 87 |
+
query_points = query_points / float(down_ratio)
|
| 88 |
+
query_points = query_points / float(self.stride)
|
| 89 |
+
|
| 90 |
+
# Init with coords as the query points
|
| 91 |
+
# It means the search will start from the position of query points at the reference frames
|
| 92 |
+
coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)
|
| 93 |
+
|
| 94 |
+
# Sample/extract the features of the query points in the query frame
|
| 95 |
+
query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])
|
| 96 |
+
|
| 97 |
+
# init track feats by query feats
|
| 98 |
+
track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C
|
| 99 |
+
# back up the init coords
|
| 100 |
+
coords_backup = coords.clone()
|
| 101 |
+
|
| 102 |
+
# Construct the correlation block
|
| 103 |
+
|
| 104 |
+
fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius)
|
| 105 |
+
|
| 106 |
+
coord_preds = []
|
| 107 |
+
|
| 108 |
+
# Iterative Refinement
|
| 109 |
+
for itr in range(iters):
|
| 110 |
+
# Detach the gradients from the last iteration
|
| 111 |
+
# (in my experience, not very important for performance)
|
| 112 |
+
coords = coords.detach()
|
| 113 |
+
|
| 114 |
+
# Compute the correlation (check the implementation of CorrBlock)
|
| 115 |
+
|
| 116 |
+
fcorr_fn.corr(track_feats)
|
| 117 |
+
fcorrs = fcorr_fn.sample(coords) # B, S, N, corrdim
|
| 118 |
+
|
| 119 |
+
corrdim = fcorrs.shape[3]
|
| 120 |
+
|
| 121 |
+
fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corrdim)
|
| 122 |
+
|
| 123 |
+
# Movement of current coords relative to query points
|
| 124 |
+
flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
|
| 125 |
+
|
| 126 |
+
flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)
|
| 127 |
+
|
| 128 |
+
# (In my trials, it is also okay to just add the flows_emb instead of concat)
|
| 129 |
+
flows_emb = torch.cat([flows_emb, flows], dim=-1)
|
| 130 |
+
|
| 131 |
+
track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
|
| 132 |
+
|
| 133 |
+
# Concatenate them as the input for the transformers
|
| 134 |
+
transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)
|
| 135 |
+
|
| 136 |
+
if transformer_input.shape[2] < self.transformer_dim:
|
| 137 |
+
# pad the features to match the dimension
|
| 138 |
+
pad_dim = self.transformer_dim - transformer_input.shape[2]
|
| 139 |
+
pad = torch.zeros_like(flows_emb[..., 0:pad_dim])
|
| 140 |
+
transformer_input = torch.cat([transformer_input, pad], dim=2)
|
| 141 |
+
|
| 142 |
+
# 2D positional embed
|
| 143 |
+
# TODO: this can be much simplified
|
| 144 |
+
pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device)
|
| 145 |
+
sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0])
|
| 146 |
+
sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1)
|
| 147 |
+
|
| 148 |
+
x = transformer_input + sampled_pos_emb
|
| 149 |
+
|
| 150 |
+
# B, N, S, C
|
| 151 |
+
x = rearrange(x, "(b n) s d -> b n s d", b=B)
|
| 152 |
+
|
| 153 |
+
# Compute the delta coordinates and delta track features
|
| 154 |
+
delta = self.updateformer(x)
|
| 155 |
+
# BN, S, C
|
| 156 |
+
delta = rearrange(delta, " b n s d -> (b n) s d", b=B)
|
| 157 |
+
delta_coords_ = delta[:, :, :2]
|
| 158 |
+
delta_feats_ = delta[:, :, 2:]
|
| 159 |
+
|
| 160 |
+
track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim)
|
| 161 |
+
delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
|
| 162 |
+
|
| 163 |
+
# Update the track features
|
| 164 |
+
track_feats_ = self.ffeat_updater(self.norm(delta_feats_)) + track_feats_
|
| 165 |
+
track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC
|
| 166 |
+
|
| 167 |
+
# B x S x N x 2
|
| 168 |
+
coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
|
| 169 |
+
|
| 170 |
+
# Force coord0 as query
|
| 171 |
+
# because we assume the query points should not be changed
|
| 172 |
+
coords[:, 0] = coords_backup[:, 0]
|
| 173 |
+
|
| 174 |
+
# The predicted tracks are in the original image scale
|
| 175 |
+
if down_ratio > 1:
|
| 176 |
+
coord_preds.append(coords * self.stride * down_ratio)
|
| 177 |
+
else:
|
| 178 |
+
coord_preds.append(coords * self.stride)
|
| 179 |
+
|
| 180 |
+
# B, S, N
|
| 181 |
+
if not self.fine:
|
| 182 |
+
vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
|
| 183 |
+
vis_e = torch.sigmoid(vis_e)
|
| 184 |
+
else:
|
| 185 |
+
vis_e = None
|
| 186 |
+
|
| 187 |
+
if return_feat:
|
| 188 |
+
return coord_preds, vis_e, track_feats, query_track_feat
|
| 189 |
+
else:
|
| 190 |
+
return coord_preds, vis_e
|
capvector-pi05/src/vggt/dependency/track_modules/blocks.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# Modified from https://github.com/facebookresearch/co-tracker/
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from functools import partial
|
| 15 |
+
from typing import Callable
|
| 16 |
+
import collections
|
| 17 |
+
from torch import Tensor
|
| 18 |
+
from itertools import repeat
|
| 19 |
+
|
| 20 |
+
from .utils import bilinear_sampler
|
| 21 |
+
|
| 22 |
+
from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class BasicEncoder(nn.Module):
|
| 26 |
+
def __init__(self, input_dim=3, output_dim=128, stride=4):
|
| 27 |
+
super(BasicEncoder, self).__init__()
|
| 28 |
+
|
| 29 |
+
self.stride = stride
|
| 30 |
+
self.norm_fn = "instance"
|
| 31 |
+
self.in_planes = output_dim // 2
|
| 32 |
+
|
| 33 |
+
self.norm1 = nn.InstanceNorm2d(self.in_planes)
|
| 34 |
+
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
|
| 35 |
+
|
| 36 |
+
self.conv1 = nn.Conv2d(input_dim, self.in_planes, kernel_size=7, stride=2, padding=3, padding_mode="zeros")
|
| 37 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 38 |
+
self.layer1 = self._make_layer(output_dim // 2, stride=1)
|
| 39 |
+
self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
|
| 40 |
+
self.layer3 = self._make_layer(output_dim, stride=2)
|
| 41 |
+
self.layer4 = self._make_layer(output_dim, stride=2)
|
| 42 |
+
|
| 43 |
+
self.conv2 = nn.Conv2d(
|
| 44 |
+
output_dim * 3 + output_dim // 4, output_dim * 2, kernel_size=3, padding=1, padding_mode="zeros"
|
| 45 |
+
)
|
| 46 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 47 |
+
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
|
| 48 |
+
|
| 49 |
+
for m in self.modules():
|
| 50 |
+
if isinstance(m, nn.Conv2d):
|
| 51 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 52 |
+
elif isinstance(m, (nn.InstanceNorm2d)):
|
| 53 |
+
if m.weight is not None:
|
| 54 |
+
nn.init.constant_(m.weight, 1)
|
| 55 |
+
if m.bias is not None:
|
| 56 |
+
nn.init.constant_(m.bias, 0)
|
| 57 |
+
|
| 58 |
+
def _make_layer(self, dim, stride=1):
|
| 59 |
+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
| 60 |
+
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
| 61 |
+
layers = (layer1, layer2)
|
| 62 |
+
|
| 63 |
+
self.in_planes = dim
|
| 64 |
+
return nn.Sequential(*layers)
|
| 65 |
+
|
| 66 |
+
def forward(self, x):
|
| 67 |
+
_, _, H, W = x.shape
|
| 68 |
+
|
| 69 |
+
x = self.conv1(x)
|
| 70 |
+
x = self.norm1(x)
|
| 71 |
+
x = self.relu1(x)
|
| 72 |
+
|
| 73 |
+
a = self.layer1(x)
|
| 74 |
+
b = self.layer2(a)
|
| 75 |
+
c = self.layer3(b)
|
| 76 |
+
d = self.layer4(c)
|
| 77 |
+
|
| 78 |
+
a = _bilinear_intepolate(a, self.stride, H, W)
|
| 79 |
+
b = _bilinear_intepolate(b, self.stride, H, W)
|
| 80 |
+
c = _bilinear_intepolate(c, self.stride, H, W)
|
| 81 |
+
d = _bilinear_intepolate(d, self.stride, H, W)
|
| 82 |
+
|
| 83 |
+
x = self.conv2(torch.cat([a, b, c, d], dim=1))
|
| 84 |
+
x = self.norm2(x)
|
| 85 |
+
x = self.relu2(x)
|
| 86 |
+
x = self.conv3(x)
|
| 87 |
+
return x
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class ShallowEncoder(nn.Module):
|
| 91 |
+
def __init__(self, input_dim=3, output_dim=32, stride=1, norm_fn="instance"):
|
| 92 |
+
super(ShallowEncoder, self).__init__()
|
| 93 |
+
self.stride = stride
|
| 94 |
+
self.norm_fn = norm_fn
|
| 95 |
+
self.in_planes = output_dim
|
| 96 |
+
|
| 97 |
+
if self.norm_fn == "group":
|
| 98 |
+
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes)
|
| 99 |
+
self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2)
|
| 100 |
+
elif self.norm_fn == "batch":
|
| 101 |
+
self.norm1 = nn.BatchNorm2d(self.in_planes)
|
| 102 |
+
self.norm2 = nn.BatchNorm2d(output_dim * 2)
|
| 103 |
+
elif self.norm_fn == "instance":
|
| 104 |
+
self.norm1 = nn.InstanceNorm2d(self.in_planes)
|
| 105 |
+
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
|
| 106 |
+
elif self.norm_fn == "none":
|
| 107 |
+
self.norm1 = nn.Sequential()
|
| 108 |
+
|
| 109 |
+
self.conv1 = nn.Conv2d(input_dim, self.in_planes, kernel_size=3, stride=2, padding=1, padding_mode="zeros")
|
| 110 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 111 |
+
|
| 112 |
+
self.layer1 = self._make_layer(output_dim, stride=2)
|
| 113 |
+
|
| 114 |
+
self.layer2 = self._make_layer(output_dim, stride=2)
|
| 115 |
+
self.conv2 = nn.Conv2d(output_dim, output_dim, kernel_size=1)
|
| 116 |
+
|
| 117 |
+
for m in self.modules():
|
| 118 |
+
if isinstance(m, nn.Conv2d):
|
| 119 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 120 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
| 121 |
+
if m.weight is not None:
|
| 122 |
+
nn.init.constant_(m.weight, 1)
|
| 123 |
+
if m.bias is not None:
|
| 124 |
+
nn.init.constant_(m.bias, 0)
|
| 125 |
+
|
| 126 |
+
def _make_layer(self, dim, stride=1):
|
| 127 |
+
self.in_planes = dim
|
| 128 |
+
|
| 129 |
+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
| 130 |
+
return layer1
|
| 131 |
+
|
| 132 |
+
def forward(self, x):
|
| 133 |
+
_, _, H, W = x.shape
|
| 134 |
+
|
| 135 |
+
x = self.conv1(x)
|
| 136 |
+
x = self.norm1(x)
|
| 137 |
+
x = self.relu1(x)
|
| 138 |
+
|
| 139 |
+
tmp = self.layer1(x)
|
| 140 |
+
x = x + F.interpolate(tmp, (x.shape[-2:]), mode="bilinear", align_corners=True)
|
| 141 |
+
tmp = self.layer2(tmp)
|
| 142 |
+
x = x + F.interpolate(tmp, (x.shape[-2:]), mode="bilinear", align_corners=True)
|
| 143 |
+
tmp = None
|
| 144 |
+
x = self.conv2(x) + x
|
| 145 |
+
|
| 146 |
+
x = F.interpolate(x, (H // self.stride, W // self.stride), mode="bilinear", align_corners=True)
|
| 147 |
+
|
| 148 |
+
return x
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def _bilinear_intepolate(x, stride, H, W):
|
| 152 |
+
return F.interpolate(x, (H // stride, W // stride), mode="bilinear", align_corners=True)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class EfficientUpdateFormer(nn.Module):
|
| 156 |
+
"""
|
| 157 |
+
Transformer model that updates track estimates.
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
def __init__(
|
| 161 |
+
self,
|
| 162 |
+
space_depth=6,
|
| 163 |
+
time_depth=6,
|
| 164 |
+
input_dim=320,
|
| 165 |
+
hidden_size=384,
|
| 166 |
+
num_heads=8,
|
| 167 |
+
output_dim=130,
|
| 168 |
+
mlp_ratio=4.0,
|
| 169 |
+
add_space_attn=True,
|
| 170 |
+
num_virtual_tracks=64,
|
| 171 |
+
):
|
| 172 |
+
super().__init__()
|
| 173 |
+
|
| 174 |
+
self.out_channels = 2
|
| 175 |
+
self.num_heads = num_heads
|
| 176 |
+
self.hidden_size = hidden_size
|
| 177 |
+
self.add_space_attn = add_space_attn
|
| 178 |
+
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
|
| 179 |
+
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
|
| 180 |
+
self.num_virtual_tracks = num_virtual_tracks
|
| 181 |
+
|
| 182 |
+
if self.add_space_attn:
|
| 183 |
+
self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size))
|
| 184 |
+
else:
|
| 185 |
+
self.virual_tracks = None
|
| 186 |
+
|
| 187 |
+
self.time_blocks = nn.ModuleList(
|
| 188 |
+
[
|
| 189 |
+
AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention)
|
| 190 |
+
for _ in range(time_depth)
|
| 191 |
+
]
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
if add_space_attn:
|
| 195 |
+
self.space_virtual_blocks = nn.ModuleList(
|
| 196 |
+
[
|
| 197 |
+
AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention)
|
| 198 |
+
for _ in range(space_depth)
|
| 199 |
+
]
|
| 200 |
+
)
|
| 201 |
+
self.space_point2virtual_blocks = nn.ModuleList(
|
| 202 |
+
[CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
|
| 203 |
+
)
|
| 204 |
+
self.space_virtual2point_blocks = nn.ModuleList(
|
| 205 |
+
[CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
|
| 206 |
+
)
|
| 207 |
+
assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
|
| 208 |
+
self.initialize_weights()
|
| 209 |
+
|
| 210 |
+
def initialize_weights(self):
|
| 211 |
+
def _basic_init(module):
|
| 212 |
+
if isinstance(module, nn.Linear):
|
| 213 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 214 |
+
if module.bias is not None:
|
| 215 |
+
nn.init.constant_(module.bias, 0)
|
| 216 |
+
|
| 217 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
| 218 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
| 219 |
+
if isinstance(module, nn.Linear):
|
| 220 |
+
trunc_normal_(module.weight, std=0.02)
|
| 221 |
+
if module.bias is not None:
|
| 222 |
+
nn.init.zeros_(module.bias)
|
| 223 |
+
|
| 224 |
+
def forward(self, input_tensor, mask=None):
|
| 225 |
+
tokens = self.input_transform(input_tensor)
|
| 226 |
+
|
| 227 |
+
init_tokens = tokens
|
| 228 |
+
|
| 229 |
+
B, _, T, _ = tokens.shape
|
| 230 |
+
|
| 231 |
+
if self.add_space_attn:
|
| 232 |
+
virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
|
| 233 |
+
tokens = torch.cat([tokens, virtual_tokens], dim=1)
|
| 234 |
+
|
| 235 |
+
_, N, _, _ = tokens.shape
|
| 236 |
+
|
| 237 |
+
j = 0
|
| 238 |
+
for i in range(len(self.time_blocks)):
|
| 239 |
+
time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
|
| 240 |
+
time_tokens = self.time_blocks[i](time_tokens)
|
| 241 |
+
|
| 242 |
+
tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
|
| 243 |
+
if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0):
|
| 244 |
+
space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C
|
| 245 |
+
point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
|
| 246 |
+
virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
|
| 247 |
+
|
| 248 |
+
virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask)
|
| 249 |
+
virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
|
| 250 |
+
point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask)
|
| 251 |
+
space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
|
| 252 |
+
tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C
|
| 253 |
+
j += 1
|
| 254 |
+
|
| 255 |
+
if self.add_space_attn:
|
| 256 |
+
tokens = tokens[:, : N - self.num_virtual_tracks]
|
| 257 |
+
|
| 258 |
+
tokens = tokens + init_tokens
|
| 259 |
+
|
| 260 |
+
flow = self.flow_head(tokens)
|
| 261 |
+
return flow
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
class CorrBlock:
|
| 265 |
+
def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"):
|
| 266 |
+
B, S, C, H, W = fmaps.shape
|
| 267 |
+
self.S, self.C, self.H, self.W = S, C, H, W
|
| 268 |
+
self.padding_mode = padding_mode
|
| 269 |
+
self.num_levels = num_levels
|
| 270 |
+
self.radius = radius
|
| 271 |
+
self.fmaps_pyramid = []
|
| 272 |
+
self.multiple_track_feats = multiple_track_feats
|
| 273 |
+
|
| 274 |
+
self.fmaps_pyramid.append(fmaps)
|
| 275 |
+
for i in range(self.num_levels - 1):
|
| 276 |
+
fmaps_ = fmaps.reshape(B * S, C, H, W)
|
| 277 |
+
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
|
| 278 |
+
_, _, H, W = fmaps_.shape
|
| 279 |
+
fmaps = fmaps_.reshape(B, S, C, H, W)
|
| 280 |
+
self.fmaps_pyramid.append(fmaps)
|
| 281 |
+
|
| 282 |
+
def sample(self, coords):
|
| 283 |
+
r = self.radius
|
| 284 |
+
B, S, N, D = coords.shape
|
| 285 |
+
assert D == 2
|
| 286 |
+
|
| 287 |
+
H, W = self.H, self.W
|
| 288 |
+
out_pyramid = []
|
| 289 |
+
for i in range(self.num_levels):
|
| 290 |
+
corrs = self.corrs_pyramid[i] # B, S, N, H, W
|
| 291 |
+
*_, H, W = corrs.shape
|
| 292 |
+
|
| 293 |
+
dx = torch.linspace(-r, r, 2 * r + 1)
|
| 294 |
+
dy = torch.linspace(-r, r, 2 * r + 1)
|
| 295 |
+
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device)
|
| 296 |
+
|
| 297 |
+
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i
|
| 298 |
+
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
| 299 |
+
coords_lvl = centroid_lvl + delta_lvl
|
| 300 |
+
|
| 301 |
+
corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode)
|
| 302 |
+
corrs = corrs.view(B, S, N, -1)
|
| 303 |
+
|
| 304 |
+
out_pyramid.append(corrs)
|
| 305 |
+
|
| 306 |
+
out = torch.cat(out_pyramid, dim=-1).contiguous() # B, S, N, LRR*2
|
| 307 |
+
return out
|
| 308 |
+
|
| 309 |
+
def corr(self, targets):
|
| 310 |
+
B, S, N, C = targets.shape
|
| 311 |
+
if self.multiple_track_feats:
|
| 312 |
+
targets_split = targets.split(C // self.num_levels, dim=-1)
|
| 313 |
+
B, S, N, C = targets_split[0].shape
|
| 314 |
+
|
| 315 |
+
assert C == self.C
|
| 316 |
+
assert S == self.S
|
| 317 |
+
|
| 318 |
+
fmap1 = targets
|
| 319 |
+
|
| 320 |
+
self.corrs_pyramid = []
|
| 321 |
+
for i, fmaps in enumerate(self.fmaps_pyramid):
|
| 322 |
+
*_, H, W = fmaps.shape
|
| 323 |
+
fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W)
|
| 324 |
+
if self.multiple_track_feats:
|
| 325 |
+
fmap1 = targets_split[i]
|
| 326 |
+
corrs = torch.matmul(fmap1, fmap2s)
|
| 327 |
+
corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W
|
| 328 |
+
corrs = corrs / torch.sqrt(torch.tensor(C).float())
|
| 329 |
+
self.corrs_pyramid.append(corrs)
|
capvector-pi05/src/vggt/dependency/track_modules/modules.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from functools import partial
|
| 12 |
+
from typing import Callable
|
| 13 |
+
import collections
|
| 14 |
+
from torch import Tensor
|
| 15 |
+
from itertools import repeat
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# From PyTorch internals
|
| 19 |
+
def _ntuple(n):
|
| 20 |
+
def parse(x):
|
| 21 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
| 22 |
+
return tuple(x)
|
| 23 |
+
return tuple(repeat(x, n))
|
| 24 |
+
|
| 25 |
+
return parse
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def exists(val):
|
| 29 |
+
return val is not None
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def default(val, d):
|
| 33 |
+
return val if exists(val) else d
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
to_2tuple = _ntuple(2)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ResidualBlock(nn.Module):
|
| 40 |
+
"""
|
| 41 |
+
ResidualBlock: construct a block of two conv layers with residual connections
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3):
|
| 45 |
+
super(ResidualBlock, self).__init__()
|
| 46 |
+
|
| 47 |
+
self.conv1 = nn.Conv2d(
|
| 48 |
+
in_planes, planes, kernel_size=kernel_size, padding=1, stride=stride, padding_mode="zeros"
|
| 49 |
+
)
|
| 50 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, padding=1, padding_mode="zeros")
|
| 51 |
+
self.relu = nn.ReLU(inplace=True)
|
| 52 |
+
|
| 53 |
+
num_groups = planes // 8
|
| 54 |
+
|
| 55 |
+
if norm_fn == "group":
|
| 56 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 57 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 58 |
+
if not stride == 1:
|
| 59 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 60 |
+
|
| 61 |
+
elif norm_fn == "batch":
|
| 62 |
+
self.norm1 = nn.BatchNorm2d(planes)
|
| 63 |
+
self.norm2 = nn.BatchNorm2d(planes)
|
| 64 |
+
if not stride == 1:
|
| 65 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
| 66 |
+
|
| 67 |
+
elif norm_fn == "instance":
|
| 68 |
+
self.norm1 = nn.InstanceNorm2d(planes)
|
| 69 |
+
self.norm2 = nn.InstanceNorm2d(planes)
|
| 70 |
+
if not stride == 1:
|
| 71 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
| 72 |
+
|
| 73 |
+
elif norm_fn == "none":
|
| 74 |
+
self.norm1 = nn.Sequential()
|
| 75 |
+
self.norm2 = nn.Sequential()
|
| 76 |
+
if not stride == 1:
|
| 77 |
+
self.norm3 = nn.Sequential()
|
| 78 |
+
else:
|
| 79 |
+
raise NotImplementedError
|
| 80 |
+
|
| 81 |
+
if stride == 1:
|
| 82 |
+
self.downsample = None
|
| 83 |
+
else:
|
| 84 |
+
self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
|
| 85 |
+
|
| 86 |
+
def forward(self, x):
|
| 87 |
+
y = x
|
| 88 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
| 89 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
| 90 |
+
|
| 91 |
+
if self.downsample is not None:
|
| 92 |
+
x = self.downsample(x)
|
| 93 |
+
|
| 94 |
+
return self.relu(x + y)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class Mlp(nn.Module):
|
| 98 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
| 99 |
+
|
| 100 |
+
def __init__(
|
| 101 |
+
self,
|
| 102 |
+
in_features,
|
| 103 |
+
hidden_features=None,
|
| 104 |
+
out_features=None,
|
| 105 |
+
act_layer=nn.GELU,
|
| 106 |
+
norm_layer=None,
|
| 107 |
+
bias=True,
|
| 108 |
+
drop=0.0,
|
| 109 |
+
use_conv=False,
|
| 110 |
+
):
|
| 111 |
+
super().__init__()
|
| 112 |
+
out_features = out_features or in_features
|
| 113 |
+
hidden_features = hidden_features or in_features
|
| 114 |
+
bias = to_2tuple(bias)
|
| 115 |
+
drop_probs = to_2tuple(drop)
|
| 116 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
| 117 |
+
|
| 118 |
+
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
| 119 |
+
self.act = act_layer()
|
| 120 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
| 121 |
+
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
| 122 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
| 123 |
+
|
| 124 |
+
def forward(self, x):
|
| 125 |
+
x = self.fc1(x)
|
| 126 |
+
x = self.act(x)
|
| 127 |
+
x = self.drop1(x)
|
| 128 |
+
x = self.fc2(x)
|
| 129 |
+
x = self.drop2(x)
|
| 130 |
+
return x
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class AttnBlock(nn.Module):
|
| 134 |
+
def __init__(
|
| 135 |
+
self,
|
| 136 |
+
hidden_size,
|
| 137 |
+
num_heads,
|
| 138 |
+
attn_class: Callable[..., nn.Module] = nn.MultiheadAttention,
|
| 139 |
+
mlp_ratio=4.0,
|
| 140 |
+
**block_kwargs,
|
| 141 |
+
):
|
| 142 |
+
"""
|
| 143 |
+
Self attention block
|
| 144 |
+
"""
|
| 145 |
+
super().__init__()
|
| 146 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 147 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 148 |
+
|
| 149 |
+
self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs)
|
| 150 |
+
|
| 151 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 152 |
+
|
| 153 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
|
| 154 |
+
|
| 155 |
+
def forward(self, x, mask=None):
|
| 156 |
+
# Prepare the mask for PyTorch's attention (it expects a different format)
|
| 157 |
+
# attn_mask = mask if mask is not None else None
|
| 158 |
+
# Normalize before attention
|
| 159 |
+
x = self.norm1(x)
|
| 160 |
+
|
| 161 |
+
# PyTorch's MultiheadAttention returns attn_output, attn_output_weights
|
| 162 |
+
# attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask)
|
| 163 |
+
|
| 164 |
+
attn_output, _ = self.attn(x, x, x)
|
| 165 |
+
|
| 166 |
+
# Add & Norm
|
| 167 |
+
x = x + attn_output
|
| 168 |
+
x = x + self.mlp(self.norm2(x))
|
| 169 |
+
return x
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class CrossAttnBlock(nn.Module):
|
| 173 |
+
def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):
|
| 174 |
+
"""
|
| 175 |
+
Cross attention block
|
| 176 |
+
"""
|
| 177 |
+
super().__init__()
|
| 178 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 179 |
+
self.norm_context = nn.LayerNorm(hidden_size)
|
| 180 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 181 |
+
|
| 182 |
+
self.cross_attn = nn.MultiheadAttention(
|
| 183 |
+
embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 187 |
+
|
| 188 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
|
| 189 |
+
|
| 190 |
+
def forward(self, x, context, mask=None):
|
| 191 |
+
# Normalize inputs
|
| 192 |
+
x = self.norm1(x)
|
| 193 |
+
context = self.norm_context(context)
|
| 194 |
+
|
| 195 |
+
# Apply cross attention
|
| 196 |
+
# Note: nn.MultiheadAttention returns attn_output, attn_output_weights
|
| 197 |
+
attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask)
|
| 198 |
+
|
| 199 |
+
# Add & Norm
|
| 200 |
+
x = x + attn_output
|
| 201 |
+
x = x + self.mlp(self.norm2(x))
|
| 202 |
+
return x
|
capvector-pi05/src/vggt/dependency/track_modules/track_refine.py
ADDED
|
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from functools import partial
|
| 13 |
+
from torch import nn, einsum
|
| 14 |
+
from einops import rearrange, repeat
|
| 15 |
+
from einops.layers.torch import Rearrange, Reduce
|
| 16 |
+
|
| 17 |
+
from PIL import Image
|
| 18 |
+
import os
|
| 19 |
+
from typing import Union, Tuple
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def refine_track(
|
| 23 |
+
images, fine_fnet, fine_tracker, coarse_pred, compute_score=False, pradius=15, sradius=2, fine_iters=6, chunk=40960
|
| 24 |
+
):
|
| 25 |
+
"""
|
| 26 |
+
Refines the tracking of images using a fine track predictor and a fine feature network.
|
| 27 |
+
Check https://arxiv.org/abs/2312.04563 for more details.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
images (torch.Tensor): The images to be tracked.
|
| 31 |
+
fine_fnet (nn.Module): The fine feature network.
|
| 32 |
+
fine_tracker (nn.Module): The fine track predictor.
|
| 33 |
+
coarse_pred (torch.Tensor): The coarse predictions of tracks.
|
| 34 |
+
compute_score (bool, optional): Whether to compute the score. Defaults to False.
|
| 35 |
+
pradius (int, optional): The radius of a patch. Defaults to 15.
|
| 36 |
+
sradius (int, optional): The search radius. Defaults to 2.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
torch.Tensor: The refined tracks.
|
| 40 |
+
torch.Tensor, optional: The score.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
# coarse_pred shape: BxSxNx2,
|
| 44 |
+
# where B is the batch, S is the video/images length, and N is the number of tracks
|
| 45 |
+
# now we are going to extract patches with the center at coarse_pred
|
| 46 |
+
# Please note that the last dimension indicates x and y, and hence has a dim number of 2
|
| 47 |
+
B, S, N, _ = coarse_pred.shape
|
| 48 |
+
_, _, _, H, W = images.shape
|
| 49 |
+
|
| 50 |
+
# Given the raidus of a patch, compute the patch size
|
| 51 |
+
psize = pradius * 2 + 1
|
| 52 |
+
|
| 53 |
+
# Note that we assume the first frame is the query frame
|
| 54 |
+
# so the 2D locations of the first frame are the query points
|
| 55 |
+
query_points = coarse_pred[:, 0]
|
| 56 |
+
|
| 57 |
+
# Given 2D positions, we can use grid_sample to extract patches
|
| 58 |
+
# but it takes too much memory.
|
| 59 |
+
# Instead, we use the floored track xy to sample patches.
|
| 60 |
+
|
| 61 |
+
# For example, if the query point xy is (128.16, 252.78),
|
| 62 |
+
# and the patch size is (31, 31),
|
| 63 |
+
# our goal is to extract the content of a rectangle
|
| 64 |
+
# with left top: (113.16, 237.78)
|
| 65 |
+
# and right bottom: (143.16, 267.78).
|
| 66 |
+
# However, we record the floored left top: (113, 237)
|
| 67 |
+
# and the offset (0.16, 0.78)
|
| 68 |
+
# Then what we need is just unfolding the images like in CNN,
|
| 69 |
+
# picking the content at [(113, 237), (143, 267)].
|
| 70 |
+
# Such operations are highly optimized at pytorch
|
| 71 |
+
# (well if you really want to use interpolation, check the function extract_glimpse() below)
|
| 72 |
+
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
content_to_extract = images.reshape(B * S, 3, H, W)
|
| 75 |
+
C_in = content_to_extract.shape[1]
|
| 76 |
+
|
| 77 |
+
# Please refer to https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html
|
| 78 |
+
# for the detailed explanation of unfold()
|
| 79 |
+
# Here it runs sliding windows (psize x psize) to build patches
|
| 80 |
+
# The shape changes from
|
| 81 |
+
# (B*S)x C_in x H x W to (B*S)x C_in x H_new x W_new x Psize x Psize
|
| 82 |
+
# where Psize is the size of patch
|
| 83 |
+
content_to_extract = content_to_extract.unfold(2, psize, 1).unfold(3, psize, 1)
|
| 84 |
+
|
| 85 |
+
# Floor the coarse predictions to get integers and save the fractional/decimal
|
| 86 |
+
track_int = coarse_pred.floor().int()
|
| 87 |
+
track_frac = coarse_pred - track_int
|
| 88 |
+
|
| 89 |
+
# Note the points represent the center of patches
|
| 90 |
+
# now we get the location of the top left corner of patches
|
| 91 |
+
# because the ouput of pytorch unfold are indexed by top left corner
|
| 92 |
+
topleft = track_int - pradius
|
| 93 |
+
topleft_BSN = topleft.clone()
|
| 94 |
+
|
| 95 |
+
# clamp the values so that we will not go out of indexes
|
| 96 |
+
# NOTE: (VERY IMPORTANT: This operation ASSUMES H=W).
|
| 97 |
+
# You need to seperately clamp x and y if H!=W
|
| 98 |
+
topleft = topleft.clamp(0, H - psize)
|
| 99 |
+
|
| 100 |
+
# Reshape from BxSxNx2 -> (B*S)xNx2
|
| 101 |
+
topleft = topleft.reshape(B * S, N, 2)
|
| 102 |
+
|
| 103 |
+
# Prepare batches for indexing, shape: (B*S)xN
|
| 104 |
+
batch_indices = torch.arange(B * S)[:, None].expand(-1, N).to(content_to_extract.device)
|
| 105 |
+
|
| 106 |
+
# extracted_patches: (B*S) x N x C_in x Psize x Psize
|
| 107 |
+
extracted_patches = content_to_extract[batch_indices, :, topleft[..., 1], topleft[..., 0]]
|
| 108 |
+
|
| 109 |
+
if chunk < 0:
|
| 110 |
+
# Extract image patches based on top left corners
|
| 111 |
+
# Feed patches to fine fent for features
|
| 112 |
+
patch_feat = fine_fnet(extracted_patches.reshape(B * S * N, C_in, psize, psize))
|
| 113 |
+
else:
|
| 114 |
+
patches = extracted_patches.reshape(B * S * N, C_in, psize, psize)
|
| 115 |
+
|
| 116 |
+
patch_feat_list = []
|
| 117 |
+
for p in torch.split(patches, chunk):
|
| 118 |
+
patch_feat_list += [fine_fnet(p)]
|
| 119 |
+
patch_feat = torch.cat(patch_feat_list, 0)
|
| 120 |
+
|
| 121 |
+
C_out = patch_feat.shape[1]
|
| 122 |
+
|
| 123 |
+
# Refine the coarse tracks by fine_tracker
|
| 124 |
+
# reshape back to B x S x N x C_out x Psize x Psize
|
| 125 |
+
patch_feat = patch_feat.reshape(B, S, N, C_out, psize, psize)
|
| 126 |
+
patch_feat = rearrange(patch_feat, "b s n c p q -> (b n) s c p q")
|
| 127 |
+
|
| 128 |
+
# Prepare for the query points for fine tracker
|
| 129 |
+
# They are relative to the patch left top corner,
|
| 130 |
+
# instead of the image top left corner now
|
| 131 |
+
# patch_query_points: N x 1 x 2
|
| 132 |
+
# only 1 here because for each patch we only have 1 query point
|
| 133 |
+
patch_query_points = track_frac[:, 0] + pradius
|
| 134 |
+
patch_query_points = patch_query_points.reshape(B * N, 2).unsqueeze(1)
|
| 135 |
+
|
| 136 |
+
# Feed the PATCH query points and tracks into fine tracker
|
| 137 |
+
fine_pred_track_lists, _, _, query_point_feat = fine_tracker(
|
| 138 |
+
query_points=patch_query_points, fmaps=patch_feat, iters=fine_iters, return_feat=True
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# relative the patch top left
|
| 142 |
+
fine_pred_track = fine_pred_track_lists[-1].clone()
|
| 143 |
+
|
| 144 |
+
# From (relative to the patch top left) to (relative to the image top left)
|
| 145 |
+
for idx in range(len(fine_pred_track_lists)):
|
| 146 |
+
fine_level = rearrange(fine_pred_track_lists[idx], "(b n) s u v -> b s n u v", b=B, n=N)
|
| 147 |
+
fine_level = fine_level.squeeze(-2)
|
| 148 |
+
fine_level = fine_level + topleft_BSN
|
| 149 |
+
fine_pred_track_lists[idx] = fine_level
|
| 150 |
+
|
| 151 |
+
# relative to the image top left
|
| 152 |
+
refined_tracks = fine_pred_track_lists[-1].clone()
|
| 153 |
+
refined_tracks[:, 0] = query_points
|
| 154 |
+
|
| 155 |
+
score = None
|
| 156 |
+
|
| 157 |
+
if compute_score:
|
| 158 |
+
score = compute_score_fn(query_point_feat, patch_feat, fine_pred_track, sradius, psize, B, N, S, C_out)
|
| 159 |
+
|
| 160 |
+
return refined_tracks, score
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def refine_track_v0(
|
| 164 |
+
images, fine_fnet, fine_tracker, coarse_pred, compute_score=False, pradius=15, sradius=2, fine_iters=6
|
| 165 |
+
):
|
| 166 |
+
"""
|
| 167 |
+
COPIED FROM VGGSfM
|
| 168 |
+
|
| 169 |
+
Refines the tracking of images using a fine track predictor and a fine feature network.
|
| 170 |
+
Check https://arxiv.org/abs/2312.04563 for more details.
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
images (torch.Tensor): The images to be tracked.
|
| 174 |
+
fine_fnet (nn.Module): The fine feature network.
|
| 175 |
+
fine_tracker (nn.Module): The fine track predictor.
|
| 176 |
+
coarse_pred (torch.Tensor): The coarse predictions of tracks.
|
| 177 |
+
compute_score (bool, optional): Whether to compute the score. Defaults to False.
|
| 178 |
+
pradius (int, optional): The radius of a patch. Defaults to 15.
|
| 179 |
+
sradius (int, optional): The search radius. Defaults to 2.
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
torch.Tensor: The refined tracks.
|
| 183 |
+
torch.Tensor, optional: The score.
|
| 184 |
+
"""
|
| 185 |
+
|
| 186 |
+
# coarse_pred shape: BxSxNx2,
|
| 187 |
+
# where B is the batch, S is the video/images length, and N is the number of tracks
|
| 188 |
+
# now we are going to extract patches with the center at coarse_pred
|
| 189 |
+
# Please note that the last dimension indicates x and y, and hence has a dim number of 2
|
| 190 |
+
B, S, N, _ = coarse_pred.shape
|
| 191 |
+
_, _, _, H, W = images.shape
|
| 192 |
+
|
| 193 |
+
# Given the raidus of a patch, compute the patch size
|
| 194 |
+
psize = pradius * 2 + 1
|
| 195 |
+
|
| 196 |
+
# Note that we assume the first frame is the query frame
|
| 197 |
+
# so the 2D locations of the first frame are the query points
|
| 198 |
+
query_points = coarse_pred[:, 0]
|
| 199 |
+
|
| 200 |
+
# Given 2D positions, we can use grid_sample to extract patches
|
| 201 |
+
# but it takes too much memory.
|
| 202 |
+
# Instead, we use the floored track xy to sample patches.
|
| 203 |
+
|
| 204 |
+
# For example, if the query point xy is (128.16, 252.78),
|
| 205 |
+
# and the patch size is (31, 31),
|
| 206 |
+
# our goal is to extract the content of a rectangle
|
| 207 |
+
# with left top: (113.16, 237.78)
|
| 208 |
+
# and right bottom: (143.16, 267.78).
|
| 209 |
+
# However, we record the floored left top: (113, 237)
|
| 210 |
+
# and the offset (0.16, 0.78)
|
| 211 |
+
# Then what we need is just unfolding the images like in CNN,
|
| 212 |
+
# picking the content at [(113, 237), (143, 267)].
|
| 213 |
+
# Such operations are highly optimized at pytorch
|
| 214 |
+
# (well if you really want to use interpolation, check the function extract_glimpse() below)
|
| 215 |
+
|
| 216 |
+
with torch.no_grad():
|
| 217 |
+
content_to_extract = images.reshape(B * S, 3, H, W)
|
| 218 |
+
C_in = content_to_extract.shape[1]
|
| 219 |
+
|
| 220 |
+
# Please refer to https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html
|
| 221 |
+
# for the detailed explanation of unfold()
|
| 222 |
+
# Here it runs sliding windows (psize x psize) to build patches
|
| 223 |
+
# The shape changes from
|
| 224 |
+
# (B*S)x C_in x H x W to (B*S)x C_in x H_new x W_new x Psize x Psize
|
| 225 |
+
# where Psize is the size of patch
|
| 226 |
+
content_to_extract = content_to_extract.unfold(2, psize, 1).unfold(3, psize, 1)
|
| 227 |
+
|
| 228 |
+
# Floor the coarse predictions to get integers and save the fractional/decimal
|
| 229 |
+
track_int = coarse_pred.floor().int()
|
| 230 |
+
track_frac = coarse_pred - track_int
|
| 231 |
+
|
| 232 |
+
# Note the points represent the center of patches
|
| 233 |
+
# now we get the location of the top left corner of patches
|
| 234 |
+
# because the ouput of pytorch unfold are indexed by top left corner
|
| 235 |
+
topleft = track_int - pradius
|
| 236 |
+
topleft_BSN = topleft.clone()
|
| 237 |
+
|
| 238 |
+
# clamp the values so that we will not go out of indexes
|
| 239 |
+
# NOTE: (VERY IMPORTANT: This operation ASSUMES H=W).
|
| 240 |
+
# You need to seperately clamp x and y if H!=W
|
| 241 |
+
topleft = topleft.clamp(0, H - psize)
|
| 242 |
+
|
| 243 |
+
# Reshape from BxSxNx2 -> (B*S)xNx2
|
| 244 |
+
topleft = topleft.reshape(B * S, N, 2)
|
| 245 |
+
|
| 246 |
+
# Prepare batches for indexing, shape: (B*S)xN
|
| 247 |
+
batch_indices = torch.arange(B * S)[:, None].expand(-1, N).to(content_to_extract.device)
|
| 248 |
+
|
| 249 |
+
# Extract image patches based on top left corners
|
| 250 |
+
# extracted_patches: (B*S) x N x C_in x Psize x Psize
|
| 251 |
+
extracted_patches = content_to_extract[batch_indices, :, topleft[..., 1], topleft[..., 0]]
|
| 252 |
+
|
| 253 |
+
# Feed patches to fine fent for features
|
| 254 |
+
patch_feat = fine_fnet(extracted_patches.reshape(B * S * N, C_in, psize, psize))
|
| 255 |
+
|
| 256 |
+
C_out = patch_feat.shape[1]
|
| 257 |
+
|
| 258 |
+
# Refine the coarse tracks by fine_tracker
|
| 259 |
+
|
| 260 |
+
# reshape back to B x S x N x C_out x Psize x Psize
|
| 261 |
+
patch_feat = patch_feat.reshape(B, S, N, C_out, psize, psize)
|
| 262 |
+
patch_feat = rearrange(patch_feat, "b s n c p q -> (b n) s c p q")
|
| 263 |
+
|
| 264 |
+
# Prepare for the query points for fine tracker
|
| 265 |
+
# They are relative to the patch left top corner,
|
| 266 |
+
# instead of the image top left corner now
|
| 267 |
+
# patch_query_points: N x 1 x 2
|
| 268 |
+
# only 1 here because for each patch we only have 1 query point
|
| 269 |
+
patch_query_points = track_frac[:, 0] + pradius
|
| 270 |
+
patch_query_points = patch_query_points.reshape(B * N, 2).unsqueeze(1)
|
| 271 |
+
|
| 272 |
+
# Feed the PATCH query points and tracks into fine tracker
|
| 273 |
+
fine_pred_track_lists, _, _, query_point_feat = fine_tracker(
|
| 274 |
+
query_points=patch_query_points, fmaps=patch_feat, iters=fine_iters, return_feat=True
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
# relative the patch top left
|
| 278 |
+
fine_pred_track = fine_pred_track_lists[-1].clone()
|
| 279 |
+
|
| 280 |
+
# From (relative to the patch top left) to (relative to the image top left)
|
| 281 |
+
for idx in range(len(fine_pred_track_lists)):
|
| 282 |
+
fine_level = rearrange(fine_pred_track_lists[idx], "(b n) s u v -> b s n u v", b=B, n=N)
|
| 283 |
+
fine_level = fine_level.squeeze(-2)
|
| 284 |
+
fine_level = fine_level + topleft_BSN
|
| 285 |
+
fine_pred_track_lists[idx] = fine_level
|
| 286 |
+
|
| 287 |
+
# relative to the image top left
|
| 288 |
+
refined_tracks = fine_pred_track_lists[-1].clone()
|
| 289 |
+
refined_tracks[:, 0] = query_points
|
| 290 |
+
|
| 291 |
+
score = None
|
| 292 |
+
|
| 293 |
+
if compute_score:
|
| 294 |
+
score = compute_score_fn(query_point_feat, patch_feat, fine_pred_track, sradius, psize, B, N, S, C_out)
|
| 295 |
+
|
| 296 |
+
return refined_tracks, score
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
################################## NOTE: NOT USED ##################################
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def compute_score_fn(query_point_feat, patch_feat, fine_pred_track, sradius, psize, B, N, S, C_out):
|
| 303 |
+
"""
|
| 304 |
+
Compute the scores, i.e., the standard deviation of the 2D similarity heatmaps,
|
| 305 |
+
given the query point features and reference frame feature maps
|
| 306 |
+
"""
|
| 307 |
+
|
| 308 |
+
from kornia.utils.grid import create_meshgrid
|
| 309 |
+
from kornia.geometry.subpix import dsnt
|
| 310 |
+
|
| 311 |
+
# query_point_feat initial shape: B x N x C_out,
|
| 312 |
+
# query_point_feat indicates the feat at the coorponsing query points
|
| 313 |
+
# Therefore we don't have S dimension here
|
| 314 |
+
query_point_feat = query_point_feat.reshape(B, N, C_out)
|
| 315 |
+
# reshape and expand to B x (S-1) x N x C_out
|
| 316 |
+
query_point_feat = query_point_feat.unsqueeze(1).expand(-1, S - 1, -1, -1)
|
| 317 |
+
# and reshape to (B*(S-1)*N) x C_out
|
| 318 |
+
query_point_feat = query_point_feat.reshape(B * (S - 1) * N, C_out)
|
| 319 |
+
|
| 320 |
+
# Radius and size for computing the score
|
| 321 |
+
ssize = sradius * 2 + 1
|
| 322 |
+
|
| 323 |
+
# Reshape, you know it, so many reshaping operations
|
| 324 |
+
patch_feat = rearrange(patch_feat, "(b n) s c p q -> b s n c p q", b=B, n=N)
|
| 325 |
+
|
| 326 |
+
# Again, we unfold the patches to smaller patches
|
| 327 |
+
# so that we can then focus on smaller patches
|
| 328 |
+
# patch_feat_unfold shape:
|
| 329 |
+
# B x S x N x C_out x (psize - 2*sradius) x (psize - 2*sradius) x ssize x ssize
|
| 330 |
+
# well a bit scary, but actually not
|
| 331 |
+
patch_feat_unfold = patch_feat.unfold(4, ssize, 1).unfold(5, ssize, 1)
|
| 332 |
+
|
| 333 |
+
# Do the same stuffs above, i.e., the same as extracting patches
|
| 334 |
+
fine_prediction_floor = fine_pred_track.floor().int()
|
| 335 |
+
fine_level_floor_topleft = fine_prediction_floor - sradius
|
| 336 |
+
|
| 337 |
+
# Clamp to ensure the smaller patch is valid
|
| 338 |
+
fine_level_floor_topleft = fine_level_floor_topleft.clamp(0, psize - ssize)
|
| 339 |
+
fine_level_floor_topleft = fine_level_floor_topleft.squeeze(2)
|
| 340 |
+
|
| 341 |
+
# Prepare the batch indices and xy locations
|
| 342 |
+
|
| 343 |
+
batch_indices_score = torch.arange(B)[:, None, None].expand(-1, S, N) # BxSxN
|
| 344 |
+
batch_indices_score = batch_indices_score.reshape(-1).to(patch_feat_unfold.device) # B*S*N
|
| 345 |
+
y_indices = fine_level_floor_topleft[..., 0].flatten() # Flatten H indices
|
| 346 |
+
x_indices = fine_level_floor_topleft[..., 1].flatten() # Flatten W indices
|
| 347 |
+
|
| 348 |
+
reference_frame_feat = patch_feat_unfold.reshape(
|
| 349 |
+
B * S * N, C_out, psize - sradius * 2, psize - sradius * 2, ssize, ssize
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
# Note again, according to pytorch convention
|
| 353 |
+
# x_indices cooresponds to [..., 1] and y_indices cooresponds to [..., 0]
|
| 354 |
+
reference_frame_feat = reference_frame_feat[batch_indices_score, :, x_indices, y_indices]
|
| 355 |
+
reference_frame_feat = reference_frame_feat.reshape(B, S, N, C_out, ssize, ssize)
|
| 356 |
+
# pick the frames other than the first one, so we have S-1 frames here
|
| 357 |
+
reference_frame_feat = reference_frame_feat[:, 1:].reshape(B * (S - 1) * N, C_out, ssize * ssize)
|
| 358 |
+
|
| 359 |
+
# Compute similarity
|
| 360 |
+
sim_matrix = torch.einsum("mc,mcr->mr", query_point_feat, reference_frame_feat)
|
| 361 |
+
softmax_temp = 1.0 / C_out**0.5
|
| 362 |
+
heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1)
|
| 363 |
+
# 2D heatmaps
|
| 364 |
+
heatmap = heatmap.reshape(B * (S - 1) * N, ssize, ssize) # * x ssize x ssize
|
| 365 |
+
|
| 366 |
+
coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0]
|
| 367 |
+
grid_normalized = create_meshgrid(ssize, ssize, normalized_coordinates=True, device=heatmap.device).reshape(
|
| 368 |
+
1, -1, 2
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
var = torch.sum(grid_normalized**2 * heatmap.view(-1, ssize * ssize, 1), dim=1) - coords_normalized**2
|
| 372 |
+
std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # clamp needed for numerical stability
|
| 373 |
+
|
| 374 |
+
score = std.reshape(B, S - 1, N)
|
| 375 |
+
# set score as 1 for the query frame
|
| 376 |
+
score = torch.cat([torch.ones_like(score[:, 0:1]), score], dim=1)
|
| 377 |
+
|
| 378 |
+
return score
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def extract_glimpse(
|
| 382 |
+
tensor: torch.Tensor, size: Tuple[int, int], offsets, mode="bilinear", padding_mode="zeros", debug=False, orib=None
|
| 383 |
+
):
|
| 384 |
+
B, C, W, H = tensor.shape
|
| 385 |
+
|
| 386 |
+
h, w = size
|
| 387 |
+
xs = torch.arange(0, w, dtype=tensor.dtype, device=tensor.device) - (w - 1) / 2.0
|
| 388 |
+
ys = torch.arange(0, h, dtype=tensor.dtype, device=tensor.device) - (h - 1) / 2.0
|
| 389 |
+
|
| 390 |
+
vy, vx = torch.meshgrid(ys, xs)
|
| 391 |
+
grid = torch.stack([vx, vy], dim=-1) # h, w, 2
|
| 392 |
+
grid = grid[None]
|
| 393 |
+
|
| 394 |
+
B, N, _ = offsets.shape
|
| 395 |
+
|
| 396 |
+
offsets = offsets.reshape((B * N), 1, 1, 2)
|
| 397 |
+
offsets_grid = offsets + grid
|
| 398 |
+
|
| 399 |
+
# normalised grid to [-1, 1]
|
| 400 |
+
offsets_grid = (offsets_grid - offsets_grid.new_tensor([W / 2, H / 2])) / offsets_grid.new_tensor([W / 2, H / 2])
|
| 401 |
+
|
| 402 |
+
# BxCxHxW -> Bx1xCxHxW
|
| 403 |
+
tensor = tensor[:, None]
|
| 404 |
+
|
| 405 |
+
# Bx1xCxHxW -> BxNxCxHxW
|
| 406 |
+
tensor = tensor.expand(-1, N, -1, -1, -1)
|
| 407 |
+
|
| 408 |
+
# BxNxCxHxW -> (B*N)xCxHxW
|
| 409 |
+
tensor = tensor.reshape((B * N), C, W, H)
|
| 410 |
+
|
| 411 |
+
sampled = torch.nn.functional.grid_sample(
|
| 412 |
+
tensor, offsets_grid, mode=mode, align_corners=False, padding_mode=padding_mode
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
# NOTE: I am not sure it should be h, w or w, h here
|
| 416 |
+
# but okay for sqaures
|
| 417 |
+
sampled = sampled.reshape(B, N, C, h, w)
|
| 418 |
+
|
| 419 |
+
return sampled
|
capvector-pi05/src/vggt/dependency/track_modules/utils.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Modified from https://github.com/facebookresearch/PoseDiffusion
|
| 8 |
+
# and https://github.com/facebookresearch/co-tracker/tree/main
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
from typing import Optional, Tuple, Union
|
| 16 |
+
from einops import rearrange, repeat
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor:
|
| 20 |
+
"""
|
| 21 |
+
This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
|
| 22 |
+
It is a wrapper of get_2d_sincos_pos_embed_from_grid.
|
| 23 |
+
Args:
|
| 24 |
+
- embed_dim: The embedding dimension.
|
| 25 |
+
- grid_size: The grid size.
|
| 26 |
+
Returns:
|
| 27 |
+
- pos_embed: The generated 2D positional embedding.
|
| 28 |
+
"""
|
| 29 |
+
if isinstance(grid_size, tuple):
|
| 30 |
+
grid_size_h, grid_size_w = grid_size
|
| 31 |
+
else:
|
| 32 |
+
grid_size_h = grid_size_w = grid_size
|
| 33 |
+
grid_h = torch.arange(grid_size_h, dtype=torch.float)
|
| 34 |
+
grid_w = torch.arange(grid_size_w, dtype=torch.float)
|
| 35 |
+
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
|
| 36 |
+
grid = torch.stack(grid, dim=0)
|
| 37 |
+
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
|
| 38 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 39 |
+
if return_grid:
|
| 40 |
+
return (pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2), grid)
|
| 41 |
+
return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor:
|
| 45 |
+
"""
|
| 46 |
+
This function generates a 2D positional embedding from a given grid using sine and cosine functions.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
- embed_dim: The embedding dimension.
|
| 50 |
+
- grid: The grid to generate the embedding from.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
- emb: The generated 2D positional embedding.
|
| 54 |
+
"""
|
| 55 |
+
assert embed_dim % 2 == 0
|
| 56 |
+
|
| 57 |
+
# use half of dimensions to encode grid_h
|
| 58 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 59 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 60 |
+
|
| 61 |
+
emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
|
| 62 |
+
return emb
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor:
|
| 66 |
+
"""
|
| 67 |
+
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
- embed_dim: The embedding dimension.
|
| 71 |
+
- pos: The position to generate the embedding from.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
- emb: The generated 1D positional embedding.
|
| 75 |
+
"""
|
| 76 |
+
assert embed_dim % 2 == 0
|
| 77 |
+
omega = torch.arange(embed_dim // 2, dtype=torch.double)
|
| 78 |
+
omega /= embed_dim / 2.0
|
| 79 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
| 80 |
+
|
| 81 |
+
pos = pos.reshape(-1) # (M,)
|
| 82 |
+
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 83 |
+
|
| 84 |
+
emb_sin = torch.sin(out) # (M, D/2)
|
| 85 |
+
emb_cos = torch.cos(out) # (M, D/2)
|
| 86 |
+
|
| 87 |
+
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
| 88 |
+
return emb[None].float()
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
|
| 92 |
+
"""
|
| 93 |
+
This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
- xy: The coordinates to generate the embedding from.
|
| 97 |
+
- C: The size of the embedding.
|
| 98 |
+
- cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
- pe: The generated 2D positional embedding.
|
| 102 |
+
"""
|
| 103 |
+
B, N, D = xy.shape
|
| 104 |
+
assert D == 2
|
| 105 |
+
|
| 106 |
+
x = xy[:, :, 0:1]
|
| 107 |
+
y = xy[:, :, 1:2]
|
| 108 |
+
div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2))
|
| 109 |
+
|
| 110 |
+
pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
| 111 |
+
pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
| 112 |
+
|
| 113 |
+
pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
| 114 |
+
pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
| 115 |
+
|
| 116 |
+
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
| 117 |
+
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
| 118 |
+
|
| 119 |
+
pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
|
| 120 |
+
if cat_coords:
|
| 121 |
+
pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
|
| 122 |
+
return pe
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
|
| 126 |
+
r"""Sample a tensor using bilinear interpolation
|
| 127 |
+
|
| 128 |
+
`bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
|
| 129 |
+
coordinates :attr:`coords` using bilinear interpolation. It is the same
|
| 130 |
+
as `torch.nn.functional.grid_sample()` but with a different coordinate
|
| 131 |
+
convention.
|
| 132 |
+
|
| 133 |
+
The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
|
| 134 |
+
:math:`B` is the batch size, :math:`C` is the number of channels,
|
| 135 |
+
:math:`H` is the height of the image, and :math:`W` is the width of the
|
| 136 |
+
image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
|
| 137 |
+
interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
|
| 138 |
+
|
| 139 |
+
Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
|
| 140 |
+
in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
|
| 141 |
+
that in this case the order of the components is slightly different
|
| 142 |
+
from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
|
| 143 |
+
|
| 144 |
+
If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
|
| 145 |
+
in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
|
| 146 |
+
left-most image pixel :math:`W-1` to the center of the right-most
|
| 147 |
+
pixel.
|
| 148 |
+
|
| 149 |
+
If `align_corners` is `False`, the coordinate :math:`x` is assumed to
|
| 150 |
+
be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
|
| 151 |
+
the left-most pixel :math:`W` to the right edge of the right-most
|
| 152 |
+
pixel.
|
| 153 |
+
|
| 154 |
+
Similar conventions apply to the :math:`y` for the range
|
| 155 |
+
:math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
|
| 156 |
+
:math:`[0,T-1]` and :math:`[0,T]`.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
input (Tensor): batch of input images.
|
| 160 |
+
coords (Tensor): batch of coordinates.
|
| 161 |
+
align_corners (bool, optional): Coordinate convention. Defaults to `True`.
|
| 162 |
+
padding_mode (str, optional): Padding mode. Defaults to `"border"`.
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
Tensor: sampled points.
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
sizes = input.shape[2:]
|
| 169 |
+
|
| 170 |
+
assert len(sizes) in [2, 3]
|
| 171 |
+
|
| 172 |
+
if len(sizes) == 3:
|
| 173 |
+
# t x y -> x y t to match dimensions T H W in grid_sample
|
| 174 |
+
coords = coords[..., [1, 2, 0]]
|
| 175 |
+
|
| 176 |
+
if align_corners:
|
| 177 |
+
coords = coords * torch.tensor([2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device)
|
| 178 |
+
else:
|
| 179 |
+
coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device)
|
| 180 |
+
|
| 181 |
+
coords -= 1
|
| 182 |
+
|
| 183 |
+
return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def sample_features4d(input, coords):
|
| 187 |
+
r"""Sample spatial features
|
| 188 |
+
|
| 189 |
+
`sample_features4d(input, coords)` samples the spatial features
|
| 190 |
+
:attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
|
| 191 |
+
|
| 192 |
+
The field is sampled at coordinates :attr:`coords` using bilinear
|
| 193 |
+
interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
|
| 194 |
+
2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
|
| 195 |
+
same convention as :func:`bilinear_sampler` with `align_corners=True`.
|
| 196 |
+
|
| 197 |
+
The output tensor has one feature per point, and has shape :math:`(B,
|
| 198 |
+
R, C)`.
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
input (Tensor): spatial features.
|
| 202 |
+
coords (Tensor): points.
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
Tensor: sampled features.
|
| 206 |
+
"""
|
| 207 |
+
|
| 208 |
+
B, _, _, _ = input.shape
|
| 209 |
+
|
| 210 |
+
# B R 2 -> B R 1 2
|
| 211 |
+
coords = coords.unsqueeze(2)
|
| 212 |
+
|
| 213 |
+
# B C R 1
|
| 214 |
+
feats = bilinear_sampler(input, coords)
|
| 215 |
+
|
| 216 |
+
return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C
|
capvector-pi05/src/vggt/dependency/track_predict.py
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
from .vggsfm_utils import *
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def predict_tracks(
|
| 13 |
+
images,
|
| 14 |
+
conf=None,
|
| 15 |
+
points_3d=None,
|
| 16 |
+
masks=None,
|
| 17 |
+
max_query_pts=2048,
|
| 18 |
+
query_frame_num=5,
|
| 19 |
+
keypoint_extractor="aliked+sp",
|
| 20 |
+
max_points_num=163840,
|
| 21 |
+
fine_tracking=True,
|
| 22 |
+
complete_non_vis=True,
|
| 23 |
+
):
|
| 24 |
+
"""
|
| 25 |
+
Predict tracks for the given images and masks.
|
| 26 |
+
|
| 27 |
+
TODO: support non-square images
|
| 28 |
+
TODO: support masks
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
This function predicts the tracks for the given images and masks using the specified query method
|
| 32 |
+
and track predictor. It finds query points, and predicts the tracks, visibility, and scores for the query frames.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
images: Tensor of shape [S, 3, H, W] containing the input images.
|
| 36 |
+
conf: Tensor of shape [S, 1, H, W] containing the confidence scores. Default is None.
|
| 37 |
+
points_3d: Tensor containing 3D points. Default is None.
|
| 38 |
+
masks: Optional tensor of shape [S, 1, H, W] containing masks. Default is None.
|
| 39 |
+
max_query_pts: Maximum number of query points. Default is 2048.
|
| 40 |
+
query_frame_num: Number of query frames to use. Default is 5.
|
| 41 |
+
keypoint_extractor: Method for keypoint extraction. Default is "aliked+sp".
|
| 42 |
+
max_points_num: Maximum number of points to process at once. Default is 163840.
|
| 43 |
+
fine_tracking: Whether to use fine tracking. Default is True.
|
| 44 |
+
complete_non_vis: Whether to augment non-visible frames. Default is True.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
pred_tracks: Numpy array containing the predicted tracks.
|
| 48 |
+
pred_vis_scores: Numpy array containing the visibility scores for the tracks.
|
| 49 |
+
pred_confs: Numpy array containing the confidence scores for the tracks.
|
| 50 |
+
pred_points_3d: Numpy array containing the 3D points for the tracks.
|
| 51 |
+
pred_colors: Numpy array containing the point colors for the tracks. (0, 255)
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
device = images.device
|
| 55 |
+
dtype = images.dtype
|
| 56 |
+
tracker = build_vggsfm_tracker().to(device, dtype)
|
| 57 |
+
|
| 58 |
+
# Find query frames
|
| 59 |
+
query_frame_indexes = generate_rank_by_dino(images, query_frame_num=query_frame_num, device=device)
|
| 60 |
+
|
| 61 |
+
# Add the first image to the front if not already present
|
| 62 |
+
if 0 in query_frame_indexes:
|
| 63 |
+
query_frame_indexes.remove(0)
|
| 64 |
+
query_frame_indexes = [0, *query_frame_indexes]
|
| 65 |
+
|
| 66 |
+
# TODO: add the functionality to handle the masks
|
| 67 |
+
keypoint_extractors = initialize_feature_extractors(
|
| 68 |
+
max_query_pts, extractor_method=keypoint_extractor, device=device
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
pred_tracks = []
|
| 72 |
+
pred_vis_scores = []
|
| 73 |
+
pred_confs = []
|
| 74 |
+
pred_points_3d = []
|
| 75 |
+
pred_colors = []
|
| 76 |
+
|
| 77 |
+
fmaps_for_tracker = tracker.process_images_to_fmaps(images)
|
| 78 |
+
|
| 79 |
+
if fine_tracking:
|
| 80 |
+
print("For faster inference, consider disabling fine_tracking")
|
| 81 |
+
|
| 82 |
+
for query_index in query_frame_indexes:
|
| 83 |
+
print(f"Predicting tracks for query frame {query_index}")
|
| 84 |
+
pred_track, pred_vis, pred_conf, pred_point_3d, pred_color = _forward_on_query(
|
| 85 |
+
query_index,
|
| 86 |
+
images,
|
| 87 |
+
conf,
|
| 88 |
+
points_3d,
|
| 89 |
+
fmaps_for_tracker,
|
| 90 |
+
keypoint_extractors,
|
| 91 |
+
tracker,
|
| 92 |
+
max_points_num,
|
| 93 |
+
fine_tracking,
|
| 94 |
+
device,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
pred_tracks.append(pred_track)
|
| 98 |
+
pred_vis_scores.append(pred_vis)
|
| 99 |
+
pred_confs.append(pred_conf)
|
| 100 |
+
pred_points_3d.append(pred_point_3d)
|
| 101 |
+
pred_colors.append(pred_color)
|
| 102 |
+
|
| 103 |
+
if complete_non_vis:
|
| 104 |
+
pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors = _augment_non_visible_frames(
|
| 105 |
+
pred_tracks,
|
| 106 |
+
pred_vis_scores,
|
| 107 |
+
pred_confs,
|
| 108 |
+
pred_points_3d,
|
| 109 |
+
pred_colors,
|
| 110 |
+
images,
|
| 111 |
+
conf,
|
| 112 |
+
points_3d,
|
| 113 |
+
fmaps_for_tracker,
|
| 114 |
+
keypoint_extractors,
|
| 115 |
+
tracker,
|
| 116 |
+
max_points_num,
|
| 117 |
+
fine_tracking,
|
| 118 |
+
min_vis=500,
|
| 119 |
+
non_vis_thresh=0.1,
|
| 120 |
+
device=device,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
pred_tracks = np.concatenate(pred_tracks, axis=1)
|
| 124 |
+
pred_vis_scores = np.concatenate(pred_vis_scores, axis=1)
|
| 125 |
+
pred_confs = np.concatenate(pred_confs, axis=0) if pred_confs else None
|
| 126 |
+
pred_points_3d = np.concatenate(pred_points_3d, axis=0) if pred_points_3d else None
|
| 127 |
+
pred_colors = np.concatenate(pred_colors, axis=0) if pred_colors else None
|
| 128 |
+
|
| 129 |
+
# from vggt.utils.visual_track import visualize_tracks_on_images
|
| 130 |
+
# visualize_tracks_on_images(images[None], torch.from_numpy(pred_tracks[None]), torch.from_numpy(pred_vis_scores[None])>0.2, out_dir="track_visuals")
|
| 131 |
+
|
| 132 |
+
return pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _forward_on_query(
|
| 136 |
+
query_index,
|
| 137 |
+
images,
|
| 138 |
+
conf,
|
| 139 |
+
points_3d,
|
| 140 |
+
fmaps_for_tracker,
|
| 141 |
+
keypoint_extractors,
|
| 142 |
+
tracker,
|
| 143 |
+
max_points_num,
|
| 144 |
+
fine_tracking,
|
| 145 |
+
device,
|
| 146 |
+
):
|
| 147 |
+
"""
|
| 148 |
+
Process a single query frame for track prediction.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
query_index: Index of the query frame
|
| 152 |
+
images: Tensor of shape [S, 3, H, W] containing the input images
|
| 153 |
+
conf: Confidence tensor
|
| 154 |
+
points_3d: 3D points tensor
|
| 155 |
+
fmaps_for_tracker: Feature maps for the tracker
|
| 156 |
+
keypoint_extractors: Initialized feature extractors
|
| 157 |
+
tracker: VGG-SFM tracker
|
| 158 |
+
max_points_num: Maximum number of points to process at once
|
| 159 |
+
fine_tracking: Whether to use fine tracking
|
| 160 |
+
device: Device to use for computation
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
pred_track: Predicted tracks
|
| 164 |
+
pred_vis: Visibility scores for the tracks
|
| 165 |
+
pred_conf: Confidence scores for the tracks
|
| 166 |
+
pred_point_3d: 3D points for the tracks
|
| 167 |
+
pred_color: Point colors for the tracks (0, 255)
|
| 168 |
+
"""
|
| 169 |
+
frame_num, _, height, width = images.shape
|
| 170 |
+
|
| 171 |
+
query_image = images[query_index]
|
| 172 |
+
query_points = extract_keypoints(query_image, keypoint_extractors, round_keypoints=False)
|
| 173 |
+
query_points = query_points[:, torch.randperm(query_points.shape[1], device=device)]
|
| 174 |
+
|
| 175 |
+
# Extract the color at the keypoint locations
|
| 176 |
+
query_points_long = query_points.squeeze(0).round().long()
|
| 177 |
+
pred_color = images[query_index][:, query_points_long[:, 1], query_points_long[:, 0]]
|
| 178 |
+
pred_color = (pred_color.permute(1, 0).cpu().numpy() * 255).astype(np.uint8)
|
| 179 |
+
|
| 180 |
+
# Query the confidence and points_3d at the keypoint locations
|
| 181 |
+
if (conf is not None) and (points_3d is not None):
|
| 182 |
+
assert height == width
|
| 183 |
+
assert conf.shape[-2] == conf.shape[-1]
|
| 184 |
+
assert conf.shape[:3] == points_3d.shape[:3]
|
| 185 |
+
scale = conf.shape[-1] / width
|
| 186 |
+
|
| 187 |
+
query_points_scaled = (query_points.squeeze(0) * scale).round().long()
|
| 188 |
+
query_points_scaled = query_points_scaled.cpu().numpy()
|
| 189 |
+
|
| 190 |
+
pred_conf = conf[query_index][query_points_scaled[:, 1], query_points_scaled[:, 0]]
|
| 191 |
+
pred_point_3d = points_3d[query_index][query_points_scaled[:, 1], query_points_scaled[:, 0]]
|
| 192 |
+
|
| 193 |
+
# heuristic to remove low confidence points
|
| 194 |
+
# should I export this as an input parameter?
|
| 195 |
+
valid_mask = pred_conf > 1.2
|
| 196 |
+
if valid_mask.sum() > 512:
|
| 197 |
+
query_points = query_points[:, valid_mask] # Make sure shape is compatible
|
| 198 |
+
pred_conf = pred_conf[valid_mask]
|
| 199 |
+
pred_point_3d = pred_point_3d[valid_mask]
|
| 200 |
+
pred_color = pred_color[valid_mask]
|
| 201 |
+
else:
|
| 202 |
+
pred_conf = None
|
| 203 |
+
pred_point_3d = None
|
| 204 |
+
|
| 205 |
+
reorder_index = calculate_index_mappings(query_index, frame_num, device=device)
|
| 206 |
+
|
| 207 |
+
images_feed, fmaps_feed = switch_tensor_order([images, fmaps_for_tracker], reorder_index, dim=0)
|
| 208 |
+
images_feed = images_feed[None] # add batch dimension
|
| 209 |
+
fmaps_feed = fmaps_feed[None] # add batch dimension
|
| 210 |
+
|
| 211 |
+
all_points_num = images_feed.shape[1] * query_points.shape[1]
|
| 212 |
+
|
| 213 |
+
# Don't need to be scared, this is just chunking to make GPU happy
|
| 214 |
+
if all_points_num > max_points_num:
|
| 215 |
+
num_splits = (all_points_num + max_points_num - 1) // max_points_num
|
| 216 |
+
query_points = torch.chunk(query_points, num_splits, dim=1)
|
| 217 |
+
else:
|
| 218 |
+
query_points = [query_points]
|
| 219 |
+
|
| 220 |
+
pred_track, pred_vis, _ = predict_tracks_in_chunks(
|
| 221 |
+
tracker, images_feed, query_points, fmaps_feed, fine_tracking=fine_tracking
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
pred_track, pred_vis = switch_tensor_order([pred_track, pred_vis], reorder_index, dim=1)
|
| 225 |
+
|
| 226 |
+
pred_track = pred_track.squeeze(0).float().cpu().numpy()
|
| 227 |
+
pred_vis = pred_vis.squeeze(0).float().cpu().numpy()
|
| 228 |
+
|
| 229 |
+
return pred_track, pred_vis, pred_conf, pred_point_3d, pred_color
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def _augment_non_visible_frames(
|
| 233 |
+
pred_tracks: list, # ← running list of np.ndarrays
|
| 234 |
+
pred_vis_scores: list, # ← running list of np.ndarrays
|
| 235 |
+
pred_confs: list, # ← running list of np.ndarrays for confidence scores
|
| 236 |
+
pred_points_3d: list, # ← running list of np.ndarrays for 3D points
|
| 237 |
+
pred_colors: list, # ← running list of np.ndarrays for colors
|
| 238 |
+
images: torch.Tensor,
|
| 239 |
+
conf,
|
| 240 |
+
points_3d,
|
| 241 |
+
fmaps_for_tracker,
|
| 242 |
+
keypoint_extractors,
|
| 243 |
+
tracker,
|
| 244 |
+
max_points_num: int,
|
| 245 |
+
fine_tracking: bool,
|
| 246 |
+
*,
|
| 247 |
+
min_vis: int = 500,
|
| 248 |
+
non_vis_thresh: float = 0.1,
|
| 249 |
+
device: torch.device = None,
|
| 250 |
+
):
|
| 251 |
+
"""
|
| 252 |
+
Augment tracking for frames with insufficient visibility.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
pred_tracks: List of numpy arrays containing predicted tracks.
|
| 256 |
+
pred_vis_scores: List of numpy arrays containing visibility scores.
|
| 257 |
+
pred_confs: List of numpy arrays containing confidence scores.
|
| 258 |
+
pred_points_3d: List of numpy arrays containing 3D points.
|
| 259 |
+
pred_colors: List of numpy arrays containing point colors.
|
| 260 |
+
images: Tensor of shape [S, 3, H, W] containing the input images.
|
| 261 |
+
conf: Tensor of shape [S, 1, H, W] containing confidence scores
|
| 262 |
+
points_3d: Tensor containing 3D points
|
| 263 |
+
fmaps_for_tracker: Feature maps for the tracker
|
| 264 |
+
keypoint_extractors: Initialized feature extractors
|
| 265 |
+
tracker: VGG-SFM tracker
|
| 266 |
+
max_points_num: Maximum number of points to process at once
|
| 267 |
+
fine_tracking: Whether to use fine tracking
|
| 268 |
+
min_vis: Minimum visibility threshold
|
| 269 |
+
non_vis_thresh: Non-visibility threshold
|
| 270 |
+
device: Device to use for computation
|
| 271 |
+
|
| 272 |
+
Returns:
|
| 273 |
+
Updated pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, and pred_colors lists.
|
| 274 |
+
"""
|
| 275 |
+
last_query = -1
|
| 276 |
+
final_trial = False
|
| 277 |
+
cur_extractors = keypoint_extractors # may be replaced on the final trial
|
| 278 |
+
|
| 279 |
+
while True:
|
| 280 |
+
# Visibility per frame
|
| 281 |
+
vis_array = np.concatenate(pred_vis_scores, axis=1)
|
| 282 |
+
|
| 283 |
+
# Count frames with sufficient visibility using numpy
|
| 284 |
+
sufficient_vis_count = (vis_array > non_vis_thresh).sum(axis=-1)
|
| 285 |
+
non_vis_frames = np.where(sufficient_vis_count < min_vis)[0].tolist()
|
| 286 |
+
|
| 287 |
+
if len(non_vis_frames) == 0:
|
| 288 |
+
break
|
| 289 |
+
|
| 290 |
+
print("Processing non visible frames:", non_vis_frames)
|
| 291 |
+
|
| 292 |
+
# Decide the frames & extractor for this round
|
| 293 |
+
if non_vis_frames[0] == last_query:
|
| 294 |
+
# Same frame failed twice - final "all-in" attempt
|
| 295 |
+
final_trial = True
|
| 296 |
+
cur_extractors = initialize_feature_extractors(2048, extractor_method="sp+sift+aliked", device=device)
|
| 297 |
+
query_frame_list = non_vis_frames # blast them all at once
|
| 298 |
+
else:
|
| 299 |
+
query_frame_list = [non_vis_frames[0]] # Process one at a time
|
| 300 |
+
|
| 301 |
+
last_query = non_vis_frames[0]
|
| 302 |
+
|
| 303 |
+
# Run the tracker for every selected frame
|
| 304 |
+
for query_index in query_frame_list:
|
| 305 |
+
new_track, new_vis, new_conf, new_point_3d, new_color = _forward_on_query(
|
| 306 |
+
query_index,
|
| 307 |
+
images,
|
| 308 |
+
conf,
|
| 309 |
+
points_3d,
|
| 310 |
+
fmaps_for_tracker,
|
| 311 |
+
cur_extractors,
|
| 312 |
+
tracker,
|
| 313 |
+
max_points_num,
|
| 314 |
+
fine_tracking,
|
| 315 |
+
device,
|
| 316 |
+
)
|
| 317 |
+
pred_tracks.append(new_track)
|
| 318 |
+
pred_vis_scores.append(new_vis)
|
| 319 |
+
pred_confs.append(new_conf)
|
| 320 |
+
pred_points_3d.append(new_point_3d)
|
| 321 |
+
pred_colors.append(new_color)
|
| 322 |
+
|
| 323 |
+
if final_trial:
|
| 324 |
+
break # Stop after final attempt
|
| 325 |
+
|
| 326 |
+
return pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors
|
capvector-pi05/src/vggt/dependency/vggsfm_tracker.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from functools import partial
|
| 13 |
+
from torch import nn, einsum
|
| 14 |
+
from einops import rearrange, repeat
|
| 15 |
+
from einops.layers.torch import Rearrange, Reduce
|
| 16 |
+
|
| 17 |
+
from hydra.utils import instantiate
|
| 18 |
+
from omegaconf import OmegaConf
|
| 19 |
+
|
| 20 |
+
from .track_modules.track_refine import refine_track
|
| 21 |
+
from .track_modules.blocks import BasicEncoder, ShallowEncoder
|
| 22 |
+
from .track_modules.base_track_predictor import BaseTrackerPredictor
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class TrackerPredictor(nn.Module):
|
| 26 |
+
def __init__(self, **extra_args):
|
| 27 |
+
super(TrackerPredictor, self).__init__()
|
| 28 |
+
"""
|
| 29 |
+
Initializes the tracker predictor.
|
| 30 |
+
|
| 31 |
+
Both coarse_predictor and fine_predictor are constructed as a BaseTrackerPredictor,
|
| 32 |
+
check track_modules/base_track_predictor.py
|
| 33 |
+
|
| 34 |
+
Both coarse_fnet and fine_fnet are constructed as a 2D CNN network
|
| 35 |
+
check track_modules/blocks.py for BasicEncoder and ShallowEncoder
|
| 36 |
+
"""
|
| 37 |
+
# Define coarse predictor configuration
|
| 38 |
+
coarse_stride = 4
|
| 39 |
+
self.coarse_down_ratio = 2
|
| 40 |
+
|
| 41 |
+
# Create networks directly instead of using instantiate
|
| 42 |
+
self.coarse_fnet = BasicEncoder(stride=coarse_stride)
|
| 43 |
+
self.coarse_predictor = BaseTrackerPredictor(stride=coarse_stride)
|
| 44 |
+
|
| 45 |
+
# Create fine predictor with stride = 1
|
| 46 |
+
self.fine_fnet = ShallowEncoder(stride=1)
|
| 47 |
+
self.fine_predictor = BaseTrackerPredictor(
|
| 48 |
+
stride=1,
|
| 49 |
+
depth=4,
|
| 50 |
+
corr_levels=3,
|
| 51 |
+
corr_radius=3,
|
| 52 |
+
latent_dim=32,
|
| 53 |
+
hidden_size=256,
|
| 54 |
+
fine=True,
|
| 55 |
+
use_spaceatt=False,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
def forward(
|
| 59 |
+
self, images, query_points, fmaps=None, coarse_iters=6, inference=True, fine_tracking=True, fine_chunk=40960
|
| 60 |
+
):
|
| 61 |
+
"""
|
| 62 |
+
Args:
|
| 63 |
+
images (torch.Tensor): Images as RGB, in the range of [0, 1], with a shape of B x S x 3 x H x W.
|
| 64 |
+
query_points (torch.Tensor): 2D xy of query points, relative to top left, with a shape of B x N x 2.
|
| 65 |
+
fmaps (torch.Tensor, optional): Precomputed feature maps. Defaults to None.
|
| 66 |
+
coarse_iters (int, optional): Number of iterations for coarse prediction. Defaults to 6.
|
| 67 |
+
inference (bool, optional): Whether to perform inference. Defaults to True.
|
| 68 |
+
fine_tracking (bool, optional): Whether to perform fine tracking. Defaults to True.
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
tuple: A tuple containing fine_pred_track, coarse_pred_track, pred_vis, and pred_score.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
if fmaps is None:
|
| 75 |
+
batch_num, frame_num, image_dim, height, width = images.shape
|
| 76 |
+
reshaped_image = images.reshape(batch_num * frame_num, image_dim, height, width)
|
| 77 |
+
fmaps = self.process_images_to_fmaps(reshaped_image)
|
| 78 |
+
fmaps = fmaps.reshape(batch_num, frame_num, -1, fmaps.shape[-2], fmaps.shape[-1])
|
| 79 |
+
|
| 80 |
+
if inference:
|
| 81 |
+
torch.cuda.empty_cache()
|
| 82 |
+
|
| 83 |
+
# Coarse prediction
|
| 84 |
+
coarse_pred_track_lists, pred_vis = self.coarse_predictor(
|
| 85 |
+
query_points=query_points, fmaps=fmaps, iters=coarse_iters, down_ratio=self.coarse_down_ratio
|
| 86 |
+
)
|
| 87 |
+
coarse_pred_track = coarse_pred_track_lists[-1]
|
| 88 |
+
|
| 89 |
+
if inference:
|
| 90 |
+
torch.cuda.empty_cache()
|
| 91 |
+
|
| 92 |
+
if fine_tracking:
|
| 93 |
+
# Refine the coarse prediction
|
| 94 |
+
fine_pred_track, pred_score = refine_track(
|
| 95 |
+
images, self.fine_fnet, self.fine_predictor, coarse_pred_track, compute_score=False, chunk=fine_chunk
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
if inference:
|
| 99 |
+
torch.cuda.empty_cache()
|
| 100 |
+
else:
|
| 101 |
+
fine_pred_track = coarse_pred_track
|
| 102 |
+
pred_score = torch.ones_like(pred_vis)
|
| 103 |
+
|
| 104 |
+
return fine_pred_track, coarse_pred_track, pred_vis, pred_score
|
| 105 |
+
|
| 106 |
+
def process_images_to_fmaps(self, images):
|
| 107 |
+
"""
|
| 108 |
+
This function processes images for inference.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
images (torch.Tensor): The images to be processed with shape S x 3 x H x W.
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
torch.Tensor: The processed feature maps.
|
| 115 |
+
"""
|
| 116 |
+
if self.coarse_down_ratio > 1:
|
| 117 |
+
# whether or not scale down the input images to save memory
|
| 118 |
+
fmaps = self.coarse_fnet(
|
| 119 |
+
F.interpolate(images, scale_factor=1 / self.coarse_down_ratio, mode="bilinear", align_corners=True)
|
| 120 |
+
)
|
| 121 |
+
else:
|
| 122 |
+
fmaps = self.coarse_fnet(images)
|
| 123 |
+
|
| 124 |
+
return fmaps
|
capvector-pi05/src/vggt/dependency/vggsfm_utils.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import warnings
|
| 9 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pycolmap
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from lightglue import ALIKED, SIFT, SuperPoint
|
| 16 |
+
|
| 17 |
+
from .vggsfm_tracker import TrackerPredictor
|
| 18 |
+
|
| 19 |
+
# Suppress verbose logging from dependencies
|
| 20 |
+
logging.getLogger("dinov2").setLevel(logging.WARNING)
|
| 21 |
+
warnings.filterwarnings("ignore", message="xFormers is available")
|
| 22 |
+
warnings.filterwarnings("ignore", message="dinov2")
|
| 23 |
+
|
| 24 |
+
# Constants
|
| 25 |
+
_RESNET_MEAN = [0.485, 0.456, 0.406]
|
| 26 |
+
_RESNET_STD = [0.229, 0.224, 0.225]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def build_vggsfm_tracker(model_path=None):
|
| 30 |
+
"""
|
| 31 |
+
Build and initialize the VGGSfM tracker.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
model_path: Path to the model weights file. If None, weights are downloaded from HuggingFace.
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
Initialized tracker model in eval mode.
|
| 38 |
+
"""
|
| 39 |
+
tracker = TrackerPredictor()
|
| 40 |
+
|
| 41 |
+
if model_path is None:
|
| 42 |
+
default_url = "https://huggingface.co/facebook/VGGSfM/resolve/main/vggsfm_v2_tracker.pt"
|
| 43 |
+
tracker.load_state_dict(torch.hub.load_state_dict_from_url(default_url))
|
| 44 |
+
else:
|
| 45 |
+
tracker.load_state_dict(torch.load(model_path))
|
| 46 |
+
|
| 47 |
+
tracker.eval()
|
| 48 |
+
return tracker
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def generate_rank_by_dino(
|
| 52 |
+
images, query_frame_num, image_size=336, model_name="dinov2_vitb14_reg", device="cuda", spatial_similarity=False
|
| 53 |
+
):
|
| 54 |
+
"""
|
| 55 |
+
Generate a ranking of frames using DINO ViT features.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
images: Tensor of shape (S, 3, H, W) with values in range [0, 1]
|
| 59 |
+
query_frame_num: Number of frames to select
|
| 60 |
+
image_size: Size to resize images to before processing
|
| 61 |
+
model_name: Name of the DINO model to use
|
| 62 |
+
device: Device to run the model on
|
| 63 |
+
spatial_similarity: Whether to use spatial token similarity or CLS token similarity
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
List of frame indices ranked by their representativeness
|
| 67 |
+
"""
|
| 68 |
+
# Resize images to the target size
|
| 69 |
+
images = F.interpolate(images, (image_size, image_size), mode="bilinear", align_corners=False)
|
| 70 |
+
|
| 71 |
+
# Load DINO model
|
| 72 |
+
dino_v2_model = torch.hub.load("facebookresearch/dinov2", model_name)
|
| 73 |
+
dino_v2_model.eval()
|
| 74 |
+
dino_v2_model = dino_v2_model.to(device)
|
| 75 |
+
|
| 76 |
+
# Normalize images using ResNet normalization
|
| 77 |
+
resnet_mean = torch.tensor(_RESNET_MEAN, device=device).view(1, 3, 1, 1)
|
| 78 |
+
resnet_std = torch.tensor(_RESNET_STD, device=device).view(1, 3, 1, 1)
|
| 79 |
+
images_resnet_norm = (images - resnet_mean) / resnet_std
|
| 80 |
+
|
| 81 |
+
with torch.no_grad():
|
| 82 |
+
frame_feat = dino_v2_model(images_resnet_norm, is_training=True)
|
| 83 |
+
|
| 84 |
+
# Process features based on similarity type
|
| 85 |
+
if spatial_similarity:
|
| 86 |
+
frame_feat = frame_feat["x_norm_patchtokens"]
|
| 87 |
+
frame_feat_norm = F.normalize(frame_feat, p=2, dim=1)
|
| 88 |
+
|
| 89 |
+
# Compute the similarity matrix
|
| 90 |
+
frame_feat_norm = frame_feat_norm.permute(1, 0, 2)
|
| 91 |
+
similarity_matrix = torch.bmm(frame_feat_norm, frame_feat_norm.transpose(-1, -2))
|
| 92 |
+
similarity_matrix = similarity_matrix.mean(dim=0)
|
| 93 |
+
else:
|
| 94 |
+
frame_feat = frame_feat["x_norm_clstoken"]
|
| 95 |
+
frame_feat_norm = F.normalize(frame_feat, p=2, dim=1)
|
| 96 |
+
similarity_matrix = torch.mm(frame_feat_norm, frame_feat_norm.transpose(-1, -2))
|
| 97 |
+
|
| 98 |
+
distance_matrix = 100 - similarity_matrix.clone()
|
| 99 |
+
|
| 100 |
+
# Ignore self-pairing
|
| 101 |
+
similarity_matrix.fill_diagonal_(-100)
|
| 102 |
+
similarity_sum = similarity_matrix.sum(dim=1)
|
| 103 |
+
|
| 104 |
+
# Find the most common frame
|
| 105 |
+
most_common_frame_index = torch.argmax(similarity_sum).item()
|
| 106 |
+
|
| 107 |
+
# Conduct FPS sampling starting from the most common frame
|
| 108 |
+
fps_idx = farthest_point_sampling(distance_matrix, query_frame_num, most_common_frame_index)
|
| 109 |
+
|
| 110 |
+
# Clean up all tensors and models to free memory
|
| 111 |
+
del frame_feat, frame_feat_norm, similarity_matrix, distance_matrix
|
| 112 |
+
del dino_v2_model
|
| 113 |
+
torch.cuda.empty_cache()
|
| 114 |
+
|
| 115 |
+
return fps_idx
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def farthest_point_sampling(distance_matrix, num_samples, most_common_frame_index=0):
|
| 119 |
+
"""
|
| 120 |
+
Farthest point sampling algorithm to select diverse frames.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
distance_matrix: Matrix of distances between frames
|
| 124 |
+
num_samples: Number of frames to select
|
| 125 |
+
most_common_frame_index: Index of the first frame to select
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
List of selected frame indices
|
| 129 |
+
"""
|
| 130 |
+
distance_matrix = distance_matrix.clamp(min=0)
|
| 131 |
+
N = distance_matrix.size(0)
|
| 132 |
+
|
| 133 |
+
# Initialize with the most common frame
|
| 134 |
+
selected_indices = [most_common_frame_index]
|
| 135 |
+
check_distances = distance_matrix[selected_indices]
|
| 136 |
+
|
| 137 |
+
while len(selected_indices) < num_samples:
|
| 138 |
+
# Find the farthest point from the current set of selected points
|
| 139 |
+
farthest_point = torch.argmax(check_distances)
|
| 140 |
+
selected_indices.append(farthest_point.item())
|
| 141 |
+
|
| 142 |
+
check_distances = distance_matrix[farthest_point]
|
| 143 |
+
# Mark already selected points to avoid selecting them again
|
| 144 |
+
check_distances[selected_indices] = 0
|
| 145 |
+
|
| 146 |
+
# Break if all points have been selected
|
| 147 |
+
if len(selected_indices) == N:
|
| 148 |
+
break
|
| 149 |
+
|
| 150 |
+
return selected_indices
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def calculate_index_mappings(query_index, S, device=None):
|
| 154 |
+
"""
|
| 155 |
+
Construct an order that switches [query_index] and [0]
|
| 156 |
+
so that the content of query_index would be placed at [0].
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
query_index: Index to swap with 0
|
| 160 |
+
S: Total number of elements
|
| 161 |
+
device: Device to place the tensor on
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
Tensor of indices with the swapped order
|
| 165 |
+
"""
|
| 166 |
+
new_order = torch.arange(S)
|
| 167 |
+
new_order[0] = query_index
|
| 168 |
+
new_order[query_index] = 0
|
| 169 |
+
if device is not None:
|
| 170 |
+
new_order = new_order.to(device)
|
| 171 |
+
return new_order
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def switch_tensor_order(tensors, order, dim=1):
|
| 175 |
+
"""
|
| 176 |
+
Reorder tensors along a specific dimension according to the given order.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
tensors: List of tensors to reorder
|
| 180 |
+
order: Tensor of indices specifying the new order
|
| 181 |
+
dim: Dimension along which to reorder
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
List of reordered tensors
|
| 185 |
+
"""
|
| 186 |
+
return [torch.index_select(tensor, dim, order) if tensor is not None else None for tensor in tensors]
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def initialize_feature_extractors(max_query_num, det_thres=0.005, extractor_method="aliked", device="cuda"):
|
| 190 |
+
"""
|
| 191 |
+
Initialize feature extractors that can be reused based on a method string.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
max_query_num: Maximum number of keypoints to extract
|
| 195 |
+
det_thres: Detection threshold for keypoint extraction
|
| 196 |
+
extractor_method: String specifying which extractors to use (e.g., "aliked", "sp+sift", "aliked+sp+sift")
|
| 197 |
+
device: Device to run extraction on
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
Dictionary of initialized extractors
|
| 201 |
+
"""
|
| 202 |
+
extractors = {}
|
| 203 |
+
methods = extractor_method.lower().split("+")
|
| 204 |
+
|
| 205 |
+
for method in methods:
|
| 206 |
+
method = method.strip()
|
| 207 |
+
if method == "aliked":
|
| 208 |
+
aliked_extractor = ALIKED(max_num_keypoints=max_query_num, detection_threshold=det_thres)
|
| 209 |
+
extractors["aliked"] = aliked_extractor.to(device).eval()
|
| 210 |
+
elif method == "sp":
|
| 211 |
+
sp_extractor = SuperPoint(max_num_keypoints=max_query_num, detection_threshold=det_thres)
|
| 212 |
+
extractors["sp"] = sp_extractor.to(device).eval()
|
| 213 |
+
elif method == "sift":
|
| 214 |
+
sift_extractor = SIFT(max_num_keypoints=max_query_num)
|
| 215 |
+
extractors["sift"] = sift_extractor.to(device).eval()
|
| 216 |
+
else:
|
| 217 |
+
print(f"Warning: Unknown feature extractor '{method}', ignoring.")
|
| 218 |
+
|
| 219 |
+
if not extractors:
|
| 220 |
+
print(f"Warning: No valid extractors found in '{extractor_method}'. Using ALIKED by default.")
|
| 221 |
+
aliked_extractor = ALIKED(max_num_keypoints=max_query_num, detection_threshold=det_thres)
|
| 222 |
+
extractors["aliked"] = aliked_extractor.to(device).eval()
|
| 223 |
+
|
| 224 |
+
return extractors
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def extract_keypoints(query_image, extractors, round_keypoints=True):
|
| 228 |
+
"""
|
| 229 |
+
Extract keypoints using pre-initialized feature extractors.
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
query_image: Input image tensor (3xHxW, range [0, 1])
|
| 233 |
+
extractors: Dictionary of initialized extractors
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
Tensor of keypoint coordinates (1xNx2)
|
| 237 |
+
"""
|
| 238 |
+
query_points = None
|
| 239 |
+
|
| 240 |
+
with torch.no_grad():
|
| 241 |
+
for extractor_name, extractor in extractors.items():
|
| 242 |
+
query_points_data = extractor.extract(query_image, invalid_mask=None)
|
| 243 |
+
extractor_points = query_points_data["keypoints"]
|
| 244 |
+
if round_keypoints:
|
| 245 |
+
extractor_points = extractor_points.round()
|
| 246 |
+
|
| 247 |
+
if query_points is not None:
|
| 248 |
+
query_points = torch.cat([query_points, extractor_points], dim=1)
|
| 249 |
+
else:
|
| 250 |
+
query_points = extractor_points
|
| 251 |
+
|
| 252 |
+
return query_points
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def predict_tracks_in_chunks(
|
| 256 |
+
track_predictor, images_feed, query_points_list, fmaps_feed, fine_tracking, num_splits=None, fine_chunk=40960
|
| 257 |
+
):
|
| 258 |
+
"""
|
| 259 |
+
Process a list of query points to avoid memory issues.
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
track_predictor (object): The track predictor object used for predicting tracks.
|
| 263 |
+
images_feed (torch.Tensor): A tensor of shape (B, T, C, H, W) representing a batch of images.
|
| 264 |
+
query_points_list (list or tuple): A list/tuple of tensors, each of shape (B, Ni, 2) representing chunks of query points.
|
| 265 |
+
fmaps_feed (torch.Tensor): A tensor of feature maps for the tracker.
|
| 266 |
+
fine_tracking (bool): Whether to perform fine tracking.
|
| 267 |
+
num_splits (int, optional): Ignored when query_points_list is provided. Kept for backward compatibility.
|
| 268 |
+
|
| 269 |
+
Returns:
|
| 270 |
+
tuple: A tuple containing the concatenated predicted tracks, visibility, and scores.
|
| 271 |
+
"""
|
| 272 |
+
# If query_points_list is not a list or tuple but a single tensor, handle it like the old version for backward compatibility
|
| 273 |
+
if not isinstance(query_points_list, (list, tuple)):
|
| 274 |
+
query_points = query_points_list
|
| 275 |
+
if num_splits is None:
|
| 276 |
+
num_splits = 1
|
| 277 |
+
query_points_list = torch.chunk(query_points, num_splits, dim=1)
|
| 278 |
+
|
| 279 |
+
# Ensure query_points_list is a list for iteration (as torch.chunk returns a tuple)
|
| 280 |
+
if isinstance(query_points_list, tuple):
|
| 281 |
+
query_points_list = list(query_points_list)
|
| 282 |
+
|
| 283 |
+
fine_pred_track_list = []
|
| 284 |
+
pred_vis_list = []
|
| 285 |
+
pred_score_list = []
|
| 286 |
+
|
| 287 |
+
for split_points in query_points_list:
|
| 288 |
+
# Feed into track predictor for each split
|
| 289 |
+
fine_pred_track, _, pred_vis, pred_score = track_predictor(
|
| 290 |
+
images_feed, split_points, fmaps=fmaps_feed, fine_tracking=fine_tracking, fine_chunk=fine_chunk
|
| 291 |
+
)
|
| 292 |
+
fine_pred_track_list.append(fine_pred_track)
|
| 293 |
+
pred_vis_list.append(pred_vis)
|
| 294 |
+
pred_score_list.append(pred_score)
|
| 295 |
+
|
| 296 |
+
# Concatenate the results from all splits
|
| 297 |
+
fine_pred_track = torch.cat(fine_pred_track_list, dim=2)
|
| 298 |
+
pred_vis = torch.cat(pred_vis_list, dim=2)
|
| 299 |
+
|
| 300 |
+
if pred_score is not None:
|
| 301 |
+
pred_score = torch.cat(pred_score_list, dim=2)
|
| 302 |
+
else:
|
| 303 |
+
pred_score = None
|
| 304 |
+
|
| 305 |
+
return fine_pred_track, pred_vis, pred_score
|
capvector-pi05/src/vggt/heads/camera_head.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
from vggt.layers import Mlp
|
| 15 |
+
from vggt.layers.block import Block
|
| 16 |
+
from vggt.heads.head_act import activate_pose
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class CameraHead(nn.Module):
|
| 20 |
+
"""
|
| 21 |
+
CameraHead predicts camera parameters from token representations using iterative refinement.
|
| 22 |
+
|
| 23 |
+
It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
dim_in: int = 2048,
|
| 29 |
+
trunk_depth: int = 4,
|
| 30 |
+
pose_encoding_type: str = "absT_quaR_FoV",
|
| 31 |
+
num_heads: int = 16,
|
| 32 |
+
mlp_ratio: int = 4,
|
| 33 |
+
init_values: float = 0.01,
|
| 34 |
+
trans_act: str = "linear",
|
| 35 |
+
quat_act: str = "linear",
|
| 36 |
+
fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
|
| 37 |
+
):
|
| 38 |
+
super().__init__()
|
| 39 |
+
|
| 40 |
+
if pose_encoding_type == "absT_quaR_FoV":
|
| 41 |
+
self.target_dim = 9
|
| 42 |
+
else:
|
| 43 |
+
raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
|
| 44 |
+
|
| 45 |
+
self.trans_act = trans_act
|
| 46 |
+
self.quat_act = quat_act
|
| 47 |
+
self.fl_act = fl_act
|
| 48 |
+
self.trunk_depth = trunk_depth
|
| 49 |
+
|
| 50 |
+
# Build the trunk using a sequence of transformer blocks.
|
| 51 |
+
self.trunk = nn.Sequential(
|
| 52 |
+
*[
|
| 53 |
+
Block(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values)
|
| 54 |
+
for _ in range(trunk_depth)
|
| 55 |
+
]
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# Normalizations for camera token and trunk output.
|
| 59 |
+
self.token_norm = nn.LayerNorm(dim_in)
|
| 60 |
+
self.trunk_norm = nn.LayerNorm(dim_in)
|
| 61 |
+
|
| 62 |
+
# Learnable empty camera pose token.
|
| 63 |
+
self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
|
| 64 |
+
self.embed_pose = nn.Linear(self.target_dim, dim_in)
|
| 65 |
+
|
| 66 |
+
# Module for producing modulation parameters: shift, scale, and a gate.
|
| 67 |
+
self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
|
| 68 |
+
|
| 69 |
+
# Adaptive layer normalization without affine parameters.
|
| 70 |
+
self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
|
| 71 |
+
self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0)
|
| 72 |
+
|
| 73 |
+
def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
|
| 74 |
+
"""
|
| 75 |
+
Forward pass to predict camera parameters.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
aggregated_tokens_list (list): List of token tensors from the network;
|
| 79 |
+
the last tensor is used for prediction.
|
| 80 |
+
num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
list: A list of predicted camera encodings (post-activation) from each iteration.
|
| 84 |
+
"""
|
| 85 |
+
# Use tokens from the last block for camera prediction.
|
| 86 |
+
tokens = aggregated_tokens_list[-1]
|
| 87 |
+
|
| 88 |
+
# Extract the camera tokens
|
| 89 |
+
pose_tokens = tokens[:, :, 0]
|
| 90 |
+
pose_tokens = self.token_norm(pose_tokens)
|
| 91 |
+
|
| 92 |
+
pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
|
| 93 |
+
return pred_pose_enc_list
|
| 94 |
+
|
| 95 |
+
def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
|
| 96 |
+
"""
|
| 97 |
+
Iteratively refine camera pose predictions.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, S, C].
|
| 101 |
+
num_iterations (int): Number of refinement iterations.
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
list: List of activated camera encodings from each iteration.
|
| 105 |
+
"""
|
| 106 |
+
B, S, C = pose_tokens.shape
|
| 107 |
+
pred_pose_enc = None
|
| 108 |
+
pred_pose_enc_list = []
|
| 109 |
+
|
| 110 |
+
for _ in range(num_iterations):
|
| 111 |
+
# Use a learned empty pose for the first iteration.
|
| 112 |
+
if pred_pose_enc is None:
|
| 113 |
+
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
|
| 114 |
+
else:
|
| 115 |
+
# Detach the previous prediction to avoid backprop through time.
|
| 116 |
+
pred_pose_enc = pred_pose_enc.detach()
|
| 117 |
+
module_input = self.embed_pose(pred_pose_enc)
|
| 118 |
+
|
| 119 |
+
# Generate modulation parameters and split them into shift, scale, and gate components.
|
| 120 |
+
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
|
| 121 |
+
|
| 122 |
+
# Adaptive layer normalization and modulation.
|
| 123 |
+
pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
|
| 124 |
+
pose_tokens_modulated = pose_tokens_modulated + pose_tokens
|
| 125 |
+
|
| 126 |
+
pose_tokens_modulated = self.trunk(pose_tokens_modulated)
|
| 127 |
+
# Compute the delta update for the pose encoding.
|
| 128 |
+
pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
|
| 129 |
+
|
| 130 |
+
if pred_pose_enc is None:
|
| 131 |
+
pred_pose_enc = pred_pose_enc_delta
|
| 132 |
+
else:
|
| 133 |
+
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
|
| 134 |
+
|
| 135 |
+
# Apply final activation functions for translation, quaternion, and field-of-view.
|
| 136 |
+
activated_pose = activate_pose(
|
| 137 |
+
pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act
|
| 138 |
+
)
|
| 139 |
+
pred_pose_enc_list.append(activated_pose)
|
| 140 |
+
|
| 141 |
+
return pred_pose_enc_list
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
| 145 |
+
"""
|
| 146 |
+
Modulate the input tensor using scaling and shifting parameters.
|
| 147 |
+
"""
|
| 148 |
+
# modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
|
| 149 |
+
return x * (1 + scale) + shift
|
capvector-pi05/src/vggt/heads/dpt_head.py
ADDED
|
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# Inspired by https://github.com/DepthAnything/Depth-Anything-V2
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
from typing import List, Dict, Tuple, Union
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from .head_act import activate_head
|
| 18 |
+
from .utils import create_uv_grid, position_grid_to_embed
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class DPTHead(nn.Module):
|
| 22 |
+
"""
|
| 23 |
+
DPT Head for dense prediction tasks.
|
| 24 |
+
|
| 25 |
+
This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
|
| 26 |
+
(https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
|
| 27 |
+
backbone and produces dense predictions by fusing multi-scale features.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
dim_in (int): Input dimension (channels).
|
| 31 |
+
patch_size (int, optional): Patch size. Default is 14.
|
| 32 |
+
output_dim (int, optional): Number of output channels. Default is 4.
|
| 33 |
+
activation (str, optional): Activation type. Default is "inv_log".
|
| 34 |
+
conf_activation (str, optional): Confidence activation type. Default is "expp1".
|
| 35 |
+
features (int, optional): Feature channels for intermediate representations. Default is 256.
|
| 36 |
+
out_channels (List[int], optional): Output channels for each intermediate layer.
|
| 37 |
+
intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
|
| 38 |
+
pos_embed (bool, optional): Whether to use positional embedding. Default is True.
|
| 39 |
+
feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
|
| 40 |
+
down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
dim_in: int,
|
| 46 |
+
patch_size: int = 14,
|
| 47 |
+
output_dim: int = 4,
|
| 48 |
+
activation: str = "inv_log",
|
| 49 |
+
conf_activation: str = "expp1",
|
| 50 |
+
features: int = 256,
|
| 51 |
+
out_channels: List[int] = [256, 512, 1024, 1024],
|
| 52 |
+
intermediate_layer_idx: List[int] = [4, 11, 17, 23],
|
| 53 |
+
pos_embed: bool = True,
|
| 54 |
+
feature_only: bool = False,
|
| 55 |
+
down_ratio: int = 1,
|
| 56 |
+
) -> None:
|
| 57 |
+
super(DPTHead, self).__init__()
|
| 58 |
+
self.patch_size = patch_size
|
| 59 |
+
self.activation = activation
|
| 60 |
+
self.conf_activation = conf_activation
|
| 61 |
+
self.pos_embed = pos_embed
|
| 62 |
+
self.feature_only = feature_only
|
| 63 |
+
self.down_ratio = down_ratio
|
| 64 |
+
self.intermediate_layer_idx = intermediate_layer_idx
|
| 65 |
+
|
| 66 |
+
self.norm = nn.LayerNorm(dim_in)
|
| 67 |
+
|
| 68 |
+
# Projection layers for each output channel from tokens.
|
| 69 |
+
self.projects = nn.ModuleList(
|
| 70 |
+
[nn.Conv2d(in_channels=dim_in, out_channels=oc, kernel_size=1, stride=1, padding=0) for oc in out_channels]
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# Resize layers for upsampling feature maps.
|
| 74 |
+
self.resize_layers = nn.ModuleList(
|
| 75 |
+
[
|
| 76 |
+
nn.ConvTranspose2d(
|
| 77 |
+
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
|
| 78 |
+
),
|
| 79 |
+
nn.ConvTranspose2d(
|
| 80 |
+
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
|
| 81 |
+
),
|
| 82 |
+
nn.Identity(),
|
| 83 |
+
nn.Conv2d(
|
| 84 |
+
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
|
| 85 |
+
),
|
| 86 |
+
]
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
self.scratch = _make_scratch(out_channels, features, expand=False)
|
| 90 |
+
|
| 91 |
+
# Attach additional modules to scratch.
|
| 92 |
+
self.scratch.stem_transpose = None
|
| 93 |
+
self.scratch.refinenet1 = _make_fusion_block(features)
|
| 94 |
+
self.scratch.refinenet2 = _make_fusion_block(features)
|
| 95 |
+
self.scratch.refinenet3 = _make_fusion_block(features)
|
| 96 |
+
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
|
| 97 |
+
|
| 98 |
+
head_features_1 = features
|
| 99 |
+
head_features_2 = 32
|
| 100 |
+
|
| 101 |
+
if feature_only:
|
| 102 |
+
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
|
| 103 |
+
else:
|
| 104 |
+
self.scratch.output_conv1 = nn.Conv2d(
|
| 105 |
+
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
|
| 106 |
+
)
|
| 107 |
+
conv2_in_channels = head_features_1 // 2
|
| 108 |
+
|
| 109 |
+
self.scratch.output_conv2 = nn.Sequential(
|
| 110 |
+
nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
|
| 111 |
+
nn.ReLU(inplace=True),
|
| 112 |
+
nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def forward(
|
| 116 |
+
self,
|
| 117 |
+
aggregated_tokens_list: List[torch.Tensor],
|
| 118 |
+
images: torch.Tensor,
|
| 119 |
+
patch_start_idx: int,
|
| 120 |
+
frames_chunk_size: int = 8,
|
| 121 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 122 |
+
"""
|
| 123 |
+
Forward pass through the DPT head, supports processing by chunking frames.
|
| 124 |
+
Args:
|
| 125 |
+
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
|
| 126 |
+
images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
|
| 127 |
+
patch_start_idx (int): Starting index for patch tokens in the token sequence.
|
| 128 |
+
Used to separate patch tokens from other tokens (e.g., camera or register tokens).
|
| 129 |
+
frames_chunk_size (int, optional): Number of frames to process in each chunk.
|
| 130 |
+
If None or larger than S, all frames are processed at once. Default: 8.
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
Tensor or Tuple[Tensor, Tensor]:
|
| 134 |
+
- If feature_only=True: Feature maps with shape [B, S, C, H, W]
|
| 135 |
+
- Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
|
| 136 |
+
"""
|
| 137 |
+
B, S, _, H, W = images.shape
|
| 138 |
+
|
| 139 |
+
# If frames_chunk_size is not specified or greater than S, process all frames at once
|
| 140 |
+
if frames_chunk_size is None or frames_chunk_size >= S:
|
| 141 |
+
return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
|
| 142 |
+
|
| 143 |
+
# Otherwise, process frames in chunks to manage memory usage
|
| 144 |
+
assert frames_chunk_size > 0
|
| 145 |
+
|
| 146 |
+
# Process frames in batches
|
| 147 |
+
all_preds = []
|
| 148 |
+
all_conf = []
|
| 149 |
+
|
| 150 |
+
for frames_start_idx in range(0, S, frames_chunk_size):
|
| 151 |
+
frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
|
| 152 |
+
|
| 153 |
+
# Process batch of frames
|
| 154 |
+
if self.feature_only:
|
| 155 |
+
chunk_output = self._forward_impl(
|
| 156 |
+
aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
|
| 157 |
+
)
|
| 158 |
+
all_preds.append(chunk_output)
|
| 159 |
+
else:
|
| 160 |
+
chunk_preds, chunk_conf = self._forward_impl(
|
| 161 |
+
aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
|
| 162 |
+
)
|
| 163 |
+
all_preds.append(chunk_preds)
|
| 164 |
+
all_conf.append(chunk_conf)
|
| 165 |
+
|
| 166 |
+
# Concatenate results along the sequence dimension
|
| 167 |
+
if self.feature_only:
|
| 168 |
+
return torch.cat(all_preds, dim=1)
|
| 169 |
+
else:
|
| 170 |
+
return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
|
| 171 |
+
|
| 172 |
+
def _forward_impl(
|
| 173 |
+
self,
|
| 174 |
+
aggregated_tokens_list: List[torch.Tensor],
|
| 175 |
+
images: torch.Tensor,
|
| 176 |
+
patch_start_idx: int,
|
| 177 |
+
frames_start_idx: int = None,
|
| 178 |
+
frames_end_idx: int = None,
|
| 179 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 180 |
+
"""
|
| 181 |
+
Implementation of the forward pass through the DPT head.
|
| 182 |
+
|
| 183 |
+
This method processes a specific chunk of frames from the sequence.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
|
| 187 |
+
images (Tensor): Input images with shape [B, S, 3, H, W].
|
| 188 |
+
patch_start_idx (int): Starting index for patch tokens.
|
| 189 |
+
frames_start_idx (int, optional): Starting index for frames to process.
|
| 190 |
+
frames_end_idx (int, optional): Ending index for frames to process.
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
|
| 194 |
+
"""
|
| 195 |
+
if frames_start_idx is not None and frames_end_idx is not None:
|
| 196 |
+
images = images[:, frames_start_idx:frames_end_idx].contiguous()
|
| 197 |
+
|
| 198 |
+
B, S, _, H, W = images.shape
|
| 199 |
+
|
| 200 |
+
patch_h, patch_w = H // self.patch_size, W // self.patch_size
|
| 201 |
+
|
| 202 |
+
out = []
|
| 203 |
+
dpt_idx = 0
|
| 204 |
+
|
| 205 |
+
for layer_idx in self.intermediate_layer_idx:
|
| 206 |
+
x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
|
| 207 |
+
|
| 208 |
+
# Select frames if processing a chunk
|
| 209 |
+
if frames_start_idx is not None and frames_end_idx is not None:
|
| 210 |
+
x = x[:, frames_start_idx:frames_end_idx]
|
| 211 |
+
|
| 212 |
+
x = x.reshape(B * S, -1, x.shape[-1])
|
| 213 |
+
|
| 214 |
+
x = self.norm(x)
|
| 215 |
+
|
| 216 |
+
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
| 217 |
+
|
| 218 |
+
x = self.projects[dpt_idx](x)
|
| 219 |
+
if self.pos_embed:
|
| 220 |
+
x = self._apply_pos_embed(x, W, H)
|
| 221 |
+
x = self.resize_layers[dpt_idx](x)
|
| 222 |
+
|
| 223 |
+
out.append(x)
|
| 224 |
+
dpt_idx += 1
|
| 225 |
+
|
| 226 |
+
# Fuse features from multiple layers.
|
| 227 |
+
out = self.scratch_forward(out)
|
| 228 |
+
# Interpolate fused output to match target image resolution.
|
| 229 |
+
out = custom_interpolate(
|
| 230 |
+
out,
|
| 231 |
+
(int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)),
|
| 232 |
+
mode="bilinear",
|
| 233 |
+
align_corners=True,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
if self.pos_embed:
|
| 237 |
+
out = self._apply_pos_embed(out, W, H)
|
| 238 |
+
|
| 239 |
+
if self.feature_only:
|
| 240 |
+
return out.view(B, S, *out.shape[1:])
|
| 241 |
+
|
| 242 |
+
out = self.scratch.output_conv2(out)
|
| 243 |
+
preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation)
|
| 244 |
+
|
| 245 |
+
preds = preds.view(B, S, *preds.shape[1:])
|
| 246 |
+
conf = conf.view(B, S, *conf.shape[1:])
|
| 247 |
+
return preds, conf
|
| 248 |
+
|
| 249 |
+
def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
|
| 250 |
+
"""
|
| 251 |
+
Apply positional embedding to tensor x.
|
| 252 |
+
"""
|
| 253 |
+
patch_w = x.shape[-1]
|
| 254 |
+
patch_h = x.shape[-2]
|
| 255 |
+
pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
|
| 256 |
+
pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
|
| 257 |
+
pos_embed = pos_embed * ratio
|
| 258 |
+
pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
|
| 259 |
+
return x + pos_embed
|
| 260 |
+
|
| 261 |
+
def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
|
| 262 |
+
"""
|
| 263 |
+
Forward pass through the fusion blocks.
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
features (List[Tensor]): List of feature maps from different layers.
|
| 267 |
+
|
| 268 |
+
Returns:
|
| 269 |
+
Tensor: Fused feature map.
|
| 270 |
+
"""
|
| 271 |
+
layer_1, layer_2, layer_3, layer_4 = features
|
| 272 |
+
|
| 273 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
| 274 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
| 275 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
| 276 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
| 277 |
+
|
| 278 |
+
out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
| 279 |
+
del layer_4_rn, layer_4
|
| 280 |
+
|
| 281 |
+
out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
|
| 282 |
+
del layer_3_rn, layer_3
|
| 283 |
+
|
| 284 |
+
out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
|
| 285 |
+
del layer_2_rn, layer_2
|
| 286 |
+
|
| 287 |
+
out = self.scratch.refinenet1(out, layer_1_rn)
|
| 288 |
+
del layer_1_rn, layer_1
|
| 289 |
+
|
| 290 |
+
out = self.scratch.output_conv1(out)
|
| 291 |
+
return out
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
################################################################################
|
| 295 |
+
# Modules
|
| 296 |
+
################################################################################
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
|
| 300 |
+
return FeatureFusionBlock(
|
| 301 |
+
features,
|
| 302 |
+
nn.ReLU(inplace=True),
|
| 303 |
+
deconv=False,
|
| 304 |
+
bn=False,
|
| 305 |
+
expand=False,
|
| 306 |
+
align_corners=True,
|
| 307 |
+
size=size,
|
| 308 |
+
has_residual=has_residual,
|
| 309 |
+
groups=groups,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
|
| 314 |
+
scratch = nn.Module()
|
| 315 |
+
out_shape1 = out_shape
|
| 316 |
+
out_shape2 = out_shape
|
| 317 |
+
out_shape3 = out_shape
|
| 318 |
+
if len(in_shape) >= 4:
|
| 319 |
+
out_shape4 = out_shape
|
| 320 |
+
|
| 321 |
+
if expand:
|
| 322 |
+
out_shape1 = out_shape
|
| 323 |
+
out_shape2 = out_shape * 2
|
| 324 |
+
out_shape3 = out_shape * 4
|
| 325 |
+
if len(in_shape) >= 4:
|
| 326 |
+
out_shape4 = out_shape * 8
|
| 327 |
+
|
| 328 |
+
scratch.layer1_rn = nn.Conv2d(
|
| 329 |
+
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 330 |
+
)
|
| 331 |
+
scratch.layer2_rn = nn.Conv2d(
|
| 332 |
+
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 333 |
+
)
|
| 334 |
+
scratch.layer3_rn = nn.Conv2d(
|
| 335 |
+
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 336 |
+
)
|
| 337 |
+
if len(in_shape) >= 4:
|
| 338 |
+
scratch.layer4_rn = nn.Conv2d(
|
| 339 |
+
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 340 |
+
)
|
| 341 |
+
return scratch
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
class ResidualConvUnit(nn.Module):
|
| 345 |
+
"""Residual convolution module."""
|
| 346 |
+
|
| 347 |
+
def __init__(self, features, activation, bn, groups=1):
|
| 348 |
+
"""Init.
|
| 349 |
+
|
| 350 |
+
Args:
|
| 351 |
+
features (int): number of features
|
| 352 |
+
"""
|
| 353 |
+
super().__init__()
|
| 354 |
+
|
| 355 |
+
self.bn = bn
|
| 356 |
+
self.groups = groups
|
| 357 |
+
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
| 358 |
+
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
| 359 |
+
|
| 360 |
+
self.norm1 = None
|
| 361 |
+
self.norm2 = None
|
| 362 |
+
|
| 363 |
+
self.activation = activation
|
| 364 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 365 |
+
|
| 366 |
+
def forward(self, x):
|
| 367 |
+
"""Forward pass.
|
| 368 |
+
|
| 369 |
+
Args:
|
| 370 |
+
x (tensor): input
|
| 371 |
+
|
| 372 |
+
Returns:
|
| 373 |
+
tensor: output
|
| 374 |
+
"""
|
| 375 |
+
|
| 376 |
+
out = self.activation(x)
|
| 377 |
+
out = self.conv1(out)
|
| 378 |
+
if self.norm1 is not None:
|
| 379 |
+
out = self.norm1(out)
|
| 380 |
+
|
| 381 |
+
out = self.activation(out)
|
| 382 |
+
out = self.conv2(out)
|
| 383 |
+
if self.norm2 is not None:
|
| 384 |
+
out = self.norm2(out)
|
| 385 |
+
|
| 386 |
+
return self.skip_add.add(out, x)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
class FeatureFusionBlock(nn.Module):
|
| 390 |
+
"""Feature fusion block."""
|
| 391 |
+
|
| 392 |
+
def __init__(
|
| 393 |
+
self,
|
| 394 |
+
features,
|
| 395 |
+
activation,
|
| 396 |
+
deconv=False,
|
| 397 |
+
bn=False,
|
| 398 |
+
expand=False,
|
| 399 |
+
align_corners=True,
|
| 400 |
+
size=None,
|
| 401 |
+
has_residual=True,
|
| 402 |
+
groups=1,
|
| 403 |
+
):
|
| 404 |
+
"""Init.
|
| 405 |
+
|
| 406 |
+
Args:
|
| 407 |
+
features (int): number of features
|
| 408 |
+
"""
|
| 409 |
+
super(FeatureFusionBlock, self).__init__()
|
| 410 |
+
|
| 411 |
+
self.deconv = deconv
|
| 412 |
+
self.align_corners = align_corners
|
| 413 |
+
self.groups = groups
|
| 414 |
+
self.expand = expand
|
| 415 |
+
out_features = features
|
| 416 |
+
if self.expand == True:
|
| 417 |
+
out_features = features // 2
|
| 418 |
+
|
| 419 |
+
self.out_conv = nn.Conv2d(
|
| 420 |
+
features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
if has_residual:
|
| 424 |
+
self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
|
| 425 |
+
|
| 426 |
+
self.has_residual = has_residual
|
| 427 |
+
self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
|
| 428 |
+
|
| 429 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 430 |
+
self.size = size
|
| 431 |
+
|
| 432 |
+
def forward(self, *xs, size=None):
|
| 433 |
+
"""Forward pass.
|
| 434 |
+
|
| 435 |
+
Returns:
|
| 436 |
+
tensor: output
|
| 437 |
+
"""
|
| 438 |
+
output = xs[0]
|
| 439 |
+
|
| 440 |
+
if self.has_residual:
|
| 441 |
+
res = self.resConfUnit1(xs[1])
|
| 442 |
+
output = self.skip_add.add(output, res)
|
| 443 |
+
|
| 444 |
+
output = self.resConfUnit2(output)
|
| 445 |
+
|
| 446 |
+
if (size is None) and (self.size is None):
|
| 447 |
+
modifier = {"scale_factor": 2}
|
| 448 |
+
elif size is None:
|
| 449 |
+
modifier = {"size": self.size}
|
| 450 |
+
else:
|
| 451 |
+
modifier = {"size": size}
|
| 452 |
+
|
| 453 |
+
output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
|
| 454 |
+
output = self.out_conv(output)
|
| 455 |
+
|
| 456 |
+
return output
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def custom_interpolate(
|
| 460 |
+
x: torch.Tensor,
|
| 461 |
+
size: Tuple[int, int] = None,
|
| 462 |
+
scale_factor: float = None,
|
| 463 |
+
mode: str = "bilinear",
|
| 464 |
+
align_corners: bool = True,
|
| 465 |
+
) -> torch.Tensor:
|
| 466 |
+
"""
|
| 467 |
+
Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
|
| 468 |
+
"""
|
| 469 |
+
if size is None:
|
| 470 |
+
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
|
| 471 |
+
|
| 472 |
+
INT_MAX = 1610612736
|
| 473 |
+
|
| 474 |
+
input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
|
| 475 |
+
|
| 476 |
+
if input_elements > INT_MAX:
|
| 477 |
+
chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
|
| 478 |
+
interpolated_chunks = [
|
| 479 |
+
nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
|
| 480 |
+
]
|
| 481 |
+
x = torch.cat(interpolated_chunks, dim=0)
|
| 482 |
+
return x.contiguous()
|
| 483 |
+
else:
|
| 484 |
+
return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
|
capvector-pi05/src/vggt/heads/head_act.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
|
| 13 |
+
"""
|
| 14 |
+
Activate pose parameters with specified activation functions.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
|
| 18 |
+
trans_act: Activation type for translation component
|
| 19 |
+
quat_act: Activation type for quaternion component
|
| 20 |
+
fl_act: Activation type for focal length component
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
Activated pose parameters tensor
|
| 24 |
+
"""
|
| 25 |
+
T = pred_pose_enc[..., :3]
|
| 26 |
+
quat = pred_pose_enc[..., 3:7]
|
| 27 |
+
fl = pred_pose_enc[..., 7:] # or fov
|
| 28 |
+
|
| 29 |
+
T = base_pose_act(T, trans_act)
|
| 30 |
+
quat = base_pose_act(quat, quat_act)
|
| 31 |
+
fl = base_pose_act(fl, fl_act) # or fov
|
| 32 |
+
|
| 33 |
+
pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
|
| 34 |
+
|
| 35 |
+
return pred_pose_enc
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def base_pose_act(pose_enc, act_type="linear"):
|
| 39 |
+
"""
|
| 40 |
+
Apply basic activation function to pose parameters.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
pose_enc: Tensor containing encoded pose parameters
|
| 44 |
+
act_type: Activation type ("linear", "inv_log", "exp", "relu")
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Activated pose parameters
|
| 48 |
+
"""
|
| 49 |
+
if act_type == "linear":
|
| 50 |
+
return pose_enc
|
| 51 |
+
elif act_type == "inv_log":
|
| 52 |
+
return inverse_log_transform(pose_enc)
|
| 53 |
+
elif act_type == "exp":
|
| 54 |
+
return torch.exp(pose_enc)
|
| 55 |
+
elif act_type == "relu":
|
| 56 |
+
return F.relu(pose_enc)
|
| 57 |
+
else:
|
| 58 |
+
raise ValueError(f"Unknown act_type: {act_type}")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def activate_head(out, activation="norm_exp", conf_activation="expp1"):
|
| 62 |
+
"""
|
| 63 |
+
Process network output to extract 3D points and confidence values.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
out: Network output tensor (B, C, H, W)
|
| 67 |
+
activation: Activation type for 3D points
|
| 68 |
+
conf_activation: Activation type for confidence values
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
Tuple of (3D points tensor, confidence tensor)
|
| 72 |
+
"""
|
| 73 |
+
# Move channels from last dim to the 4th dimension => (B, H, W, C)
|
| 74 |
+
fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
|
| 75 |
+
|
| 76 |
+
# Split into xyz (first C-1 channels) and confidence (last channel)
|
| 77 |
+
xyz = fmap[:, :, :, :-1]
|
| 78 |
+
conf = fmap[:, :, :, -1]
|
| 79 |
+
|
| 80 |
+
if activation == "norm_exp":
|
| 81 |
+
d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
| 82 |
+
xyz_normed = xyz / d
|
| 83 |
+
pts3d = xyz_normed * torch.expm1(d)
|
| 84 |
+
elif activation == "norm":
|
| 85 |
+
pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
|
| 86 |
+
elif activation == "exp":
|
| 87 |
+
pts3d = torch.exp(xyz)
|
| 88 |
+
elif activation == "relu":
|
| 89 |
+
pts3d = F.relu(xyz)
|
| 90 |
+
elif activation == "inv_log":
|
| 91 |
+
pts3d = inverse_log_transform(xyz)
|
| 92 |
+
elif activation == "xy_inv_log":
|
| 93 |
+
xy, z = xyz.split([2, 1], dim=-1)
|
| 94 |
+
z = inverse_log_transform(z)
|
| 95 |
+
pts3d = torch.cat([xy * z, z], dim=-1)
|
| 96 |
+
elif activation == "sigmoid":
|
| 97 |
+
pts3d = torch.sigmoid(xyz)
|
| 98 |
+
elif activation == "linear":
|
| 99 |
+
pts3d = xyz
|
| 100 |
+
else:
|
| 101 |
+
raise ValueError(f"Unknown activation: {activation}")
|
| 102 |
+
|
| 103 |
+
if conf_activation == "expp1":
|
| 104 |
+
conf_out = 1 + conf.exp()
|
| 105 |
+
elif conf_activation == "expp0":
|
| 106 |
+
conf_out = conf.exp()
|
| 107 |
+
elif conf_activation == "sigmoid":
|
| 108 |
+
conf_out = torch.sigmoid(conf)
|
| 109 |
+
else:
|
| 110 |
+
raise ValueError(f"Unknown conf_activation: {conf_activation}")
|
| 111 |
+
|
| 112 |
+
return pts3d, conf_out
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def inverse_log_transform(y):
|
| 116 |
+
"""
|
| 117 |
+
Apply inverse log transform: sign(y) * (exp(|y|) - 1)
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
y: Input tensor
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
Transformed tensor
|
| 124 |
+
"""
|
| 125 |
+
return torch.sign(y) * (torch.expm1(torch.abs(y)))
|
capvector-pi05/src/vggt/heads/track_head.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from .dpt_head import DPTHead
|
| 9 |
+
from .track_modules.base_track_predictor import BaseTrackerPredictor
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TrackHead(nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking.
|
| 15 |
+
The tracking is performed iteratively, refining predictions over multiple iterations.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
dim_in,
|
| 21 |
+
patch_size=14,
|
| 22 |
+
features=128,
|
| 23 |
+
iters=4,
|
| 24 |
+
predict_conf=True,
|
| 25 |
+
stride=2,
|
| 26 |
+
corr_levels=7,
|
| 27 |
+
corr_radius=4,
|
| 28 |
+
hidden_size=384,
|
| 29 |
+
):
|
| 30 |
+
"""
|
| 31 |
+
Initialize the TrackHead module.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
dim_in (int): Input dimension of tokens from the backbone.
|
| 35 |
+
patch_size (int): Size of image patches used in the vision transformer.
|
| 36 |
+
features (int): Number of feature channels in the feature extractor output.
|
| 37 |
+
iters (int): Number of refinement iterations for tracking predictions.
|
| 38 |
+
predict_conf (bool): Whether to predict confidence scores for tracked points.
|
| 39 |
+
stride (int): Stride value for the tracker predictor.
|
| 40 |
+
corr_levels (int): Number of correlation pyramid levels
|
| 41 |
+
corr_radius (int): Radius for correlation computation, controlling the search area.
|
| 42 |
+
hidden_size (int): Size of hidden layers in the tracker network.
|
| 43 |
+
"""
|
| 44 |
+
super().__init__()
|
| 45 |
+
|
| 46 |
+
self.patch_size = patch_size
|
| 47 |
+
|
| 48 |
+
# Feature extractor based on DPT architecture
|
| 49 |
+
# Processes tokens into feature maps for tracking
|
| 50 |
+
self.feature_extractor = DPTHead(
|
| 51 |
+
dim_in=dim_in,
|
| 52 |
+
patch_size=patch_size,
|
| 53 |
+
features=features,
|
| 54 |
+
feature_only=True, # Only output features, no activation
|
| 55 |
+
down_ratio=2, # Reduces spatial dimensions by factor of 2
|
| 56 |
+
pos_embed=False,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# Tracker module that predicts point trajectories
|
| 60 |
+
# Takes feature maps and predicts coordinates and visibility
|
| 61 |
+
self.tracker = BaseTrackerPredictor(
|
| 62 |
+
latent_dim=features, # Match the output_dim of feature extractor
|
| 63 |
+
predict_conf=predict_conf,
|
| 64 |
+
stride=stride,
|
| 65 |
+
corr_levels=corr_levels,
|
| 66 |
+
corr_radius=corr_radius,
|
| 67 |
+
hidden_size=hidden_size,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
self.iters = iters
|
| 71 |
+
|
| 72 |
+
def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None):
|
| 73 |
+
"""
|
| 74 |
+
Forward pass of the TrackHead.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
aggregated_tokens_list (list): List of aggregated tokens from the backbone.
|
| 78 |
+
images (torch.Tensor): Input images of shape (B, S, C, H, W) where:
|
| 79 |
+
B = batch size, S = sequence length.
|
| 80 |
+
patch_start_idx (int): Starting index for patch tokens.
|
| 81 |
+
query_points (torch.Tensor, optional): Initial query points to track.
|
| 82 |
+
If None, points are initialized by the tracker.
|
| 83 |
+
iters (int, optional): Number of refinement iterations. If None, uses self.iters.
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
tuple:
|
| 87 |
+
- coord_preds (torch.Tensor): Predicted coordinates for tracked points.
|
| 88 |
+
- vis_scores (torch.Tensor): Visibility scores for tracked points.
|
| 89 |
+
- conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True).
|
| 90 |
+
"""
|
| 91 |
+
B, S, _, H, W = images.shape
|
| 92 |
+
|
| 93 |
+
# Extract features from tokens
|
| 94 |
+
# feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2
|
| 95 |
+
feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx)
|
| 96 |
+
|
| 97 |
+
# Use default iterations if not specified
|
| 98 |
+
if iters is None:
|
| 99 |
+
iters = self.iters
|
| 100 |
+
|
| 101 |
+
# Perform tracking using the extracted features
|
| 102 |
+
coord_preds, vis_scores, conf_scores = self.tracker(query_points=query_points, fmaps=feature_maps, iters=iters)
|
| 103 |
+
|
| 104 |
+
return coord_preds, vis_scores, conf_scores
|
capvector-pi05/src/vggt/heads/track_modules/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
capvector-pi05/src/vggt/heads/track_modules/base_track_predictor.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from einops import rearrange, repeat
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
from .blocks import EfficientUpdateFormer, CorrBlock
|
| 13 |
+
from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed
|
| 14 |
+
from .modules import Mlp
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class BaseTrackerPredictor(nn.Module):
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
stride=1,
|
| 21 |
+
corr_levels=5,
|
| 22 |
+
corr_radius=4,
|
| 23 |
+
latent_dim=128,
|
| 24 |
+
hidden_size=384,
|
| 25 |
+
use_spaceatt=True,
|
| 26 |
+
depth=6,
|
| 27 |
+
max_scale=518,
|
| 28 |
+
predict_conf=True,
|
| 29 |
+
):
|
| 30 |
+
super(BaseTrackerPredictor, self).__init__()
|
| 31 |
+
"""
|
| 32 |
+
The base template to create a track predictor
|
| 33 |
+
|
| 34 |
+
Modified from https://github.com/facebookresearch/co-tracker/
|
| 35 |
+
and https://github.com/facebookresearch/vggsfm
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
self.stride = stride
|
| 39 |
+
self.latent_dim = latent_dim
|
| 40 |
+
self.corr_levels = corr_levels
|
| 41 |
+
self.corr_radius = corr_radius
|
| 42 |
+
self.hidden_size = hidden_size
|
| 43 |
+
self.max_scale = max_scale
|
| 44 |
+
self.predict_conf = predict_conf
|
| 45 |
+
|
| 46 |
+
self.flows_emb_dim = latent_dim // 2
|
| 47 |
+
|
| 48 |
+
self.corr_mlp = Mlp(
|
| 49 |
+
in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2,
|
| 50 |
+
hidden_features=self.hidden_size,
|
| 51 |
+
out_features=self.latent_dim,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4
|
| 55 |
+
|
| 56 |
+
self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim))
|
| 57 |
+
|
| 58 |
+
space_depth = depth if use_spaceatt else 0
|
| 59 |
+
time_depth = depth
|
| 60 |
+
|
| 61 |
+
self.updateformer = EfficientUpdateFormer(
|
| 62 |
+
space_depth=space_depth,
|
| 63 |
+
time_depth=time_depth,
|
| 64 |
+
input_dim=self.transformer_dim,
|
| 65 |
+
hidden_size=self.hidden_size,
|
| 66 |
+
output_dim=self.latent_dim + 2,
|
| 67 |
+
mlp_ratio=4.0,
|
| 68 |
+
add_space_attn=use_spaceatt,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
self.fmap_norm = nn.LayerNorm(self.latent_dim)
|
| 72 |
+
self.ffeat_norm = nn.GroupNorm(1, self.latent_dim)
|
| 73 |
+
|
| 74 |
+
# A linear layer to update track feats at each iteration
|
| 75 |
+
self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU())
|
| 76 |
+
|
| 77 |
+
self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
|
| 78 |
+
|
| 79 |
+
if predict_conf:
|
| 80 |
+
self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
|
| 81 |
+
|
| 82 |
+
def forward(self, query_points, fmaps=None, iters=6, return_feat=False, down_ratio=1, apply_sigmoid=True):
|
| 83 |
+
"""
|
| 84 |
+
query_points: B x N x 2, the number of batches, tracks, and xy
|
| 85 |
+
fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension.
|
| 86 |
+
note HH and WW is the size of feature maps instead of original images
|
| 87 |
+
"""
|
| 88 |
+
B, N, D = query_points.shape
|
| 89 |
+
B, S, C, HH, WW = fmaps.shape
|
| 90 |
+
|
| 91 |
+
assert D == 2, "Input points must be 2D coordinates"
|
| 92 |
+
|
| 93 |
+
# apply a layernorm to fmaps here
|
| 94 |
+
fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2))
|
| 95 |
+
fmaps = fmaps.permute(0, 1, 4, 2, 3)
|
| 96 |
+
|
| 97 |
+
# Scale the input query_points because we may downsample the images
|
| 98 |
+
# by down_ratio or self.stride
|
| 99 |
+
# e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map
|
| 100 |
+
# its query_points should be query_points/4
|
| 101 |
+
if down_ratio > 1:
|
| 102 |
+
query_points = query_points / float(down_ratio)
|
| 103 |
+
|
| 104 |
+
query_points = query_points / float(self.stride)
|
| 105 |
+
|
| 106 |
+
# Init with coords as the query points
|
| 107 |
+
# It means the search will start from the position of query points at the reference frames
|
| 108 |
+
coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)
|
| 109 |
+
|
| 110 |
+
# Sample/extract the features of the query points in the query frame
|
| 111 |
+
query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])
|
| 112 |
+
|
| 113 |
+
# init track feats by query feats
|
| 114 |
+
track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C
|
| 115 |
+
# back up the init coords
|
| 116 |
+
coords_backup = coords.clone()
|
| 117 |
+
|
| 118 |
+
fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius)
|
| 119 |
+
|
| 120 |
+
coord_preds = []
|
| 121 |
+
|
| 122 |
+
# Iterative Refinement
|
| 123 |
+
for _ in range(iters):
|
| 124 |
+
# Detach the gradients from the last iteration
|
| 125 |
+
# (in my experience, not very important for performance)
|
| 126 |
+
coords = coords.detach()
|
| 127 |
+
|
| 128 |
+
fcorrs = fcorr_fn.corr_sample(track_feats, coords)
|
| 129 |
+
|
| 130 |
+
corr_dim = fcorrs.shape[3]
|
| 131 |
+
fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim)
|
| 132 |
+
fcorrs_ = self.corr_mlp(fcorrs_)
|
| 133 |
+
|
| 134 |
+
# Movement of current coords relative to query points
|
| 135 |
+
flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
|
| 136 |
+
|
| 137 |
+
flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)
|
| 138 |
+
|
| 139 |
+
# (In my trials, it is also okay to just add the flows_emb instead of concat)
|
| 140 |
+
flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1)
|
| 141 |
+
|
| 142 |
+
track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
|
| 143 |
+
|
| 144 |
+
# Concatenate them as the input for the transformers
|
| 145 |
+
transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)
|
| 146 |
+
|
| 147 |
+
# 2D positional embed
|
| 148 |
+
# TODO: this can be much simplified
|
| 149 |
+
pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device)
|
| 150 |
+
sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0])
|
| 151 |
+
|
| 152 |
+
sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1)
|
| 153 |
+
|
| 154 |
+
x = transformer_input + sampled_pos_emb
|
| 155 |
+
|
| 156 |
+
# Add the query ref token to the track feats
|
| 157 |
+
query_ref_token = torch.cat(
|
| 158 |
+
[self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1
|
| 159 |
+
)
|
| 160 |
+
x = x + query_ref_token.to(x.device).to(x.dtype)
|
| 161 |
+
|
| 162 |
+
# B, N, S, C
|
| 163 |
+
x = rearrange(x, "(b n) s d -> b n s d", b=B)
|
| 164 |
+
|
| 165 |
+
# Compute the delta coordinates and delta track features
|
| 166 |
+
delta, _ = self.updateformer(x)
|
| 167 |
+
|
| 168 |
+
# BN, S, C
|
| 169 |
+
delta = rearrange(delta, " b n s d -> (b n) s d", b=B)
|
| 170 |
+
delta_coords_ = delta[:, :, :2]
|
| 171 |
+
delta_feats_ = delta[:, :, 2:]
|
| 172 |
+
|
| 173 |
+
track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim)
|
| 174 |
+
delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
|
| 175 |
+
|
| 176 |
+
# Update the track features
|
| 177 |
+
track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_
|
| 178 |
+
|
| 179 |
+
track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC
|
| 180 |
+
|
| 181 |
+
# B x S x N x 2
|
| 182 |
+
coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
|
| 183 |
+
|
| 184 |
+
# Force coord0 as query
|
| 185 |
+
# because we assume the query points should not be changed
|
| 186 |
+
coords[:, 0] = coords_backup[:, 0]
|
| 187 |
+
|
| 188 |
+
# The predicted tracks are in the original image scale
|
| 189 |
+
if down_ratio > 1:
|
| 190 |
+
coord_preds.append(coords * self.stride * down_ratio)
|
| 191 |
+
else:
|
| 192 |
+
coord_preds.append(coords * self.stride)
|
| 193 |
+
|
| 194 |
+
# B, S, N
|
| 195 |
+
vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
|
| 196 |
+
if apply_sigmoid:
|
| 197 |
+
vis_e = torch.sigmoid(vis_e)
|
| 198 |
+
|
| 199 |
+
if self.predict_conf:
|
| 200 |
+
conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
|
| 201 |
+
if apply_sigmoid:
|
| 202 |
+
conf_e = torch.sigmoid(conf_e)
|
| 203 |
+
else:
|
| 204 |
+
conf_e = None
|
| 205 |
+
|
| 206 |
+
if return_feat:
|
| 207 |
+
return coord_preds, vis_e, track_feats, query_track_feat, conf_e
|
| 208 |
+
else:
|
| 209 |
+
return coord_preds, vis_e, conf_e
|
capvector-pi05/src/vggt/heads/track_modules/blocks.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# Modified from https://github.com/facebookresearch/co-tracker/
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
from .utils import bilinear_sampler
|
| 16 |
+
from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class EfficientUpdateFormer(nn.Module):
|
| 20 |
+
"""
|
| 21 |
+
Transformer model that updates track estimates.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
space_depth=6,
|
| 27 |
+
time_depth=6,
|
| 28 |
+
input_dim=320,
|
| 29 |
+
hidden_size=384,
|
| 30 |
+
num_heads=8,
|
| 31 |
+
output_dim=130,
|
| 32 |
+
mlp_ratio=4.0,
|
| 33 |
+
add_space_attn=True,
|
| 34 |
+
num_virtual_tracks=64,
|
| 35 |
+
):
|
| 36 |
+
super().__init__()
|
| 37 |
+
|
| 38 |
+
self.out_channels = 2
|
| 39 |
+
self.num_heads = num_heads
|
| 40 |
+
self.hidden_size = hidden_size
|
| 41 |
+
self.add_space_attn = add_space_attn
|
| 42 |
+
|
| 43 |
+
# Add input LayerNorm before linear projection
|
| 44 |
+
self.input_norm = nn.LayerNorm(input_dim)
|
| 45 |
+
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
|
| 46 |
+
|
| 47 |
+
# Add output LayerNorm before final projection
|
| 48 |
+
self.output_norm = nn.LayerNorm(hidden_size)
|
| 49 |
+
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
|
| 50 |
+
self.num_virtual_tracks = num_virtual_tracks
|
| 51 |
+
|
| 52 |
+
if self.add_space_attn:
|
| 53 |
+
self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size))
|
| 54 |
+
else:
|
| 55 |
+
self.virual_tracks = None
|
| 56 |
+
|
| 57 |
+
self.time_blocks = nn.ModuleList(
|
| 58 |
+
[
|
| 59 |
+
AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention)
|
| 60 |
+
for _ in range(time_depth)
|
| 61 |
+
]
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
if add_space_attn:
|
| 65 |
+
self.space_virtual_blocks = nn.ModuleList(
|
| 66 |
+
[
|
| 67 |
+
AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention)
|
| 68 |
+
for _ in range(space_depth)
|
| 69 |
+
]
|
| 70 |
+
)
|
| 71 |
+
self.space_point2virtual_blocks = nn.ModuleList(
|
| 72 |
+
[CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
|
| 73 |
+
)
|
| 74 |
+
self.space_virtual2point_blocks = nn.ModuleList(
|
| 75 |
+
[CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
|
| 76 |
+
)
|
| 77 |
+
assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
|
| 78 |
+
self.initialize_weights()
|
| 79 |
+
|
| 80 |
+
def initialize_weights(self):
|
| 81 |
+
def _basic_init(module):
|
| 82 |
+
if isinstance(module, nn.Linear):
|
| 83 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 84 |
+
if module.bias is not None:
|
| 85 |
+
nn.init.constant_(module.bias, 0)
|
| 86 |
+
torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
|
| 87 |
+
|
| 88 |
+
self.apply(_basic_init)
|
| 89 |
+
|
| 90 |
+
def forward(self, input_tensor, mask=None):
|
| 91 |
+
# Apply input LayerNorm
|
| 92 |
+
input_tensor = self.input_norm(input_tensor)
|
| 93 |
+
tokens = self.input_transform(input_tensor)
|
| 94 |
+
|
| 95 |
+
init_tokens = tokens
|
| 96 |
+
|
| 97 |
+
B, _, T, _ = tokens.shape
|
| 98 |
+
|
| 99 |
+
if self.add_space_attn:
|
| 100 |
+
virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
|
| 101 |
+
tokens = torch.cat([tokens, virtual_tokens], dim=1)
|
| 102 |
+
|
| 103 |
+
_, N, _, _ = tokens.shape
|
| 104 |
+
|
| 105 |
+
j = 0
|
| 106 |
+
for i in range(len(self.time_blocks)):
|
| 107 |
+
time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
|
| 108 |
+
|
| 109 |
+
time_tokens = self.time_blocks[i](time_tokens)
|
| 110 |
+
|
| 111 |
+
tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
|
| 112 |
+
if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0):
|
| 113 |
+
space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C
|
| 114 |
+
point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
|
| 115 |
+
virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
|
| 116 |
+
|
| 117 |
+
virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask)
|
| 118 |
+
virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
|
| 119 |
+
point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask)
|
| 120 |
+
|
| 121 |
+
space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
|
| 122 |
+
tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C
|
| 123 |
+
j += 1
|
| 124 |
+
|
| 125 |
+
if self.add_space_attn:
|
| 126 |
+
tokens = tokens[:, : N - self.num_virtual_tracks]
|
| 127 |
+
|
| 128 |
+
tokens = tokens + init_tokens
|
| 129 |
+
|
| 130 |
+
# Apply output LayerNorm before final projection
|
| 131 |
+
tokens = self.output_norm(tokens)
|
| 132 |
+
flow = self.flow_head(tokens)
|
| 133 |
+
|
| 134 |
+
return flow, None
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class CorrBlock:
|
| 138 |
+
def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"):
|
| 139 |
+
"""
|
| 140 |
+
Build a pyramid of feature maps from the input.
|
| 141 |
+
|
| 142 |
+
fmaps: Tensor (B, S, C, H, W)
|
| 143 |
+
num_levels: number of pyramid levels (each downsampled by factor 2)
|
| 144 |
+
radius: search radius for sampling correlation
|
| 145 |
+
multiple_track_feats: if True, split the target features per pyramid level
|
| 146 |
+
padding_mode: passed to grid_sample / bilinear_sampler
|
| 147 |
+
"""
|
| 148 |
+
B, S, C, H, W = fmaps.shape
|
| 149 |
+
self.S, self.C, self.H, self.W = S, C, H, W
|
| 150 |
+
self.num_levels = num_levels
|
| 151 |
+
self.radius = radius
|
| 152 |
+
self.padding_mode = padding_mode
|
| 153 |
+
self.multiple_track_feats = multiple_track_feats
|
| 154 |
+
|
| 155 |
+
# Build pyramid: each level is half the spatial resolution of the previous
|
| 156 |
+
self.fmaps_pyramid = [fmaps] # level 0 is full resolution
|
| 157 |
+
current_fmaps = fmaps
|
| 158 |
+
for i in range(num_levels - 1):
|
| 159 |
+
B, S, C, H, W = current_fmaps.shape
|
| 160 |
+
# Merge batch & sequence dimensions
|
| 161 |
+
current_fmaps = current_fmaps.reshape(B * S, C, H, W)
|
| 162 |
+
# Avg pool down by factor 2
|
| 163 |
+
current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2)
|
| 164 |
+
_, _, H_new, W_new = current_fmaps.shape
|
| 165 |
+
current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new)
|
| 166 |
+
self.fmaps_pyramid.append(current_fmaps)
|
| 167 |
+
|
| 168 |
+
# Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling.
|
| 169 |
+
# This grid is added to the (scaled) coordinate centroids.
|
| 170 |
+
r = self.radius
|
| 171 |
+
dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
|
| 172 |
+
dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
|
| 173 |
+
# delta: for every (dy,dx) displacement (i.e. Δx, Δy)
|
| 174 |
+
self.delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), dim=-1) # shape: (2r+1, 2r+1, 2)
|
| 175 |
+
|
| 176 |
+
def corr_sample(self, targets, coords):
|
| 177 |
+
"""
|
| 178 |
+
Instead of storing the entire correlation pyramid, we compute each level's correlation
|
| 179 |
+
volume, sample it immediately, then discard it. This saves GPU memory.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
targets: Tensor (B, S, N, C) — features for the current targets.
|
| 183 |
+
coords: Tensor (B, S, N, 2) — coordinates at full resolution.
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations)
|
| 187 |
+
"""
|
| 188 |
+
B, S, N, C = targets.shape
|
| 189 |
+
|
| 190 |
+
# If you have multiple track features, split them per level.
|
| 191 |
+
if self.multiple_track_feats:
|
| 192 |
+
targets_split = torch.split(targets, C // self.num_levels, dim=-1)
|
| 193 |
+
|
| 194 |
+
out_pyramid = []
|
| 195 |
+
for i, fmaps in enumerate(self.fmaps_pyramid):
|
| 196 |
+
# Get current spatial resolution H, W for this pyramid level.
|
| 197 |
+
B, S, C, H, W = fmaps.shape
|
| 198 |
+
# Reshape feature maps for correlation computation:
|
| 199 |
+
# fmap2s: (B, S, C, H*W)
|
| 200 |
+
fmap2s = fmaps.view(B, S, C, H * W)
|
| 201 |
+
# Choose appropriate target features.
|
| 202 |
+
fmap1 = targets_split[i] if self.multiple_track_feats else targets # shape: (B, S, N, C)
|
| 203 |
+
|
| 204 |
+
# Compute correlation directly
|
| 205 |
+
corrs = compute_corr_level(fmap1, fmap2s, C)
|
| 206 |
+
corrs = corrs.view(B, S, N, H, W)
|
| 207 |
+
|
| 208 |
+
# Prepare sampling grid:
|
| 209 |
+
# Scale down the coordinates for the current level.
|
| 210 |
+
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i)
|
| 211 |
+
# Make sure our precomputed delta grid is on the same device/dtype.
|
| 212 |
+
delta_lvl = self.delta.to(coords.device).to(coords.dtype)
|
| 213 |
+
# Now the grid for grid_sample is:
|
| 214 |
+
# coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid)
|
| 215 |
+
coords_lvl = centroid_lvl + delta_lvl.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2)
|
| 216 |
+
|
| 217 |
+
# Sample from the correlation volume using bilinear interpolation.
|
| 218 |
+
# We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target.
|
| 219 |
+
corrs_sampled = bilinear_sampler(
|
| 220 |
+
corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode
|
| 221 |
+
)
|
| 222 |
+
# The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims.
|
| 223 |
+
corrs_sampled = corrs_sampled.view(B, S, N, -1) # Now shape: (B, S, N, (2r+1)^2)
|
| 224 |
+
out_pyramid.append(corrs_sampled)
|
| 225 |
+
|
| 226 |
+
# Concatenate all levels along the last dimension.
|
| 227 |
+
out = torch.cat(out_pyramid, dim=-1).contiguous()
|
| 228 |
+
return out
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def compute_corr_level(fmap1, fmap2s, C):
|
| 232 |
+
# fmap1: (B, S, N, C)
|
| 233 |
+
# fmap2s: (B, S, C, H*W)
|
| 234 |
+
corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W)
|
| 235 |
+
corrs = corrs.view(fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1) # (B, S, N, H*W)
|
| 236 |
+
return corrs / math.sqrt(C)
|
capvector-pi05/src/vggt/heads/track_modules/modules.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from functools import partial
|
| 12 |
+
from typing import Callable
|
| 13 |
+
import collections
|
| 14 |
+
from torch import Tensor
|
| 15 |
+
from itertools import repeat
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# From PyTorch internals
|
| 19 |
+
def _ntuple(n):
|
| 20 |
+
def parse(x):
|
| 21 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
| 22 |
+
return tuple(x)
|
| 23 |
+
return tuple(repeat(x, n))
|
| 24 |
+
|
| 25 |
+
return parse
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def exists(val):
|
| 29 |
+
return val is not None
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def default(val, d):
|
| 33 |
+
return val if exists(val) else d
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
to_2tuple = _ntuple(2)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ResidualBlock(nn.Module):
|
| 40 |
+
"""
|
| 41 |
+
ResidualBlock: construct a block of two conv layers with residual connections
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3):
|
| 45 |
+
super(ResidualBlock, self).__init__()
|
| 46 |
+
|
| 47 |
+
self.conv1 = nn.Conv2d(
|
| 48 |
+
in_planes, planes, kernel_size=kernel_size, padding=1, stride=stride, padding_mode="zeros"
|
| 49 |
+
)
|
| 50 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, padding=1, padding_mode="zeros")
|
| 51 |
+
self.relu = nn.ReLU(inplace=True)
|
| 52 |
+
|
| 53 |
+
num_groups = planes // 8
|
| 54 |
+
|
| 55 |
+
if norm_fn == "group":
|
| 56 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 57 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 58 |
+
if not stride == 1:
|
| 59 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 60 |
+
|
| 61 |
+
elif norm_fn == "batch":
|
| 62 |
+
self.norm1 = nn.BatchNorm2d(planes)
|
| 63 |
+
self.norm2 = nn.BatchNorm2d(planes)
|
| 64 |
+
if not stride == 1:
|
| 65 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
| 66 |
+
|
| 67 |
+
elif norm_fn == "instance":
|
| 68 |
+
self.norm1 = nn.InstanceNorm2d(planes)
|
| 69 |
+
self.norm2 = nn.InstanceNorm2d(planes)
|
| 70 |
+
if not stride == 1:
|
| 71 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
| 72 |
+
|
| 73 |
+
elif norm_fn == "none":
|
| 74 |
+
self.norm1 = nn.Sequential()
|
| 75 |
+
self.norm2 = nn.Sequential()
|
| 76 |
+
if not stride == 1:
|
| 77 |
+
self.norm3 = nn.Sequential()
|
| 78 |
+
else:
|
| 79 |
+
raise NotImplementedError
|
| 80 |
+
|
| 81 |
+
if stride == 1:
|
| 82 |
+
self.downsample = None
|
| 83 |
+
else:
|
| 84 |
+
self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
|
| 85 |
+
|
| 86 |
+
def forward(self, x):
|
| 87 |
+
y = x
|
| 88 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
| 89 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
| 90 |
+
|
| 91 |
+
if self.downsample is not None:
|
| 92 |
+
x = self.downsample(x)
|
| 93 |
+
|
| 94 |
+
return self.relu(x + y)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class Mlp(nn.Module):
|
| 98 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
| 99 |
+
|
| 100 |
+
def __init__(
|
| 101 |
+
self,
|
| 102 |
+
in_features,
|
| 103 |
+
hidden_features=None,
|
| 104 |
+
out_features=None,
|
| 105 |
+
act_layer=nn.GELU,
|
| 106 |
+
norm_layer=None,
|
| 107 |
+
bias=True,
|
| 108 |
+
drop=0.0,
|
| 109 |
+
use_conv=False,
|
| 110 |
+
):
|
| 111 |
+
super().__init__()
|
| 112 |
+
out_features = out_features or in_features
|
| 113 |
+
hidden_features = hidden_features or in_features
|
| 114 |
+
bias = to_2tuple(bias)
|
| 115 |
+
drop_probs = to_2tuple(drop)
|
| 116 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
| 117 |
+
|
| 118 |
+
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
| 119 |
+
self.act = act_layer()
|
| 120 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
| 121 |
+
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
| 122 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
| 123 |
+
|
| 124 |
+
def forward(self, x):
|
| 125 |
+
x = self.fc1(x)
|
| 126 |
+
x = self.act(x)
|
| 127 |
+
x = self.drop1(x)
|
| 128 |
+
x = self.fc2(x)
|
| 129 |
+
x = self.drop2(x)
|
| 130 |
+
return x
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class AttnBlock(nn.Module):
|
| 134 |
+
def __init__(
|
| 135 |
+
self,
|
| 136 |
+
hidden_size,
|
| 137 |
+
num_heads,
|
| 138 |
+
attn_class: Callable[..., nn.Module] = nn.MultiheadAttention,
|
| 139 |
+
mlp_ratio=4.0,
|
| 140 |
+
**block_kwargs,
|
| 141 |
+
):
|
| 142 |
+
"""
|
| 143 |
+
Self attention block
|
| 144 |
+
"""
|
| 145 |
+
super().__init__()
|
| 146 |
+
|
| 147 |
+
self.norm1 = nn.LayerNorm(hidden_size)
|
| 148 |
+
self.norm2 = nn.LayerNorm(hidden_size)
|
| 149 |
+
|
| 150 |
+
self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs)
|
| 151 |
+
|
| 152 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 153 |
+
|
| 154 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
|
| 155 |
+
|
| 156 |
+
def forward(self, x, mask=None):
|
| 157 |
+
# Prepare the mask for PyTorch's attention (it expects a different format)
|
| 158 |
+
# attn_mask = mask if mask is not None else None
|
| 159 |
+
# Normalize before attention
|
| 160 |
+
x = self.norm1(x)
|
| 161 |
+
|
| 162 |
+
# PyTorch's MultiheadAttention returns attn_output, attn_output_weights
|
| 163 |
+
# attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask)
|
| 164 |
+
|
| 165 |
+
attn_output, _ = self.attn(x, x, x)
|
| 166 |
+
|
| 167 |
+
# Add & Norm
|
| 168 |
+
x = x + attn_output
|
| 169 |
+
x = x + self.mlp(self.norm2(x))
|
| 170 |
+
return x
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class CrossAttnBlock(nn.Module):
|
| 174 |
+
def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):
|
| 175 |
+
"""
|
| 176 |
+
Cross attention block
|
| 177 |
+
"""
|
| 178 |
+
super().__init__()
|
| 179 |
+
|
| 180 |
+
self.norm1 = nn.LayerNorm(hidden_size)
|
| 181 |
+
self.norm_context = nn.LayerNorm(hidden_size)
|
| 182 |
+
self.norm2 = nn.LayerNorm(hidden_size)
|
| 183 |
+
|
| 184 |
+
self.cross_attn = nn.MultiheadAttention(
|
| 185 |
+
embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 189 |
+
|
| 190 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
|
| 191 |
+
|
| 192 |
+
def forward(self, x, context, mask=None):
|
| 193 |
+
# Normalize inputs
|
| 194 |
+
x = self.norm1(x)
|
| 195 |
+
context = self.norm_context(context)
|
| 196 |
+
|
| 197 |
+
# Apply cross attention
|
| 198 |
+
# Note: nn.MultiheadAttention returns attn_output, attn_output_weights
|
| 199 |
+
attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask)
|
| 200 |
+
|
| 201 |
+
# Add & Norm
|
| 202 |
+
x = x + attn_output
|
| 203 |
+
x = x + self.mlp(self.norm2(x))
|
| 204 |
+
return x
|
capvector-pi05/src/vggt/heads/track_modules/utils.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Modified from https://github.com/facebookresearch/vggsfm
|
| 8 |
+
# and https://github.com/facebookresearch/co-tracker/tree/main
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
from typing import Optional, Tuple, Union
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor:
|
| 19 |
+
"""
|
| 20 |
+
This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
|
| 21 |
+
It is a wrapper of get_2d_sincos_pos_embed_from_grid.
|
| 22 |
+
Args:
|
| 23 |
+
- embed_dim: The embedding dimension.
|
| 24 |
+
- grid_size: The grid size.
|
| 25 |
+
Returns:
|
| 26 |
+
- pos_embed: The generated 2D positional embedding.
|
| 27 |
+
"""
|
| 28 |
+
if isinstance(grid_size, tuple):
|
| 29 |
+
grid_size_h, grid_size_w = grid_size
|
| 30 |
+
else:
|
| 31 |
+
grid_size_h = grid_size_w = grid_size
|
| 32 |
+
grid_h = torch.arange(grid_size_h, dtype=torch.float)
|
| 33 |
+
grid_w = torch.arange(grid_size_w, dtype=torch.float)
|
| 34 |
+
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
|
| 35 |
+
grid = torch.stack(grid, dim=0)
|
| 36 |
+
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
|
| 37 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 38 |
+
if return_grid:
|
| 39 |
+
return (pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2), grid)
|
| 40 |
+
return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor:
|
| 44 |
+
"""
|
| 45 |
+
This function generates a 2D positional embedding from a given grid using sine and cosine functions.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
- embed_dim: The embedding dimension.
|
| 49 |
+
- grid: The grid to generate the embedding from.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
- emb: The generated 2D positional embedding.
|
| 53 |
+
"""
|
| 54 |
+
assert embed_dim % 2 == 0
|
| 55 |
+
|
| 56 |
+
# use half of dimensions to encode grid_h
|
| 57 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 58 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 59 |
+
|
| 60 |
+
emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
|
| 61 |
+
return emb
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor:
|
| 65 |
+
"""
|
| 66 |
+
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
- embed_dim: The embedding dimension.
|
| 70 |
+
- pos: The position to generate the embedding from.
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
- emb: The generated 1D positional embedding.
|
| 74 |
+
"""
|
| 75 |
+
assert embed_dim % 2 == 0
|
| 76 |
+
omega = torch.arange(embed_dim // 2, dtype=torch.double)
|
| 77 |
+
omega /= embed_dim / 2.0
|
| 78 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
| 79 |
+
|
| 80 |
+
pos = pos.reshape(-1) # (M,)
|
| 81 |
+
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 82 |
+
|
| 83 |
+
emb_sin = torch.sin(out) # (M, D/2)
|
| 84 |
+
emb_cos = torch.cos(out) # (M, D/2)
|
| 85 |
+
|
| 86 |
+
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
| 87 |
+
return emb[None].float()
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
|
| 91 |
+
"""
|
| 92 |
+
This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
- xy: The coordinates to generate the embedding from.
|
| 96 |
+
- C: The size of the embedding.
|
| 97 |
+
- cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
- pe: The generated 2D positional embedding.
|
| 101 |
+
"""
|
| 102 |
+
B, N, D = xy.shape
|
| 103 |
+
assert D == 2
|
| 104 |
+
|
| 105 |
+
x = xy[:, :, 0:1]
|
| 106 |
+
y = xy[:, :, 1:2]
|
| 107 |
+
div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2))
|
| 108 |
+
|
| 109 |
+
pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
| 110 |
+
pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
| 111 |
+
|
| 112 |
+
pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
| 113 |
+
pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
| 114 |
+
|
| 115 |
+
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
| 116 |
+
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
| 117 |
+
|
| 118 |
+
pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
|
| 119 |
+
if cat_coords:
|
| 120 |
+
pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
|
| 121 |
+
return pe
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
|
| 125 |
+
r"""Sample a tensor using bilinear interpolation
|
| 126 |
+
|
| 127 |
+
`bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
|
| 128 |
+
coordinates :attr:`coords` using bilinear interpolation. It is the same
|
| 129 |
+
as `torch.nn.functional.grid_sample()` but with a different coordinate
|
| 130 |
+
convention.
|
| 131 |
+
|
| 132 |
+
The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
|
| 133 |
+
:math:`B` is the batch size, :math:`C` is the number of channels,
|
| 134 |
+
:math:`H` is the height of the image, and :math:`W` is the width of the
|
| 135 |
+
image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
|
| 136 |
+
interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
|
| 137 |
+
|
| 138 |
+
Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
|
| 139 |
+
in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
|
| 140 |
+
that in this case the order of the components is slightly different
|
| 141 |
+
from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
|
| 142 |
+
|
| 143 |
+
If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
|
| 144 |
+
in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
|
| 145 |
+
left-most image pixel :math:`W-1` to the center of the right-most
|
| 146 |
+
pixel.
|
| 147 |
+
|
| 148 |
+
If `align_corners` is `False`, the coordinate :math:`x` is assumed to
|
| 149 |
+
be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
|
| 150 |
+
the left-most pixel :math:`W` to the right edge of the right-most
|
| 151 |
+
pixel.
|
| 152 |
+
|
| 153 |
+
Similar conventions apply to the :math:`y` for the range
|
| 154 |
+
:math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
|
| 155 |
+
:math:`[0,T-1]` and :math:`[0,T]`.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
input (Tensor): batch of input images.
|
| 159 |
+
coords (Tensor): batch of coordinates.
|
| 160 |
+
align_corners (bool, optional): Coordinate convention. Defaults to `True`.
|
| 161 |
+
padding_mode (str, optional): Padding mode. Defaults to `"border"`.
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
Tensor: sampled points.
|
| 165 |
+
"""
|
| 166 |
+
coords = coords.detach().clone()
|
| 167 |
+
############################################################
|
| 168 |
+
# IMPORTANT:
|
| 169 |
+
coords = coords.to(input.device).to(input.dtype)
|
| 170 |
+
############################################################
|
| 171 |
+
|
| 172 |
+
sizes = input.shape[2:]
|
| 173 |
+
|
| 174 |
+
assert len(sizes) in [2, 3]
|
| 175 |
+
|
| 176 |
+
if len(sizes) == 3:
|
| 177 |
+
# t x y -> x y t to match dimensions T H W in grid_sample
|
| 178 |
+
coords = coords[..., [1, 2, 0]]
|
| 179 |
+
|
| 180 |
+
if align_corners:
|
| 181 |
+
scale = torch.tensor(
|
| 182 |
+
[2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype
|
| 183 |
+
)
|
| 184 |
+
else:
|
| 185 |
+
scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype)
|
| 186 |
+
|
| 187 |
+
coords.mul_(scale) # coords = coords * scale
|
| 188 |
+
coords.sub_(1) # coords = coords - 1
|
| 189 |
+
|
| 190 |
+
return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def sample_features4d(input, coords):
|
| 194 |
+
r"""Sample spatial features
|
| 195 |
+
|
| 196 |
+
`sample_features4d(input, coords)` samples the spatial features
|
| 197 |
+
:attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
|
| 198 |
+
|
| 199 |
+
The field is sampled at coordinates :attr:`coords` using bilinear
|
| 200 |
+
interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
|
| 201 |
+
2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
|
| 202 |
+
same convention as :func:`bilinear_sampler` with `align_corners=True`.
|
| 203 |
+
|
| 204 |
+
The output tensor has one feature per point, and has shape :math:`(B,
|
| 205 |
+
R, C)`.
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
input (Tensor): spatial features.
|
| 209 |
+
coords (Tensor): points.
|
| 210 |
+
|
| 211 |
+
Returns:
|
| 212 |
+
Tensor: sampled features.
|
| 213 |
+
"""
|
| 214 |
+
|
| 215 |
+
B, _, _, _ = input.shape
|
| 216 |
+
|
| 217 |
+
# B R 2 -> B R 1 2
|
| 218 |
+
coords = coords.unsqueeze(2)
|
| 219 |
+
|
| 220 |
+
# B C R 1
|
| 221 |
+
feats = bilinear_sampler(input, coords)
|
| 222 |
+
|
| 223 |
+
return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C
|
capvector-pi05/src/vggt/heads/utils.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import numpy as np
|
| 10 |
+
from typing import List, Dict, Tuple, Union
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
|
| 13 |
+
def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
|
| 14 |
+
"""
|
| 15 |
+
Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
|
| 19 |
+
embed_dim: Output channel dimension for embeddings
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
Tensor of shape (H, W, embed_dim) with positional embeddings
|
| 23 |
+
"""
|
| 24 |
+
H, W, grid_dim = pos_grid.shape
|
| 25 |
+
assert grid_dim == 2
|
| 26 |
+
pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
|
| 27 |
+
|
| 28 |
+
# Process x and y coordinates separately
|
| 29 |
+
emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
|
| 30 |
+
emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
|
| 31 |
+
|
| 32 |
+
# Combine and reshape
|
| 33 |
+
emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
|
| 34 |
+
|
| 35 |
+
return emb.view(H, W, embed_dim) # [H, W, D]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
|
| 39 |
+
"""
|
| 40 |
+
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
- embed_dim: The embedding dimension.
|
| 44 |
+
- pos: The position to generate the embedding from.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
- emb: The generated 1D positional embedding.
|
| 48 |
+
"""
|
| 49 |
+
assert embed_dim % 2 == 0
|
| 50 |
+
device = pos.device
|
| 51 |
+
omega = torch.arange(embed_dim // 2, dtype=torch.float32 if device.type == "mps" else torch.double, device=device)
|
| 52 |
+
omega /= embed_dim / 2.0
|
| 53 |
+
omega = 1.0 / omega_0**omega # (D/2,)
|
| 54 |
+
|
| 55 |
+
pos = pos.reshape(-1) # (M,)
|
| 56 |
+
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 57 |
+
|
| 58 |
+
emb_sin = torch.sin(out) # (M, D/2)
|
| 59 |
+
emb_cos = torch.cos(out) # (M, D/2)
|
| 60 |
+
|
| 61 |
+
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
| 62 |
+
return emb.float()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# Inspired by https://github.com/microsoft/moge
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def create_uv_grid(
|
| 69 |
+
width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
|
| 70 |
+
) -> torch.Tensor:
|
| 71 |
+
"""
|
| 72 |
+
Create a normalized UV grid of shape (width, height, 2).
|
| 73 |
+
|
| 74 |
+
The grid spans horizontally and vertically according to an aspect ratio,
|
| 75 |
+
ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
|
| 76 |
+
corner is at (x_span, y_span), normalized by the diagonal of the plane.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
width (int): Number of points horizontally.
|
| 80 |
+
height (int): Number of points vertically.
|
| 81 |
+
aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
|
| 82 |
+
dtype (torch.dtype, optional): Data type of the resulting tensor.
|
| 83 |
+
device (torch.device, optional): Device on which the tensor is created.
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
torch.Tensor: A (width, height, 2) tensor of UV coordinates.
|
| 87 |
+
"""
|
| 88 |
+
# Derive aspect ratio if not explicitly provided
|
| 89 |
+
if aspect_ratio is None:
|
| 90 |
+
aspect_ratio = float(width) / float(height)
|
| 91 |
+
|
| 92 |
+
# Compute normalized spans for X and Y
|
| 93 |
+
diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
|
| 94 |
+
span_x = aspect_ratio / diag_factor
|
| 95 |
+
span_y = 1.0 / diag_factor
|
| 96 |
+
|
| 97 |
+
# Establish the linspace boundaries
|
| 98 |
+
left_x = -span_x * (width - 1) / width
|
| 99 |
+
right_x = span_x * (width - 1) / width
|
| 100 |
+
top_y = -span_y * (height - 1) / height
|
| 101 |
+
bottom_y = span_y * (height - 1) / height
|
| 102 |
+
|
| 103 |
+
# Generate 1D coordinates
|
| 104 |
+
x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
|
| 105 |
+
y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
|
| 106 |
+
|
| 107 |
+
# Create 2D meshgrid (width x height) and stack into UV
|
| 108 |
+
uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
|
| 109 |
+
uv_grid = torch.stack((uu, vv), dim=-1)
|
| 110 |
+
|
| 111 |
+
return uv_grid
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def _interpolate(
|
| 115 |
+
x: torch.Tensor,
|
| 116 |
+
size: Tuple[int, int] = None,
|
| 117 |
+
scale_factor: float = None,
|
| 118 |
+
mode: str = "bilinear",
|
| 119 |
+
align_corners: bool = True,
|
| 120 |
+
) -> torch.Tensor:
|
| 121 |
+
"""
|
| 122 |
+
Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
|
| 123 |
+
"""
|
| 124 |
+
if size is None:
|
| 125 |
+
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
|
| 126 |
+
|
| 127 |
+
INT_MAX = 1610612736
|
| 128 |
+
|
| 129 |
+
input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
|
| 130 |
+
|
| 131 |
+
if input_elements > INT_MAX:
|
| 132 |
+
chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
|
| 133 |
+
interpolated_chunks = [
|
| 134 |
+
nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
|
| 135 |
+
]
|
| 136 |
+
x = torch.cat(interpolated_chunks, dim=0)
|
| 137 |
+
return x.contiguous()
|
| 138 |
+
else:
|
| 139 |
+
return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _apply_pos_embed(x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
|
| 143 |
+
"""
|
| 144 |
+
Apply positional embedding to tensor x.
|
| 145 |
+
"""
|
| 146 |
+
patch_w = x.shape[-1]
|
| 147 |
+
patch_h = x.shape[-2]
|
| 148 |
+
pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
|
| 149 |
+
pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
|
| 150 |
+
pos_embed = pos_embed * ratio
|
| 151 |
+
pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
|
| 152 |
+
return x + pos_embed
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def interpolate_pooling(hidden, patch_hw, img_hw, reference, pooling_func, use_vggt_pe):
|
| 156 |
+
(patch_h, patch_w) = patch_hw
|
| 157 |
+
(img_h, img_w) = img_hw
|
| 158 |
+
bs, N, S, D = hidden.shape
|
| 159 |
+
re_sample_ratio = 1 / np.sqrt(N * S / reference.shape[1])
|
| 160 |
+
|
| 161 |
+
_hidden = hidden.permute(0, 1, 3, 2)
|
| 162 |
+
_hidden = _hidden.reshape(bs*N, D, patch_h, patch_w)
|
| 163 |
+
if use_vggt_pe:
|
| 164 |
+
_hidden = _apply_pos_embed(_hidden, img_w, img_h)
|
| 165 |
+
hidden_pooling = _interpolate(
|
| 166 |
+
_hidden, scale_factor=re_sample_ratio, mode=pooling_func, align_corners=True
|
| 167 |
+
)
|
| 168 |
+
hidden_pooling = hidden_pooling.reshape(bs, N, D, -1).permute(0, 1, 3, 2).reshape(bs, -1, D)
|
| 169 |
+
return hidden_pooling
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def custom_pooling(hidden, patch_hw, img_hw, reference, pooling_func, use_vggt_pe):
|
| 173 |
+
if pooling_func in ['bilinear']:
|
| 174 |
+
return interpolate_pooling(hidden, patch_hw, img_hw, reference, pooling_func, use_vggt_pe)
|
| 175 |
+
else:
|
| 176 |
+
raise NotImplementedError(f"Pooling function {pooling_func} is not implemented.")
|
capvector-pi05/src/vggt/layers/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from .mlp import Mlp
|
| 8 |
+
from .patch_embed import PatchEmbed
|
| 9 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
| 10 |
+
from .block import NestedTensorBlock
|
| 11 |
+
from .attention import MemEffAttention
|
capvector-pi05/src/vggt/layers/attention.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
import warnings
|
| 13 |
+
|
| 14 |
+
from torch import Tensor
|
| 15 |
+
from torch import nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
|
| 18 |
+
XFORMERS_AVAILABLE = False
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Attention(nn.Module):
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
dim: int,
|
| 25 |
+
num_heads: int = 8,
|
| 26 |
+
qkv_bias: bool = True,
|
| 27 |
+
proj_bias: bool = True,
|
| 28 |
+
attn_drop: float = 0.0,
|
| 29 |
+
proj_drop: float = 0.0,
|
| 30 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
| 31 |
+
qk_norm: bool = False,
|
| 32 |
+
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
|
| 33 |
+
rope=None,
|
| 34 |
+
) -> None:
|
| 35 |
+
super().__init__()
|
| 36 |
+
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
| 37 |
+
self.num_heads = num_heads
|
| 38 |
+
self.head_dim = dim // num_heads
|
| 39 |
+
self.scale = self.head_dim**-0.5
|
| 40 |
+
self.fused_attn = fused_attn
|
| 41 |
+
|
| 42 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 43 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 44 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 45 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 46 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 47 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 48 |
+
self.rope = rope
|
| 49 |
+
|
| 50 |
+
def forward(self, x: Tensor, pos=None) -> Tensor:
|
| 51 |
+
B, N, C = x.shape
|
| 52 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 53 |
+
q, k, v = qkv.unbind(0)
|
| 54 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 55 |
+
|
| 56 |
+
if self.rope is not None:
|
| 57 |
+
q = self.rope(q, pos)
|
| 58 |
+
k = self.rope(k, pos)
|
| 59 |
+
|
| 60 |
+
if self.fused_attn:
|
| 61 |
+
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0)
|
| 62 |
+
else:
|
| 63 |
+
q = q * self.scale
|
| 64 |
+
attn = q @ k.transpose(-2, -1)
|
| 65 |
+
attn = attn.softmax(dim=-1)
|
| 66 |
+
attn = self.attn_drop(attn)
|
| 67 |
+
x = attn @ v
|
| 68 |
+
|
| 69 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
| 70 |
+
x = self.proj(x)
|
| 71 |
+
x = self.proj_drop(x)
|
| 72 |
+
return x
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class MemEffAttention(Attention):
|
| 76 |
+
def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor:
|
| 77 |
+
assert pos is None
|
| 78 |
+
if not XFORMERS_AVAILABLE:
|
| 79 |
+
if attn_bias is not None:
|
| 80 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 81 |
+
return super().forward(x)
|
| 82 |
+
|
| 83 |
+
B, N, C = x.shape
|
| 84 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 85 |
+
|
| 86 |
+
q, k, v = unbind(qkv, 2)
|
| 87 |
+
|
| 88 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
| 89 |
+
x = x.reshape([B, N, C])
|
| 90 |
+
|
| 91 |
+
x = self.proj(x)
|
| 92 |
+
x = self.proj_drop(x)
|
| 93 |
+
return x
|
capvector-pi05/src/vggt/layers/block.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
from typing import Callable, List, Any, Tuple, Dict
|
| 13 |
+
import warnings
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from torch import nn, Tensor
|
| 17 |
+
|
| 18 |
+
from .attention import Attention
|
| 19 |
+
from .drop_path import DropPath
|
| 20 |
+
from .layer_scale import LayerScale
|
| 21 |
+
from .mlp import Mlp
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
XFORMERS_AVAILABLE = False
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Block(nn.Module):
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
dim: int,
|
| 31 |
+
num_heads: int,
|
| 32 |
+
mlp_ratio: float = 4.0,
|
| 33 |
+
qkv_bias: bool = True,
|
| 34 |
+
proj_bias: bool = True,
|
| 35 |
+
ffn_bias: bool = True,
|
| 36 |
+
drop: float = 0.0,
|
| 37 |
+
attn_drop: float = 0.0,
|
| 38 |
+
init_values=None,
|
| 39 |
+
drop_path: float = 0.0,
|
| 40 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 41 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 42 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
| 43 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 44 |
+
qk_norm: bool = False,
|
| 45 |
+
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
|
| 46 |
+
rope=None,
|
| 47 |
+
) -> None:
|
| 48 |
+
super().__init__()
|
| 49 |
+
|
| 50 |
+
self.norm1 = norm_layer(dim)
|
| 51 |
+
|
| 52 |
+
self.attn = attn_class(
|
| 53 |
+
dim,
|
| 54 |
+
num_heads=num_heads,
|
| 55 |
+
qkv_bias=qkv_bias,
|
| 56 |
+
proj_bias=proj_bias,
|
| 57 |
+
attn_drop=attn_drop,
|
| 58 |
+
proj_drop=drop,
|
| 59 |
+
qk_norm=qk_norm,
|
| 60 |
+
fused_attn=fused_attn,
|
| 61 |
+
rope=rope,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 65 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 66 |
+
|
| 67 |
+
self.norm2 = norm_layer(dim)
|
| 68 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 69 |
+
self.mlp = ffn_layer(
|
| 70 |
+
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, bias=ffn_bias
|
| 71 |
+
)
|
| 72 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 73 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 74 |
+
|
| 75 |
+
self.sample_drop_ratio = drop_path
|
| 76 |
+
|
| 77 |
+
def forward(self, x: Tensor, pos=None) -> Tensor:
|
| 78 |
+
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
|
| 79 |
+
return self.ls1(self.attn(self.norm1(x), pos=pos))
|
| 80 |
+
|
| 81 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 82 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 83 |
+
|
| 84 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
| 85 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
| 86 |
+
x = drop_add_residual_stochastic_depth(
|
| 87 |
+
x, pos=pos, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio
|
| 88 |
+
)
|
| 89 |
+
x = drop_add_residual_stochastic_depth(
|
| 90 |
+
x, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio
|
| 91 |
+
)
|
| 92 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
| 93 |
+
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
|
| 94 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
| 95 |
+
else:
|
| 96 |
+
x = x + attn_residual_func(x, pos=pos)
|
| 97 |
+
x = x + ffn_residual_func(x)
|
| 98 |
+
return x
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def drop_add_residual_stochastic_depth(
|
| 102 |
+
x: Tensor, residual_func: Callable[[Tensor], Tensor], sample_drop_ratio: float = 0.0, pos=None
|
| 103 |
+
) -> Tensor:
|
| 104 |
+
# 1) extract subset using permutation
|
| 105 |
+
b, n, d = x.shape
|
| 106 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 107 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 108 |
+
x_subset = x[brange]
|
| 109 |
+
|
| 110 |
+
# 2) apply residual_func to get residual
|
| 111 |
+
if pos is not None:
|
| 112 |
+
# if necessary, apply rope to the subset
|
| 113 |
+
pos = pos[brange]
|
| 114 |
+
residual = residual_func(x_subset, pos=pos)
|
| 115 |
+
else:
|
| 116 |
+
residual = residual_func(x_subset)
|
| 117 |
+
|
| 118 |
+
x_flat = x.flatten(1)
|
| 119 |
+
residual = residual.flatten(1)
|
| 120 |
+
|
| 121 |
+
residual_scale_factor = b / sample_subset_size
|
| 122 |
+
|
| 123 |
+
# 3) add the residual
|
| 124 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 125 |
+
return x_plus_residual.view_as(x)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
| 129 |
+
b, n, d = x.shape
|
| 130 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 131 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 132 |
+
residual_scale_factor = b / sample_subset_size
|
| 133 |
+
return brange, residual_scale_factor
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
| 137 |
+
if scaling_vector is None:
|
| 138 |
+
x_flat = x.flatten(1)
|
| 139 |
+
residual = residual.flatten(1)
|
| 140 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 141 |
+
else:
|
| 142 |
+
x_plus_residual = scaled_index_add(
|
| 143 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
| 144 |
+
)
|
| 145 |
+
return x_plus_residual
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
| 152 |
+
"""
|
| 153 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
| 154 |
+
"""
|
| 155 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
| 156 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
| 157 |
+
if all_shapes not in attn_bias_cache.keys():
|
| 158 |
+
seqlens = []
|
| 159 |
+
for b, x in zip(batch_sizes, x_list):
|
| 160 |
+
for _ in range(b):
|
| 161 |
+
seqlens.append(x.shape[1])
|
| 162 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
| 163 |
+
attn_bias._batch_sizes = batch_sizes
|
| 164 |
+
attn_bias_cache[all_shapes] = attn_bias
|
| 165 |
+
|
| 166 |
+
if branges is not None:
|
| 167 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
| 168 |
+
else:
|
| 169 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
| 170 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
| 171 |
+
|
| 172 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def drop_add_residual_stochastic_depth_list(
|
| 176 |
+
x_list: List[Tensor],
|
| 177 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
| 178 |
+
sample_drop_ratio: float = 0.0,
|
| 179 |
+
scaling_vector=None,
|
| 180 |
+
) -> Tensor:
|
| 181 |
+
# 1) generate random set of indices for dropping samples in the batch
|
| 182 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
| 183 |
+
branges = [s[0] for s in branges_scales]
|
| 184 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
| 185 |
+
|
| 186 |
+
# 2) get attention bias and index+concat the tensors
|
| 187 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
| 188 |
+
|
| 189 |
+
# 3) apply residual_func to get residual, and split the result
|
| 190 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
| 191 |
+
|
| 192 |
+
outputs = []
|
| 193 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
| 194 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
| 195 |
+
return outputs
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class NestedTensorBlock(Block):
|
| 199 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
| 200 |
+
"""
|
| 201 |
+
x_list contains a list of tensors to nest together and run
|
| 202 |
+
"""
|
| 203 |
+
assert isinstance(self.attn, MemEffAttention)
|
| 204 |
+
|
| 205 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
| 206 |
+
|
| 207 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 208 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
| 209 |
+
|
| 210 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 211 |
+
return self.mlp(self.norm2(x))
|
| 212 |
+
|
| 213 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 214 |
+
x_list,
|
| 215 |
+
residual_func=attn_residual_func,
|
| 216 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 217 |
+
scaling_vector=(self.ls1.gamma if isinstance(self.ls1, LayerScale) else None),
|
| 218 |
+
)
|
| 219 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 220 |
+
x_list,
|
| 221 |
+
residual_func=ffn_residual_func,
|
| 222 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 223 |
+
scaling_vector=(self.ls2.gamma if isinstance(self.ls1, LayerScale) else None),
|
| 224 |
+
)
|
| 225 |
+
return x_list
|
| 226 |
+
else:
|
| 227 |
+
|
| 228 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 229 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
| 230 |
+
|
| 231 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 232 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 233 |
+
|
| 234 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
| 235 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
| 236 |
+
x = x + ffn_residual_func(x)
|
| 237 |
+
return attn_bias.split(x)
|
| 238 |
+
|
| 239 |
+
def forward(self, x_or_x_list):
|
| 240 |
+
if isinstance(x_or_x_list, Tensor):
|
| 241 |
+
return super().forward(x_or_x_list)
|
| 242 |
+
elif isinstance(x_or_x_list, list):
|
| 243 |
+
if not XFORMERS_AVAILABLE:
|
| 244 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 245 |
+
return self.forward_nested(x_or_x_list)
|
| 246 |
+
else:
|
| 247 |
+
raise AssertionError
|
capvector-pi05/src/vggt/layers/drop_path.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from torch import nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
| 15 |
+
if drop_prob == 0.0 or not training:
|
| 16 |
+
return x
|
| 17 |
+
keep_prob = 1 - drop_prob
|
| 18 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 19 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 20 |
+
if keep_prob > 0.0:
|
| 21 |
+
random_tensor.div_(keep_prob)
|
| 22 |
+
output = x * random_tensor
|
| 23 |
+
return output
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class DropPath(nn.Module):
|
| 27 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 28 |
+
|
| 29 |
+
def __init__(self, drop_prob=None):
|
| 30 |
+
super(DropPath, self).__init__()
|
| 31 |
+
self.drop_prob = drop_prob
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
return drop_path(x, self.drop_prob, self.training)
|
capvector-pi05/src/vggt/layers/layer_scale.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
| 7 |
+
|
| 8 |
+
from typing import Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch import Tensor
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LayerScale(nn.Module):
|
| 16 |
+
def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False) -> None:
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.inplace = inplace
|
| 19 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 20 |
+
|
| 21 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 22 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
capvector-pi05/src/vggt/layers/mlp.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from typing import Callable, Optional
|
| 12 |
+
|
| 13 |
+
from torch import Tensor, nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Mlp(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
in_features: int,
|
| 20 |
+
hidden_features: Optional[int] = None,
|
| 21 |
+
out_features: Optional[int] = None,
|
| 22 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 23 |
+
drop: float = 0.0,
|
| 24 |
+
bias: bool = True,
|
| 25 |
+
) -> None:
|
| 26 |
+
super().__init__()
|
| 27 |
+
out_features = out_features or in_features
|
| 28 |
+
hidden_features = hidden_features or in_features
|
| 29 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 30 |
+
self.act = act_layer()
|
| 31 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 32 |
+
self.drop = nn.Dropout(drop)
|
| 33 |
+
|
| 34 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 35 |
+
x = self.fc1(x)
|
| 36 |
+
x = self.act(x)
|
| 37 |
+
x = self.drop(x)
|
| 38 |
+
x = self.fc2(x)
|
| 39 |
+
x = self.drop(x)
|
| 40 |
+
return x
|