haofuly commited on
Commit
4e80de3
·
verified ·
1 Parent(s): 5e4171f

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. capvector-pi05/src/openpi/policies/policy_test.py +34 -0
  2. capvector-pi05/src/openpi/serving/websocket_policy_server.py +90 -0
  3. capvector-pi05/src/openpi/shared/__init__.py +0 -0
  4. capvector-pi05/src/openpi/shared/download.py +194 -0
  5. capvector-pi05/src/openpi/shared/download_test.py +54 -0
  6. capvector-pi05/src/openpi/shared/image_tools.py +186 -0
  7. capvector-pi05/src/openpi/shared/image_tools_test.py +37 -0
  8. capvector-pi05/src/openpi/shared/nnx_utils.py +69 -0
  9. capvector-pi05/src/openpi/shared/normalize.py +146 -0
  10. capvector-pi05/src/openpi/shared/normalize_test.py +43 -0
  11. capvector-pi05/src/openpi/training/checkpoints.py +159 -0
  12. capvector-pi05/src/openpi/training/config.py +1033 -0
  13. capvector-pi05/src/openpi/training/data_loader.py +540 -0
  14. capvector-pi05/src/openpi/training/data_loader_test.py +84 -0
  15. capvector-pi05/src/openpi/training/droid_rlds_dataset.py +221 -0
  16. capvector-pi05/src/openpi/training/misc/roboarena_config.py +116 -0
  17. capvector-pi05/src/openpi/training/optimizer.py +109 -0
  18. capvector-pi05/src/openpi/training/sharding.py +102 -0
  19. capvector-pi05/src/openpi/training/utils.py +38 -0
  20. capvector-pi05/src/openpi/training/weight_loaders.py +104 -0
  21. capvector-pi05/src/vggt/__init__.py +0 -0
  22. capvector-pi05/src/vggt/dependency/__init__.py +3 -0
  23. capvector-pi05/src/vggt/dependency/distortion.py +182 -0
  24. capvector-pi05/src/vggt/dependency/np_to_pycolmap.py +320 -0
  25. capvector-pi05/src/vggt/dependency/projection.py +228 -0
  26. capvector-pi05/src/vggt/dependency/track_modules/__init__.py +0 -0
  27. capvector-pi05/src/vggt/dependency/track_modules/base_track_predictor.py +190 -0
  28. capvector-pi05/src/vggt/dependency/track_modules/blocks.py +329 -0
  29. capvector-pi05/src/vggt/dependency/track_modules/modules.py +202 -0
  30. capvector-pi05/src/vggt/dependency/track_modules/track_refine.py +419 -0
  31. capvector-pi05/src/vggt/dependency/track_modules/utils.py +216 -0
  32. capvector-pi05/src/vggt/dependency/track_predict.py +326 -0
  33. capvector-pi05/src/vggt/dependency/vggsfm_tracker.py +124 -0
  34. capvector-pi05/src/vggt/dependency/vggsfm_utils.py +305 -0
  35. capvector-pi05/src/vggt/heads/camera_head.py +149 -0
  36. capvector-pi05/src/vggt/heads/dpt_head.py +484 -0
  37. capvector-pi05/src/vggt/heads/head_act.py +125 -0
  38. capvector-pi05/src/vggt/heads/track_head.py +104 -0
  39. capvector-pi05/src/vggt/heads/track_modules/__init__.py +5 -0
  40. capvector-pi05/src/vggt/heads/track_modules/base_track_predictor.py +209 -0
  41. capvector-pi05/src/vggt/heads/track_modules/blocks.py +236 -0
  42. capvector-pi05/src/vggt/heads/track_modules/modules.py +204 -0
  43. capvector-pi05/src/vggt/heads/track_modules/utils.py +223 -0
  44. capvector-pi05/src/vggt/heads/utils.py +176 -0
  45. capvector-pi05/src/vggt/layers/__init__.py +11 -0
  46. capvector-pi05/src/vggt/layers/attention.py +93 -0
  47. capvector-pi05/src/vggt/layers/block.py +247 -0
  48. capvector-pi05/src/vggt/layers/drop_path.py +34 -0
  49. capvector-pi05/src/vggt/layers/layer_scale.py +22 -0
  50. capvector-pi05/src/vggt/layers/mlp.py +40 -0
capvector-pi05/src/openpi/policies/policy_test.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openpi_client import action_chunk_broker
2
+ import pytest
3
+
4
+ from openpi.policies import aloha_policy
5
+ from openpi.policies import policy_config as _policy_config
6
+ from openpi.training import config as _config
7
+
8
+
9
+ @pytest.mark.manual
10
+ def test_infer():
11
+ config = _config.get_config("pi0_aloha_sim")
12
+ policy = _policy_config.create_trained_policy(config, "gs://openpi-assets/checkpoints/pi0_aloha_sim")
13
+
14
+ example = aloha_policy.make_aloha_example()
15
+ result = policy.infer(example)
16
+
17
+ assert result["actions"].shape == (config.model.action_horizon, 14)
18
+
19
+
20
+ @pytest.mark.manual
21
+ def test_broker():
22
+ config = _config.get_config("pi0_aloha_sim")
23
+ policy = _policy_config.create_trained_policy(config, "gs://openpi-assets/checkpoints/pi0_aloha_sim")
24
+
25
+ broker = action_chunk_broker.ActionChunkBroker(
26
+ policy,
27
+ # Only execute the first half of the chunk.
28
+ action_horizon=config.model.action_horizon // 2,
29
+ )
30
+
31
+ example = aloha_policy.make_aloha_example()
32
+ for _ in range(config.model.action_horizon):
33
+ outputs = broker.infer(example)
34
+ assert outputs["actions"].shape == (14,)
capvector-pi05/src/openpi/serving/websocket_policy_server.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import http
3
+ import logging
4
+ import time
5
+ import traceback
6
+
7
+ from openpi_client import base_policy as _base_policy
8
+ from openpi_client import msgpack_numpy
9
+ import websockets.asyncio.server as _server
10
+ import websockets.frames
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class WebsocketPolicyServer:
16
+ """Serves a policy using the websocket protocol. See websocket_client_policy.py for a client implementation.
17
+
18
+ Currently only implements the `load` and `infer` methods.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ policy: _base_policy.BasePolicy,
24
+ host: str = "0.0.0.0",
25
+ port: int | None = None,
26
+ metadata: dict | None = None,
27
+ ) -> None:
28
+ self._policy = policy
29
+ self._host = host
30
+ self._port = port
31
+ self._metadata = metadata or {}
32
+ logging.getLogger("websockets.server").setLevel(logging.INFO)
33
+
34
+ def serve_forever(self) -> None:
35
+ asyncio.run(self.run())
36
+
37
+ async def run(self):
38
+ async with _server.serve(
39
+ self._handler,
40
+ self._host,
41
+ self._port,
42
+ compression=None,
43
+ max_size=None,
44
+ process_request=_health_check,
45
+ ) as server:
46
+ await server.serve_forever()
47
+
48
+ async def _handler(self, websocket: _server.ServerConnection):
49
+ logger.info(f"Connection from {websocket.remote_address} opened")
50
+ packer = msgpack_numpy.Packer()
51
+
52
+ await websocket.send(packer.pack(self._metadata))
53
+
54
+ prev_total_time = None
55
+ while True:
56
+ try:
57
+ start_time = time.monotonic()
58
+ obs = msgpack_numpy.unpackb(await websocket.recv())
59
+
60
+ infer_time = time.monotonic()
61
+ action = self._policy.infer(obs)
62
+ infer_time = time.monotonic() - infer_time
63
+
64
+ action["server_timing"] = {
65
+ "infer_ms": infer_time * 1000,
66
+ }
67
+ if prev_total_time is not None:
68
+ # We can only record the last total time since we also want to include the send time.
69
+ action["server_timing"]["prev_total_ms"] = prev_total_time * 1000
70
+
71
+ await websocket.send(packer.pack(action))
72
+ prev_total_time = time.monotonic() - start_time
73
+
74
+ except websockets.ConnectionClosed:
75
+ logger.info(f"Connection from {websocket.remote_address} closed")
76
+ break
77
+ except Exception:
78
+ await websocket.send(traceback.format_exc())
79
+ await websocket.close(
80
+ code=websockets.frames.CloseCode.INTERNAL_ERROR,
81
+ reason="Internal server error. Traceback included in previous frame.",
82
+ )
83
+ raise
84
+
85
+
86
+ def _health_check(connection: _server.ServerConnection, request: _server.Request) -> _server.Response | None:
87
+ if request.path == "/healthz":
88
+ return connection.respond(http.HTTPStatus.OK, "OK\n")
89
+ # Continue with the normal request handling.
90
+ return None
capvector-pi05/src/openpi/shared/__init__.py ADDED
File without changes
capvector-pi05/src/openpi/shared/download.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import concurrent.futures
2
+ import datetime
3
+ import logging
4
+ import os
5
+ import pathlib
6
+ import re
7
+ import shutil
8
+ import stat
9
+ import time
10
+ import urllib.parse
11
+
12
+ import filelock
13
+ import fsspec
14
+ import fsspec.generic
15
+ import tqdm_loggable.auto as tqdm
16
+
17
+ # Environment variable to control cache directory path, ~/.cache/openpi will be used by default.
18
+ _OPENPI_DATA_HOME = "OPENPI_DATA_HOME"
19
+ DEFAULT_CACHE_DIR = "~/.cache/openpi"
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def get_cache_dir() -> pathlib.Path:
25
+ cache_dir = pathlib.Path(os.getenv(_OPENPI_DATA_HOME, DEFAULT_CACHE_DIR)).expanduser().resolve()
26
+ cache_dir.mkdir(parents=True, exist_ok=True)
27
+ _set_folder_permission(cache_dir)
28
+ return cache_dir
29
+
30
+
31
+ def maybe_download(url: str, *, force_download: bool = False, **kwargs) -> pathlib.Path:
32
+ """Download a file or directory from a remote filesystem to the local cache, and return the local path.
33
+
34
+ If the local file already exists, it will be returned directly.
35
+
36
+ It is safe to call this function concurrently from multiple processes.
37
+ See `get_cache_dir` for more details on the cache directory.
38
+
39
+ Args:
40
+ url: URL to the file to download.
41
+ force_download: If True, the file will be downloaded even if it already exists in the cache.
42
+ **kwargs: Additional arguments to pass to fsspec.
43
+
44
+ Returns:
45
+ Local path to the downloaded file or directory. That path is guaranteed to exist and is absolute.
46
+ """
47
+ # Don't use fsspec to parse the url to avoid unnecessary connection to the remote filesystem.
48
+ parsed = urllib.parse.urlparse(url)
49
+
50
+ # Short circuit if this is a local path.
51
+ if parsed.scheme == "":
52
+ path = pathlib.Path(url)
53
+ if not path.exists():
54
+ raise FileNotFoundError(f"File not found at {url}")
55
+ return path.resolve()
56
+
57
+ cache_dir = get_cache_dir()
58
+
59
+ local_path = cache_dir / parsed.netloc / parsed.path.strip("/")
60
+ local_path = local_path.resolve()
61
+
62
+ # Check if the cache should be invalidated.
63
+ invalidate_cache = False
64
+ if local_path.exists():
65
+ if force_download or _should_invalidate_cache(cache_dir, local_path):
66
+ invalidate_cache = True
67
+ else:
68
+ return local_path
69
+
70
+ try:
71
+ lock_path = local_path.with_suffix(".lock")
72
+ with filelock.FileLock(lock_path):
73
+ # Ensure consistent permissions for the lock file.
74
+ _ensure_permissions(lock_path)
75
+ # First, remove the existing cache if it is expired.
76
+ if invalidate_cache:
77
+ logger.info(f"Removing expired cached entry: {local_path}")
78
+ if local_path.is_dir():
79
+ shutil.rmtree(local_path)
80
+ else:
81
+ local_path.unlink()
82
+
83
+ # Download the data to a local cache.
84
+ logger.info(f"Downloading {url} to {local_path}")
85
+ scratch_path = local_path.with_suffix(".partial")
86
+ _download_fsspec(url, scratch_path, **kwargs)
87
+
88
+ shutil.move(scratch_path, local_path)
89
+ _ensure_permissions(local_path)
90
+
91
+ except PermissionError as e:
92
+ msg = (
93
+ f"Local file permission error was encountered while downloading {url}. "
94
+ f"Please try again after removing the cached data using: `rm -rf {local_path}*`"
95
+ )
96
+ raise PermissionError(msg) from e
97
+
98
+ return local_path
99
+
100
+
101
+ def _download_fsspec(url: str, local_path: pathlib.Path, **kwargs) -> None:
102
+ """Download a file from a remote filesystem to the local cache, and return the local path."""
103
+ fs, _ = fsspec.core.url_to_fs(url, **kwargs)
104
+ info = fs.info(url)
105
+ # Folders are represented by 0-byte objects with a trailing forward slash.
106
+ if is_dir := (info["type"] == "directory" or (info["size"] == 0 and info["name"].endswith("/"))):
107
+ total_size = fs.du(url)
108
+ else:
109
+ total_size = info["size"]
110
+ with tqdm.tqdm(total=total_size, unit="iB", unit_scale=True, unit_divisor=1024) as pbar:
111
+ executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
112
+ future = executor.submit(fs.get, url, local_path, recursive=is_dir)
113
+ while not future.done():
114
+ current_size = sum(f.stat().st_size for f in [*local_path.rglob("*"), local_path] if f.is_file())
115
+ pbar.update(current_size - pbar.n)
116
+ time.sleep(1)
117
+ pbar.update(total_size - pbar.n)
118
+
119
+
120
+ def _set_permission(path: pathlib.Path, target_permission: int):
121
+ """chmod requires executable permission to be set, so we skip if the permission is already match with the target."""
122
+ if path.stat().st_mode & target_permission == target_permission:
123
+ logger.debug(f"Skipping {path} because it already has correct permissions")
124
+ return
125
+ path.chmod(target_permission)
126
+ logger.debug(f"Set {path} to {target_permission}")
127
+
128
+
129
+ def _set_folder_permission(folder_path: pathlib.Path) -> None:
130
+ """Set folder permission to be read, write and searchable."""
131
+ _set_permission(folder_path, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
132
+
133
+
134
+ def _ensure_permissions(path: pathlib.Path) -> None:
135
+ """Since we are sharing cache directory with containerized runtime as well as training script, we need to
136
+ ensure that the cache directory has the correct permissions.
137
+ """
138
+
139
+ def _setup_folder_permission_between_cache_dir_and_path(path: pathlib.Path) -> None:
140
+ cache_dir = get_cache_dir()
141
+ relative_path = path.relative_to(cache_dir)
142
+ moving_path = cache_dir
143
+ for part in relative_path.parts:
144
+ _set_folder_permission(moving_path / part)
145
+ moving_path = moving_path / part
146
+
147
+ def _set_file_permission(file_path: pathlib.Path) -> None:
148
+ """Set all files to be read & writable, if it is a script, keep it as a script."""
149
+ file_rw = stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IWGRP | stat.S_IROTH | stat.S_IWOTH
150
+ if file_path.stat().st_mode & 0o100:
151
+ _set_permission(file_path, file_rw | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)
152
+ else:
153
+ _set_permission(file_path, file_rw)
154
+
155
+ _setup_folder_permission_between_cache_dir_and_path(path)
156
+ for root, dirs, files in os.walk(str(path)):
157
+ root_path = pathlib.Path(root)
158
+ for file in files:
159
+ file_path = root_path / file
160
+ _set_file_permission(file_path)
161
+
162
+ for dir in dirs:
163
+ dir_path = root_path / dir
164
+ _set_folder_permission(dir_path)
165
+
166
+
167
+ def _get_mtime(year: int, month: int, day: int) -> float:
168
+ """Get the mtime of a given date at midnight UTC."""
169
+ date = datetime.datetime(year, month, day, tzinfo=datetime.UTC)
170
+ return time.mktime(date.timetuple())
171
+
172
+
173
+ # Map of relative paths, defined as regular expressions, to expiration timestamps (mtime format).
174
+ # Partial matching will be used from top to bottom and the first match will be chosen.
175
+ # Cached entries will be retained only if they are newer than the expiration timestamp.
176
+ _INVALIDATE_CACHE_DIRS: dict[re.Pattern, float] = {
177
+ re.compile("openpi-assets/checkpoints/pi0_aloha_pen_uncap"): _get_mtime(2025, 2, 17),
178
+ re.compile("openpi-assets/checkpoints/pi0_libero"): _get_mtime(2025, 2, 6),
179
+ re.compile("openpi-assets/checkpoints/"): _get_mtime(2025, 2, 3),
180
+ }
181
+
182
+
183
+ def _should_invalidate_cache(cache_dir: pathlib.Path, local_path: pathlib.Path) -> bool:
184
+ """Invalidate the cache if it is expired. Return True if the cache was invalidated."""
185
+
186
+ assert local_path.exists(), f"File not found at {local_path}"
187
+
188
+ relative_path = str(local_path.relative_to(cache_dir))
189
+ for pattern, expire_time in _INVALIDATE_CACHE_DIRS.items():
190
+ if pattern.match(relative_path):
191
+ # Remove if not newer than the expiration timestamp.
192
+ return local_path.stat().st_mtime <= expire_time
193
+
194
+ return False
capvector-pi05/src/openpi/shared/download_test.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+
3
+ import pytest
4
+
5
+ import openpi.shared.download as download
6
+
7
+
8
+ @pytest.fixture(scope="session", autouse=True)
9
+ def set_openpi_data_home(tmp_path_factory):
10
+ temp_dir = tmp_path_factory.mktemp("openpi_data")
11
+ with pytest.MonkeyPatch().context() as mp:
12
+ mp.setenv("OPENPI_DATA_HOME", str(temp_dir))
13
+ yield
14
+
15
+
16
+ def test_download_local(tmp_path: pathlib.Path):
17
+ local_path = tmp_path / "local"
18
+ local_path.touch()
19
+
20
+ result = download.maybe_download(str(local_path))
21
+ assert result == local_path
22
+
23
+ with pytest.raises(FileNotFoundError):
24
+ download.maybe_download("bogus")
25
+
26
+
27
+ def test_download_gs_dir():
28
+ remote_path = "gs://openpi-assets/testdata/random"
29
+
30
+ local_path = download.maybe_download(remote_path)
31
+ assert local_path.exists()
32
+
33
+ new_local_path = download.maybe_download(remote_path)
34
+ assert new_local_path == local_path
35
+
36
+
37
+ def test_download_gs():
38
+ remote_path = "gs://openpi-assets/testdata/random/random_512kb.bin"
39
+
40
+ local_path = download.maybe_download(remote_path)
41
+ assert local_path.exists()
42
+
43
+ new_local_path = download.maybe_download(remote_path)
44
+ assert new_local_path == local_path
45
+
46
+
47
+ def test_download_fsspec():
48
+ remote_path = "gs://big_vision/paligemma_tokenizer.model"
49
+
50
+ local_path = download.maybe_download(remote_path, gs={"token": "anon"})
51
+ assert local_path.exists()
52
+
53
+ new_local_path = download.maybe_download(remote_path, gs={"token": "anon"})
54
+ assert new_local_path == local_path
capvector-pi05/src/openpi/shared/image_tools.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import torch
6
+ import torch.nn.functional as F # noqa: N812
7
+
8
+ import openpi.shared.array_typing as at
9
+
10
+
11
+ @functools.partial(jax.jit, static_argnums=(1, 2, 3))
12
+ @at.typecheck
13
+ def resize_with_pad(
14
+ images: at.UInt8[at.Array, "*b h w c"] | at.Float[at.Array, "*b h w c"],
15
+ height: int,
16
+ width: int,
17
+ method: jax.image.ResizeMethod = jax.image.ResizeMethod.LINEAR,
18
+ ) -> at.UInt8[at.Array, "*b {height} {width} c"] | at.Float[at.Array, "*b {height} {width} c"]:
19
+ """Replicates tf.image.resize_with_pad. Resizes an image to a target height and width without distortion
20
+ by padding with black. If the image is float32, it must be in the range [-1, 1].
21
+ """
22
+ has_batch_dim = images.ndim == 4
23
+ if not has_batch_dim:
24
+ images = images[None] # type: ignore
25
+ cur_height, cur_width = images.shape[1:3]
26
+ ratio = max(cur_width / width, cur_height / height)
27
+ resized_height = int(cur_height / ratio)
28
+ resized_width = int(cur_width / ratio)
29
+ resized_images = jax.image.resize(
30
+ images, (images.shape[0], resized_height, resized_width, images.shape[3]), method=method
31
+ )
32
+ if images.dtype == jnp.uint8:
33
+ # round from float back to uint8
34
+ resized_images = jnp.round(resized_images).clip(0, 255).astype(jnp.uint8)
35
+ elif images.dtype == jnp.float32:
36
+ resized_images = resized_images.clip(-1.0, 1.0)
37
+ else:
38
+ raise ValueError(f"Unsupported image dtype: {images.dtype}")
39
+
40
+ pad_h0, remainder_h = divmod(height - resized_height, 2)
41
+ pad_h1 = pad_h0 + remainder_h
42
+ pad_w0, remainder_w = divmod(width - resized_width, 2)
43
+ pad_w1 = pad_w0 + remainder_w
44
+ padded_images = jnp.pad(
45
+ resized_images,
46
+ ((0, 0), (pad_h0, pad_h1), (pad_w0, pad_w1), (0, 0)),
47
+ constant_values=0 if images.dtype == jnp.uint8 else -1.0,
48
+ )
49
+
50
+ if not has_batch_dim:
51
+ padded_images = padded_images[0]
52
+ return padded_images
53
+
54
+
55
+ def resize_with_pad_torch(
56
+ images: torch.Tensor,
57
+ height: int,
58
+ width: int,
59
+ mode: str = "bilinear",
60
+ ) -> torch.Tensor:
61
+ """PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion
62
+ by padding with black. If the image is float32, it must be in the range [-1, 1].
63
+
64
+ Args:
65
+ images: Tensor of shape [*b, h, w, c] or [*b, c, h, w]
66
+ height: Target height
67
+ width: Target width
68
+ mode: Interpolation mode ('bilinear', 'nearest', etc.)
69
+
70
+ Returns:
71
+ Resized and padded tensor with same shape format as input
72
+ """
73
+ # Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w]
74
+ if images.shape[-1] <= 4: # Assume channels-last format
75
+ channels_last = True
76
+ # Convert to channels-first for torch operations
77
+ if images.dim() == 3:
78
+ images = images.unsqueeze(0) # Add batch dimension
79
+ images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w]
80
+ else:
81
+ channels_last = False
82
+ if images.dim() == 3:
83
+ images = images.unsqueeze(0) # Add batch dimension
84
+
85
+ batch_size, channels, cur_height, cur_width = images.shape
86
+
87
+ # Calculate resize ratio
88
+ ratio = max(cur_width / width, cur_height / height)
89
+ resized_height = int(cur_height / ratio)
90
+ resized_width = int(cur_width / ratio)
91
+
92
+ # Resize
93
+ resized_images = F.interpolate(
94
+ images, size=(resized_height, resized_width), mode=mode, align_corners=False if mode == "bilinear" else None
95
+ )
96
+
97
+ # Handle dtype-specific clipping
98
+ if images.dtype == torch.uint8:
99
+ resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
100
+ elif images.dtype == torch.float32:
101
+ resized_images = resized_images.clamp(-1.0, 1.0)
102
+ else:
103
+ raise ValueError(f"Unsupported image dtype: {images.dtype}")
104
+
105
+ # Calculate padding
106
+ pad_h0, remainder_h = divmod(height - resized_height, 2)
107
+ pad_h1 = pad_h0 + remainder_h
108
+ pad_w0, remainder_w = divmod(width - resized_width, 2)
109
+ pad_w1 = pad_w0 + remainder_w
110
+
111
+ # Pad
112
+ constant_value = 0 if images.dtype == torch.uint8 else -1.0
113
+ padded_images = F.pad(
114
+ resized_images,
115
+ (pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
116
+ mode="constant",
117
+ value=constant_value,
118
+ )
119
+
120
+ # Convert back to original format if needed
121
+ if channels_last:
122
+ padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
123
+ if batch_size == 1 and images.shape[0] == 1:
124
+ padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added
125
+
126
+ return padded_images
127
+
128
+
129
+ def replace_padding_0to1_torch(image: torch.Tensor,) -> torch.Tensor:
130
+ """PyTorch version of replace_padding_0to1.
131
+ OpenPI requires images with 0 value paddings, while VGGT series requires 1 value paddings.
132
+ Here it achieves this bounding-box based padding replacement.
133
+ Args:
134
+ image: Tensor of shape [*b, h, w, c]
135
+ Returns:
136
+ Padding-replaced tensor with same shape as input
137
+ """
138
+ single = False
139
+ if image.dim() == 3:
140
+ image = image.unsqueeze(0)
141
+ single = True
142
+
143
+ b, h, w, c = image.shape
144
+ device = image.device
145
+
146
+ nonzero_any = (image != 0).any(dim=-1)
147
+
148
+ row_any = nonzero_any.any(dim=2)
149
+ col_any = nonzero_any.any(dim=1)
150
+
151
+ top = row_any.to(torch.float32).argmax(dim=1)
152
+ bottom = h - 1 - row_any.flip(dims=[1]).to(torch.float32).argmax(dim=1)
153
+ left = col_any.to(torch.float32).argmax(dim=1)
154
+ right = w - 1 - col_any.flip(dims=[1]).to(torch.float32).argmax(dim=1)
155
+
156
+ has_any = row_any.any(dim=1)
157
+ top = torch.where(has_any, top, torch.zeros_like(top))
158
+ bottom = torch.where(has_any, bottom, torch.full_like(bottom, h - 1))
159
+ left = torch.where(has_any, left, torch.zeros_like(left))
160
+ right = torch.where(has_any, right, torch.full_like(right, w - 1))
161
+
162
+ rows = torch.arange(h, device=device).view(1, h, 1)
163
+ cols = torch.arange(w, device=device).view(1, 1, w)
164
+ top_v = top.view(b, 1, 1)
165
+ bottom_v = bottom.view(b, 1, 1)
166
+ left_v = left.view(b, 1, 1)
167
+ right_v = right.view(b, 1, 1)
168
+
169
+ row_mask = (rows >= top_v) & (rows <= bottom_v)
170
+ col_mask = (cols >= left_v) & (cols <= right_v)
171
+ inside_mask = row_mask & col_mask
172
+
173
+ padding_mask = ~inside_mask
174
+
175
+ pixel_zero = (image == 0).all(dim=-1)
176
+
177
+ final_mask = padding_mask & pixel_zero
178
+
179
+ if final_mask.any():
180
+ mask_exp = final_mask.unsqueeze(-1).expand_as(image)
181
+ one_t = torch.tensor(1, dtype=image.dtype, device=device)
182
+ image = torch.where(mask_exp, one_t, image)
183
+
184
+ if single:
185
+ image = image.squeeze(0)
186
+ return image
capvector-pi05/src/openpi/shared/image_tools_test.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax.numpy as jnp
2
+
3
+ from openpi.shared import image_tools
4
+
5
+
6
+ def test_resize_with_pad_shapes():
7
+ # Test case 1: Resize image with larger dimensions
8
+ images = jnp.zeros((2, 10, 10, 3), dtype=jnp.uint8) # Input images of shape (batch_size, height, width, channels)
9
+ height = 20
10
+ width = 20
11
+ resized_images = image_tools.resize_with_pad(images, height, width)
12
+ assert resized_images.shape == (2, height, width, 3)
13
+ assert jnp.all(resized_images == 0)
14
+
15
+ # Test case 2: Resize image with smaller dimensions
16
+ images = jnp.zeros((3, 30, 30, 3), dtype=jnp.uint8)
17
+ height = 15
18
+ width = 15
19
+ resized_images = image_tools.resize_with_pad(images, height, width)
20
+ assert resized_images.shape == (3, height, width, 3)
21
+ assert jnp.all(resized_images == 0)
22
+
23
+ # Test case 3: Resize image with the same dimensions
24
+ images = jnp.zeros((1, 50, 50, 3), dtype=jnp.uint8)
25
+ height = 50
26
+ width = 50
27
+ resized_images = image_tools.resize_with_pad(images, height, width)
28
+ assert resized_images.shape == (1, height, width, 3)
29
+ assert jnp.all(resized_images == 0)
30
+
31
+ # Test case 3: Resize image with odd-numbered padding
32
+ images = jnp.zeros((1, 256, 320, 3), dtype=jnp.uint8)
33
+ height = 60
34
+ width = 80
35
+ resized_images = image_tools.resize_with_pad(images, height, width)
36
+ assert resized_images.shape == (1, height, width, 3)
37
+ assert jnp.all(resized_images == 0)
capvector-pi05/src/openpi/shared/nnx_utils.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Callable
2
+ import dataclasses
3
+ import functools
4
+ import inspect
5
+ import re
6
+ from typing import Any, ParamSpec, TypeVar
7
+
8
+ import flax.nnx as nnx
9
+ import jax
10
+
11
+ P = ParamSpec("P")
12
+ R = TypeVar("R")
13
+
14
+
15
+ def module_jit(meth: Callable[P, R], *jit_args, **jit_kwargs) -> Callable[P, R]:
16
+ """A higher-order function to JIT-compile `nnx.Module` methods, freezing the module's state in the process.
17
+
18
+ Why not `nnx.jit`? For some reason, naively applying `nnx.jit` to `nnx.Module` methods, bound or unbound, uses much
19
+ more memory than necessary. I'm guessing it has something to do with the fact that it must keep track of module
20
+ mutations. Also, `nnx.jit` has some inherent overhead compared to a standard `jax.jit`, since every call must
21
+ traverse the NNX module graph. See https://github.com/google/flax/discussions/4224 for details.
22
+
23
+ `module_jit` is an alternative that avoids these issues by freezing the module's state. The function returned by
24
+ `module_jit` acts exactly like the original method, except that the state of the module is frozen to whatever it was
25
+ when `module_jit` was called. Mutations to the module within `meth` are still allowed, but they will be discarded
26
+ after the method call completes.
27
+ """
28
+ if not (inspect.ismethod(meth) and isinstance(meth.__self__, nnx.Module)):
29
+ raise ValueError("module_jit must only be used on bound methods of nnx.Modules.")
30
+
31
+ graphdef, state = nnx.split(meth.__self__)
32
+
33
+ def fun(state: nnx.State, *args: P.args, **kwargs: P.kwargs) -> R:
34
+ module = nnx.merge(graphdef, state)
35
+ return meth.__func__(module, *args, **kwargs)
36
+
37
+ jitted_fn = jax.jit(fun, *jit_args, **jit_kwargs)
38
+
39
+ @functools.wraps(meth)
40
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
41
+ return jitted_fn(state, *args, **kwargs)
42
+
43
+ return wrapper
44
+
45
+
46
+ @dataclasses.dataclass(frozen=True)
47
+ class PathRegex:
48
+ """NNX Filter that matches paths using a regex.
49
+
50
+ By default, paths are joined with a `/` separator. This can be overridden by setting the `sep` argument.
51
+ """
52
+
53
+ pattern: str | re.Pattern
54
+ sep: str = "/"
55
+
56
+ def __post_init__(self):
57
+ if not isinstance(self.pattern, re.Pattern):
58
+ object.__setattr__(self, "pattern", re.compile(self.pattern))
59
+
60
+ def __call__(self, path: nnx.filterlib.PathParts, x: Any) -> bool:
61
+ joined_path = self.sep.join(str(x) for x in path)
62
+ assert isinstance(self.pattern, re.Pattern)
63
+ return self.pattern.fullmatch(joined_path) is not None
64
+
65
+
66
+ def state_map(state: nnx.State, filter: nnx.filterlib.Filter, fn: Callable[[Any], Any]) -> nnx.State:
67
+ """Apply a function to the leaves of the state that match the filter."""
68
+ filtered_keys = set(state.filter(filter).flat_state())
69
+ return state.map(lambda k, v: fn(v) if k in filtered_keys else v)
capvector-pi05/src/openpi/shared/normalize.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pathlib
3
+
4
+ import numpy as np
5
+ import numpydantic
6
+ import pydantic
7
+
8
+
9
+ @pydantic.dataclasses.dataclass
10
+ class NormStats:
11
+ mean: numpydantic.NDArray
12
+ std: numpydantic.NDArray
13
+ q01: numpydantic.NDArray | None = None # 1st quantile
14
+ q99: numpydantic.NDArray | None = None # 99th quantile
15
+
16
+
17
+ class RunningStats:
18
+ """Compute running statistics of a batch of vectors."""
19
+
20
+ def __init__(self):
21
+ self._count = 0
22
+ self._mean = None
23
+ self._mean_of_squares = None
24
+ self._min = None
25
+ self._max = None
26
+ self._histograms = None
27
+ self._bin_edges = None
28
+ self._num_quantile_bins = 5000 # for computing quantiles on the fly
29
+
30
+ def update(self, batch: np.ndarray) -> None:
31
+ """
32
+ Update the running statistics with a batch of vectors.
33
+
34
+ Args:
35
+ vectors (np.ndarray): An array where all dimensions except the last are batch dimensions.
36
+ """
37
+ batch = batch.reshape(-1, batch.shape[-1])
38
+ num_elements, vector_length = batch.shape
39
+ if self._count == 0:
40
+ self._mean = np.mean(batch, axis=0)
41
+ self._mean_of_squares = np.mean(batch**2, axis=0)
42
+ self._min = np.min(batch, axis=0)
43
+ self._max = np.max(batch, axis=0)
44
+ self._histograms = [np.zeros(self._num_quantile_bins) for _ in range(vector_length)]
45
+ self._bin_edges = [
46
+ np.linspace(self._min[i] - 1e-10, self._max[i] + 1e-10, self._num_quantile_bins + 1)
47
+ for i in range(vector_length)
48
+ ]
49
+ else:
50
+ if vector_length != self._mean.size:
51
+ raise ValueError("The length of new vectors does not match the initialized vector length.")
52
+ new_max = np.max(batch, axis=0)
53
+ new_min = np.min(batch, axis=0)
54
+ max_changed = np.any(new_max > self._max)
55
+ min_changed = np.any(new_min < self._min)
56
+ self._max = np.maximum(self._max, new_max)
57
+ self._min = np.minimum(self._min, new_min)
58
+
59
+ if max_changed or min_changed:
60
+ self._adjust_histograms()
61
+
62
+ self._count += num_elements
63
+
64
+ batch_mean = np.mean(batch, axis=0)
65
+ batch_mean_of_squares = np.mean(batch**2, axis=0)
66
+
67
+ # Update running mean and mean of squares.
68
+ self._mean += (batch_mean - self._mean) * (num_elements / self._count)
69
+ self._mean_of_squares += (batch_mean_of_squares - self._mean_of_squares) * (num_elements / self._count)
70
+
71
+ self._update_histograms(batch)
72
+
73
+ def get_statistics(self) -> NormStats:
74
+ """
75
+ Compute and return the statistics of the vectors processed so far.
76
+
77
+ Returns:
78
+ dict: A dictionary containing the computed statistics.
79
+ """
80
+ if self._count < 2:
81
+ raise ValueError("Cannot compute statistics for less than 2 vectors.")
82
+
83
+ variance = self._mean_of_squares - self._mean**2
84
+ stddev = np.sqrt(np.maximum(0, variance))
85
+ q01, q99 = self._compute_quantiles([0.01, 0.99])
86
+ return NormStats(mean=self._mean, std=stddev, q01=q01, q99=q99)
87
+
88
+ def _adjust_histograms(self):
89
+ """Adjust histograms when min or max changes."""
90
+ for i in range(len(self._histograms)):
91
+ old_edges = self._bin_edges[i]
92
+ new_edges = np.linspace(self._min[i], self._max[i], self._num_quantile_bins + 1)
93
+
94
+ # Redistribute the existing histogram counts to the new bins
95
+ new_hist, _ = np.histogram(old_edges[:-1], bins=new_edges, weights=self._histograms[i])
96
+
97
+ self._histograms[i] = new_hist
98
+ self._bin_edges[i] = new_edges
99
+
100
+ def _update_histograms(self, batch: np.ndarray) -> None:
101
+ """Update histograms with new vectors."""
102
+ for i in range(batch.shape[1]):
103
+ hist, _ = np.histogram(batch[:, i], bins=self._bin_edges[i])
104
+ self._histograms[i] += hist
105
+
106
+ def _compute_quantiles(self, quantiles):
107
+ """Compute quantiles based on histograms."""
108
+ results = []
109
+ for q in quantiles:
110
+ target_count = q * self._count
111
+ q_values = []
112
+ for hist, edges in zip(self._histograms, self._bin_edges, strict=True):
113
+ cumsum = np.cumsum(hist)
114
+ idx = np.searchsorted(cumsum, target_count)
115
+ q_values.append(edges[idx])
116
+ results.append(np.array(q_values))
117
+ return results
118
+
119
+
120
+ class _NormStatsDict(pydantic.BaseModel):
121
+ norm_stats: dict[str, NormStats]
122
+
123
+
124
+ def serialize_json(norm_stats: dict[str, NormStats]) -> str:
125
+ """Serialize the running statistics to a JSON string."""
126
+ return _NormStatsDict(norm_stats=norm_stats).model_dump_json(indent=2)
127
+
128
+
129
+ def deserialize_json(data: str) -> dict[str, NormStats]:
130
+ """Deserialize the running statistics from a JSON string."""
131
+ return _NormStatsDict(**json.loads(data)).norm_stats
132
+
133
+
134
+ def save(directory: pathlib.Path | str, norm_stats: dict[str, NormStats]) -> None:
135
+ """Save the normalization stats to a directory."""
136
+ path = pathlib.Path(directory) / "norm_stats.json"
137
+ path.parent.mkdir(parents=True, exist_ok=True)
138
+ path.write_text(serialize_json(norm_stats))
139
+
140
+
141
+ def load(directory: pathlib.Path | str) -> dict[str, NormStats]:
142
+ """Load the normalization stats from a directory."""
143
+ path = pathlib.Path(directory) / "norm_stats.json"
144
+ if not path.exists():
145
+ raise FileNotFoundError(f"Norm stats file not found at: {path}")
146
+ return deserialize_json(path.read_text())
capvector-pi05/src/openpi/shared/normalize_test.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import openpi.shared.normalize as normalize
4
+
5
+
6
+ def test_normalize_update():
7
+ arr = np.arange(12).reshape(4, 3) # 4 vectors of length 3
8
+
9
+ stats = normalize.RunningStats()
10
+ for i in range(len(arr)):
11
+ stats.update(arr[i : i + 1]) # Update with one vector at a time
12
+ results = stats.get_statistics()
13
+
14
+ assert np.allclose(results.mean, np.mean(arr, axis=0))
15
+ assert np.allclose(results.std, np.std(arr, axis=0))
16
+
17
+
18
+ def test_serialize_deserialize():
19
+ stats = normalize.RunningStats()
20
+ stats.update(np.arange(12).reshape(4, 3)) # 4 vectors of length 3
21
+
22
+ norm_stats = {"test": stats.get_statistics()}
23
+ norm_stats2 = normalize.deserialize_json(normalize.serialize_json(norm_stats))
24
+ assert np.allclose(norm_stats["test"].mean, norm_stats2["test"].mean)
25
+ assert np.allclose(norm_stats["test"].std, norm_stats2["test"].std)
26
+
27
+
28
+ def test_multiple_batch_dimensions():
29
+ # Test with multiple batch dimensions: (2, 3, 4) where 4 is vector dimension
30
+ batch_shape = (2, 3, 4)
31
+ arr = np.random.rand(*batch_shape)
32
+
33
+ stats = normalize.RunningStats()
34
+ stats.update(arr) # Should handle (2, 3, 4) -> reshape to (6, 4)
35
+ results = stats.get_statistics()
36
+
37
+ # Flatten batch dimensions and compute expected stats
38
+ flattened = arr.reshape(-1, arr.shape[-1]) # (6, 4)
39
+ expected_mean = np.mean(flattened, axis=0)
40
+ expected_std = np.std(flattened, axis=0)
41
+
42
+ assert np.allclose(results.mean, expected_mean)
43
+ assert np.allclose(results.std, expected_std)
capvector-pi05/src/openpi/training/checkpoints.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import concurrent.futures as futures
5
+ import dataclasses
6
+ import logging
7
+ from typing import Protocol
8
+
9
+ from etils import epath
10
+ import jax
11
+ import orbax.checkpoint as ocp
12
+ import orbax.checkpoint.future as future
13
+
14
+ from openpi.shared import array_typing as at
15
+ import openpi.shared.normalize as _normalize
16
+ import openpi.training.data_loader as _data_loader
17
+ import openpi.training.utils as training_utils
18
+
19
+
20
+ def initialize_checkpoint_dir(
21
+ checkpoint_dir: epath.Path | str, *, keep_period: int | None, overwrite: bool, resume: bool
22
+ ) -> tuple[ocp.CheckpointManager, bool]:
23
+ checkpoint_dir = epath.Path(checkpoint_dir).resolve()
24
+ resuming = False
25
+ if checkpoint_dir.exists():
26
+ if overwrite:
27
+ checkpoint_dir.rmtree()
28
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
29
+ logging.info(f"Wiped checkpoint directory {checkpoint_dir}")
30
+ elif resume:
31
+ resuming = True
32
+ else:
33
+ raise FileExistsError(
34
+ f"Checkpoint directory {checkpoint_dir} already exists. Use --overwrite or --resume "
35
+ "to indicate how to handle it."
36
+ )
37
+
38
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
39
+
40
+ mngr = ocp.CheckpointManager(
41
+ checkpoint_dir,
42
+ item_handlers={
43
+ "assets": CallbackHandler(),
44
+ "train_state": ocp.PyTreeCheckpointHandler(),
45
+ "params": ocp.PyTreeCheckpointHandler(),
46
+ },
47
+ options=ocp.CheckpointManagerOptions(
48
+ max_to_keep=1,
49
+ keep_period=keep_period,
50
+ create=False,
51
+ async_options=ocp.AsyncOptions(timeout_secs=7200),
52
+ ),
53
+ )
54
+
55
+ # Special case: the checkpoint directory exists and the user requests to resume training, but the training run did
56
+ # not get to the first checkpoint saved. In this case, we don't actually want the train script to try and restore a
57
+ # checkpoint, since it will fail.
58
+ if resuming and tuple(mngr.all_steps()) in [(), (0,)]:
59
+ logging.info("Checkpoint directory exists, but does not contain any checkpoints. Aborting resume.")
60
+ resuming = False
61
+
62
+ return mngr, resuming
63
+
64
+
65
+ def save_state(
66
+ checkpoint_manager: ocp.CheckpointManager,
67
+ state: training_utils.TrainState,
68
+ data_loader: _data_loader.DataLoader,
69
+ step: int,
70
+ ):
71
+ def save_assets(directory: epath.Path):
72
+ # Save the normalization stats.
73
+ data_config = data_loader.data_config()
74
+ norm_stats = data_config.norm_stats
75
+ if norm_stats is not None and data_config.asset_id is not None:
76
+ _normalize.save(directory / data_config.asset_id, norm_stats)
77
+
78
+ # Split params that can be used for inference into a separate item.
79
+ with at.disable_typechecking():
80
+ train_state, params = _split_params(state)
81
+ items = {
82
+ "assets": save_assets,
83
+ "train_state": train_state,
84
+ "params": {"params": params},
85
+ }
86
+ checkpoint_manager.save(step, items)
87
+
88
+
89
+ def restore_state(
90
+ checkpoint_manager: ocp.CheckpointManager,
91
+ state: training_utils.TrainState,
92
+ data_loader: _data_loader.DataLoader,
93
+ step: int | None = None,
94
+ ) -> training_utils.TrainState:
95
+ del data_loader
96
+
97
+ with at.disable_typechecking():
98
+ # Split params that can be used for inference into a separate item.
99
+ train_state, params = _split_params(state)
100
+ restored = checkpoint_manager.restore(
101
+ step,
102
+ items={
103
+ "train_state": train_state,
104
+ "params": {"params": params},
105
+ },
106
+ )
107
+ return _merge_params(restored["train_state"], restored["params"])
108
+
109
+
110
+ def load_norm_stats(assets_dir: epath.Path | str, asset_id: str) -> dict[str, _normalize.NormStats] | None:
111
+ norm_stats_dir = epath.Path(assets_dir) / asset_id
112
+ norm_stats = _normalize.load(norm_stats_dir)
113
+ logging.info(f"Loaded norm stats from {norm_stats_dir}")
114
+ return norm_stats
115
+
116
+
117
+ class Callback(Protocol):
118
+ def __call__(self, directory: epath.Path) -> None: ...
119
+
120
+
121
+ class CallbackHandler(ocp.AsyncCheckpointHandler):
122
+ """A CheckpointHandler for calling an arbitrary function asynchronously. Only for saving, not for restoring."""
123
+
124
+ def save(self, directory: epath.Path, args: CallbackSave):
125
+ if jax.process_index() == 0:
126
+ args.callback(directory)
127
+
128
+ async def async_save(self, directory: epath.Path, args: CallbackSave) -> list[futures.Future]:
129
+ return [future.CommitFutureAwaitingContractedSignals(asyncio.to_thread(self.save, directory, args))]
130
+
131
+ def restore(self, *args, **kwargs):
132
+ raise NotImplementedError("CallbackHandler does not support restore")
133
+
134
+
135
+ @ocp.args.register_with_handler(CallbackHandler, for_save=True)
136
+ @dataclasses.dataclass
137
+ class CallbackSave(ocp.args.CheckpointArgs):
138
+ callback: Callback
139
+
140
+
141
+ @ocp.args.register_with_handler(CallbackHandler, for_restore=True)
142
+ class CallbackRestore(ocp.args.CheckpointArgs): ...
143
+
144
+
145
+ def _split_params(state: training_utils.TrainState) -> tuple[training_utils.TrainState, at.Params]:
146
+ if state.ema_params is not None:
147
+ params = state.ema_params
148
+ train_state = dataclasses.replace(state, ema_params=None)
149
+ else:
150
+ params = state.params
151
+ train_state = dataclasses.replace(state, params={})
152
+ return train_state, params
153
+
154
+
155
+ def _merge_params(train_state: training_utils.TrainState, params: dict[str, at.Params]) -> training_utils.TrainState:
156
+ # Revert the logic inside `_split_params`. Assumes that existence of `params` means that EMA params were used during the split.
157
+ if train_state.params:
158
+ return dataclasses.replace(train_state, ema_params=params["params"])
159
+ return dataclasses.replace(train_state, params=params["params"])
capvector-pi05/src/openpi/training/config.py ADDED
@@ -0,0 +1,1033 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """See _CONFIGS for the list of available configs."""
2
+
3
+ import abc
4
+ from collections.abc import Sequence
5
+ import dataclasses
6
+ import difflib
7
+ import logging
8
+ import pathlib
9
+ from typing import Any, Literal, Protocol, TypeAlias
10
+
11
+ import etils.epath as epath
12
+ import flax.nnx as nnx
13
+ from typing_extensions import override
14
+ import tyro
15
+
16
+ import openpi.models.model as _model
17
+ import openpi.models.pi0_config as pi0_config
18
+ import openpi.models.pi0_fast as pi0_fast
19
+ import openpi.models.tokenizer as _tokenizer
20
+ import openpi.policies.aloha_policy as aloha_policy
21
+ import openpi.policies.droid_policy as droid_policy
22
+ import openpi.policies.libero_policy as libero_policy
23
+ import openpi.shared.download as _download
24
+ import openpi.shared.normalize as _normalize
25
+ import openpi.training.droid_rlds_dataset as droid_rlds_dataset
26
+ import openpi.training.misc.roboarena_config as roboarena_config
27
+ import openpi.training.optimizer as _optimizer
28
+ import openpi.training.weight_loaders as weight_loaders
29
+ import openpi.transforms as _transforms
30
+
31
+ ModelType: TypeAlias = _model.ModelType
32
+ # Work around a tyro issue with using nnx.filterlib.Filter directly.
33
+ Filter: TypeAlias = nnx.filterlib.Filter
34
+
35
+
36
+ @dataclasses.dataclass(frozen=True)
37
+ class AssetsConfig:
38
+ """Determines the location of assets (e.g., norm stats) that will be used to set up the data pipeline.
39
+
40
+ These assets will be replicated inside the checkpoint under the `assets/asset_id` directory.
41
+
42
+ This can be used to load assets from a different checkpoint (e.g., base model checkpoint) or some other
43
+ centralized location. For example, to load the norm stats for the Trossen robot from the base model checkpoint
44
+ during fine-tuning, use:
45
+
46
+ ```
47
+ AssetsConfig(
48
+ assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
49
+ asset_id="trossen",
50
+ )
51
+ ```
52
+ """
53
+
54
+ # Assets directory. If not provided, the config assets_dirs will be used. This is useful to load assets from
55
+ # a different checkpoint (e.g., base model checkpoint) or some other centralized location.
56
+ assets_dir: str | None = None
57
+
58
+ # Asset id. If not provided, the repo id will be used. This allows users to reference assets that describe
59
+ # different robot platforms.
60
+ asset_id: str | None = None
61
+
62
+
63
+ @dataclasses.dataclass(frozen=True)
64
+ class DataConfig:
65
+ # LeRobot repo id. If None, fake data will be created.
66
+ repo_id: str | None = None
67
+ # Directory within the assets directory containing the data assets.
68
+ asset_id: str | None = None
69
+ # Contains precomputed normalization stats. If None, normalization will not be performed.
70
+ norm_stats: dict[str, _transforms.NormStats] | None = None
71
+
72
+ # Used to adopt the inputs from a dataset specific format to a common format
73
+ # which is expected by the data transforms.
74
+ repack_transforms: _transforms.Group = dataclasses.field(default_factory=_transforms.Group)
75
+ # Data transforms, typically include robot specific transformations. Will be applied
76
+ # before the data is normalized. See `model.Observation` and `model.Actions` to learn about the
77
+ # normalized data.
78
+ data_transforms: _transforms.Group = dataclasses.field(default_factory=_transforms.Group)
79
+ # Model specific transforms. Will be applied after the data is normalized.
80
+ model_transforms: _transforms.Group = dataclasses.field(default_factory=_transforms.Group)
81
+ # If true, will use quantile normalization. Otherwise, normal z-score normalization will be used.
82
+ use_quantile_norm: bool = False
83
+
84
+ # Names of keys that will be used by the data loader to generate the action sequence. The length of the
85
+ # sequence is defined by the `action_horizon` field in the model config. This should be adjusted if your
86
+ # LeRobot dataset is using different keys to represent the action.
87
+ action_sequence_keys: Sequence[str] = ("actions",)
88
+
89
+ # If true, will use the LeRobot dataset task to define the prompt.
90
+ prompt_from_task: bool = False
91
+
92
+ # Only used for RLDS data loader (ie currently only used for DROID).
93
+ rlds_data_dir: str | None = None
94
+ # Action space for DROID dataset.
95
+ action_space: droid_rlds_dataset.DroidActionSpace | None = None
96
+ # Path to the data filter file for DROID dataset
97
+ filter_dict_path: str | None = None
98
+
99
+
100
+ class GroupFactory(Protocol):
101
+ def __call__(self, model_config: _model.BaseModelConfig) -> _transforms.Group:
102
+ """Create a group."""
103
+
104
+
105
+ @dataclasses.dataclass(frozen=True)
106
+ class ModelTransformFactory(GroupFactory):
107
+ """Creates model transforms for standard pi0 models."""
108
+
109
+ # If provided, will determine the default prompt that be used by the model.
110
+ default_prompt: str | None = None
111
+
112
+ def __call__(self, model_config: _model.BaseModelConfig) -> _transforms.Group:
113
+ match model_config.model_type:
114
+ case _model.ModelType.PI0:
115
+ return _transforms.Group(
116
+ inputs=[
117
+ _transforms.InjectDefaultPrompt(self.default_prompt),
118
+ _transforms.ResizeImages(224, 224),
119
+ _transforms.TokenizePrompt(
120
+ _tokenizer.PaligemmaTokenizer(model_config.max_token_len),
121
+ ),
122
+ _transforms.PadStatesAndActions(model_config.action_dim),
123
+ ],
124
+ )
125
+ case _model.ModelType.PI05:
126
+ assert isinstance(model_config, pi0_config.Pi0Config)
127
+ return _transforms.Group(
128
+ inputs=[
129
+ _transforms.InjectDefaultPrompt(self.default_prompt),
130
+ _transforms.ResizeImages(224, 224),
131
+ _transforms.TokenizePrompt(
132
+ _tokenizer.PaligemmaTokenizer(model_config.max_token_len),
133
+ discrete_state_input=model_config.discrete_state_input,
134
+ ),
135
+ _transforms.PadStatesAndActions(model_config.action_dim),
136
+ ],
137
+ )
138
+ case _model.ModelType.PI0_FAST:
139
+ tokenizer_cls = (
140
+ _tokenizer.FASTTokenizer
141
+ if model_config.fast_model_tokenizer is None
142
+ else model_config.fast_model_tokenizer
143
+ )
144
+ tokenizer_kwargs = (
145
+ {} if model_config.fast_model_tokenizer_kwargs is None else model_config.fast_model_tokenizer_kwargs
146
+ )
147
+ return _transforms.Group(
148
+ inputs=[
149
+ _transforms.InjectDefaultPrompt(self.default_prompt),
150
+ _transforms.ResizeImages(224, 224),
151
+ _transforms.TokenizeFASTInputs(
152
+ tokenizer_cls(model_config.max_token_len, **tokenizer_kwargs),
153
+ ),
154
+ ],
155
+ outputs=[
156
+ _transforms.ExtractFASTActions(
157
+ tokenizer_cls(model_config.max_token_len, **tokenizer_kwargs),
158
+ action_horizon=model_config.action_horizon,
159
+ action_dim=model_config.action_dim,
160
+ )
161
+ ],
162
+ )
163
+
164
+
165
+ @dataclasses.dataclass(frozen=True)
166
+ class DataConfigFactory(abc.ABC):
167
+ # The LeRobot repo id.
168
+ repo_id: str = tyro.MISSING
169
+ # Determines how the assets will be loaded.
170
+ assets: AssetsConfig = dataclasses.field(default_factory=AssetsConfig)
171
+ # Base config that will be updated by the factory.
172
+ base_config: tyro.conf.Suppress[DataConfig | None] = None
173
+
174
+ @abc.abstractmethod
175
+ def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
176
+ """Create a data config."""
177
+
178
+ def create_base_config(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
179
+ repo_id = self.repo_id if self.repo_id is not tyro.MISSING else None
180
+ asset_id = self.assets.asset_id or repo_id
181
+ return dataclasses.replace(
182
+ self.base_config or DataConfig(),
183
+ repo_id=repo_id,
184
+ asset_id=asset_id,
185
+ norm_stats=self._load_norm_stats(epath.Path(self.assets.assets_dir or assets_dirs), asset_id),
186
+ use_quantile_norm=model_config.model_type != ModelType.PI0,
187
+ )
188
+
189
+ def _load_norm_stats(self, assets_dir: epath.Path, asset_id: str | None) -> dict[str, _transforms.NormStats] | None:
190
+ if asset_id is None:
191
+ return None
192
+ try:
193
+ data_assets_dir = str(assets_dir / asset_id)
194
+ norm_stats = _normalize.load(_download.maybe_download(data_assets_dir))
195
+ logging.info(f"Loaded norm stats from {data_assets_dir}")
196
+ return norm_stats
197
+ except FileNotFoundError:
198
+ logging.info(f"Norm stats not found in {data_assets_dir}, skipping.")
199
+ return None
200
+
201
+
202
+ @dataclasses.dataclass(frozen=True)
203
+ class FakeDataConfig(DataConfigFactory):
204
+ repo_id: str = "fake"
205
+
206
+ @override
207
+ def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
208
+ return DataConfig(repo_id=self.repo_id)
209
+
210
+
211
+ @dataclasses.dataclass(frozen=True)
212
+ class SimpleDataConfig(DataConfigFactory):
213
+ # Factory for the data transforms.
214
+ data_transforms: tyro.conf.Suppress[GroupFactory] = dataclasses.field(default_factory=GroupFactory)
215
+ # Factory for the model transforms.
216
+ model_transforms: tyro.conf.Suppress[GroupFactory] = dataclasses.field(default_factory=ModelTransformFactory)
217
+
218
+ @override
219
+ def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
220
+ return dataclasses.replace(
221
+ self.create_base_config(assets_dirs, model_config),
222
+ data_transforms=self.data_transforms(model_config),
223
+ model_transforms=self.model_transforms(model_config),
224
+ )
225
+
226
+
227
+ @dataclasses.dataclass(frozen=True)
228
+ class LeRobotAlohaDataConfig(DataConfigFactory):
229
+ # If true, will convert joint dimensions to deltas with respect to the current state before passing to the model.
230
+ # Gripper dimensions will remain in absolute values.
231
+ use_delta_joint_actions: bool = True
232
+ # If provided, will be injected into the input data if the "prompt" key is not present.
233
+ default_prompt: str | None = None
234
+ # If true, this will convert the joint and gripper values from the standard Aloha space to
235
+ # the space used by the pi internal runtime which was used to train the base model. People who
236
+ # use standard Aloha data should set this to true.
237
+ adapt_to_pi: bool = True
238
+
239
+ # Repack transforms.
240
+ repack_transforms: tyro.conf.Suppress[_transforms.Group] = dataclasses.field(
241
+ default=_transforms.Group(
242
+ inputs=[
243
+ _transforms.RepackTransform(
244
+ {
245
+ "images": {"cam_high": "observation.images.top"},
246
+ "state": "observation.state",
247
+ "actions": "action",
248
+ }
249
+ )
250
+ ]
251
+ )
252
+ )
253
+ # Action keys that will be used to read the action sequence from the dataset.
254
+ action_sequence_keys: Sequence[str] = ("action",)
255
+
256
+ @override
257
+ def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
258
+ data_transforms = _transforms.Group(
259
+ inputs=[aloha_policy.AlohaInputs(adapt_to_pi=self.adapt_to_pi)],
260
+ outputs=[aloha_policy.AlohaOutputs(adapt_to_pi=self.adapt_to_pi)],
261
+ )
262
+ if self.use_delta_joint_actions:
263
+ delta_action_mask = _transforms.make_bool_mask(6, -1, 6, -1)
264
+ data_transforms = data_transforms.push(
265
+ inputs=[_transforms.DeltaActions(delta_action_mask)],
266
+ outputs=[_transforms.AbsoluteActions(delta_action_mask)],
267
+ )
268
+
269
+ model_transforms = ModelTransformFactory(default_prompt=self.default_prompt)(model_config)
270
+
271
+ return dataclasses.replace(
272
+ self.create_base_config(assets_dirs, model_config),
273
+ repack_transforms=self.repack_transforms,
274
+ data_transforms=data_transforms,
275
+ model_transforms=model_transforms,
276
+ action_sequence_keys=self.action_sequence_keys,
277
+ )
278
+
279
+
280
+ @dataclasses.dataclass(frozen=True)
281
+ class LeRobotLiberoDataConfig(DataConfigFactory):
282
+ """
283
+ This config is used to configure transforms that are applied at various parts of the data pipeline.
284
+ For your own dataset, you can copy this class and modify the transforms to match your dataset based on the
285
+ comments below.
286
+ """
287
+
288
+ extra_delta_transform: bool = False
289
+
290
+ @override
291
+ def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
292
+ # The repack transform is *only* applied to the data coming from the dataset,
293
+ # and *not* during inference. We can use it to make inputs from the dataset look
294
+ # as close as possible to those coming from the inference environment (e.g. match the keys).
295
+ # Below, we match the keys in the dataset (which we defined in the data conversion script) to
296
+ # the keys we use in our inference pipeline (defined in the inference script for libero).
297
+ # For your own dataset, first figure out what keys your environment passes to the policy server
298
+ # and then modify the mappings below so your dataset's keys get matched to those target keys.
299
+ # The repack transform simply remaps key names here.
300
+ repack_transform = _transforms.Group(
301
+ inputs=[
302
+ _transforms.RepackTransform(
303
+ {
304
+ "observation/image": "image",
305
+ "observation/wrist_image": "wrist_image",
306
+ "observation/state": "state",
307
+ "actions": "actions",
308
+ "prompt": "prompt",
309
+ }
310
+ )
311
+ ]
312
+ )
313
+
314
+ # The data transforms are applied to the data coming from the dataset *and* during inference.
315
+ # Below, we define the transforms for data going into the model (``inputs``) and the transforms
316
+ # for data coming out of the model (``outputs``) (the latter is only used during inference).
317
+ # We defined these transforms in `libero_policy.py`. You can check the detailed comments there for
318
+ # how to modify the transforms to match your dataset. Once you created your own transforms, you can
319
+ # replace the transforms below with your own.
320
+ data_transforms = _transforms.Group(
321
+ inputs=[libero_policy.LiberoInputs(model_type=model_config.model_type)],
322
+ outputs=[libero_policy.LiberoOutputs()],
323
+ )
324
+
325
+ # One additional data transform: pi0 models are trained on delta actions (relative to the first
326
+ # state in each action chunk). IF your data has ``absolute`` actions (e.g. target joint angles)
327
+ # you can uncomment the following line to convert the actions to delta actions. The only exception
328
+ # is for the gripper actions which are always absolute.
329
+ # In the example below, we would apply the delta conversion to the first 6 actions (joints) and
330
+ # leave the 7th action (gripper) unchanged, i.e. absolute.
331
+ # In Libero, the raw actions in the dataset are already delta actions, so we *do not* need to
332
+ # apply a separate delta conversion (that's why it's commented out). Choose whether to apply this
333
+ # transform based on whether your dataset uses ``absolute`` or ``delta`` actions out of the box.
334
+
335
+ # LIBERO already represents actions as deltas, but we have some old Pi0 checkpoints that are trained with this
336
+ # extra delta transform.
337
+ if self.extra_delta_transform:
338
+ delta_action_mask = _transforms.make_bool_mask(6, -1)
339
+ data_transforms = data_transforms.push(
340
+ inputs=[_transforms.DeltaActions(delta_action_mask)],
341
+ outputs=[_transforms.AbsoluteActions(delta_action_mask)],
342
+ )
343
+
344
+ # Model transforms include things like tokenizing the prompt and action targets
345
+ # You do not need to change anything here for your own dataset.
346
+ model_transforms = ModelTransformFactory()(model_config)
347
+
348
+ # We return all data transforms for training and inference. No need to change anything here.
349
+ return dataclasses.replace(
350
+ self.create_base_config(assets_dirs, model_config),
351
+ repack_transforms=repack_transform,
352
+ data_transforms=data_transforms,
353
+ model_transforms=model_transforms,
354
+ )
355
+
356
+
357
+ @dataclasses.dataclass(frozen=True)
358
+ class RLDSDroidDataConfig(DataConfigFactory):
359
+ """
360
+ Config for training on DROID, using RLDS data format (for efficient training on larger datasets).
361
+ """
362
+
363
+ rlds_data_dir: str | None = None
364
+ action_space: droid_rlds_dataset.DroidActionSpace | None = None
365
+
366
+ # Filtering options. Can pass a path to a dictionary that maps episodes to timestep ranges
367
+ # to tuples denoting ranges of time steps to keep (start, end). Episodes are uniquely identified with
368
+ # f"{recording_folderpath}--{file_path}", both of which are present in the RLDS episode metadata.
369
+ # Path to the filter dictionary file.
370
+ filter_dict_path: str | None = "gs://openpi-assets/droid/droid_sample_ranges_v1_0_1.json"
371
+
372
+ @override
373
+ def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
374
+ repack_transform = _transforms.Group(
375
+ inputs=[
376
+ _transforms.RepackTransform(
377
+ {
378
+ "observation/exterior_image_1_left": "observation/image",
379
+ "observation/wrist_image_left": "observation/wrist_image",
380
+ "observation/joint_position": "observation/joint_position",
381
+ "observation/gripper_position": "observation/gripper_position",
382
+ "actions": "actions",
383
+ "prompt": "prompt",
384
+ }
385
+ )
386
+ ]
387
+ )
388
+
389
+ data_transforms = _transforms.Group(
390
+ inputs=[droid_policy.DroidInputs(model_type=model_config.model_type)],
391
+ outputs=[droid_policy.DroidOutputs()],
392
+ )
393
+
394
+ if self.action_space == droid_rlds_dataset.DroidActionSpace.JOINT_POSITION:
395
+ # Data loader returns absolute joint position actions -- convert to delta actions for training.
396
+ delta_action_mask = _transforms.make_bool_mask(7, -1)
397
+ data_transforms = data_transforms.push(
398
+ inputs=[_transforms.DeltaActions(delta_action_mask)],
399
+ outputs=[_transforms.AbsoluteActions(delta_action_mask)],
400
+ )
401
+
402
+ model_transforms = ModelTransformFactory()(model_config)
403
+
404
+ assert self.rlds_data_dir is not None, "Need to set rlds data dir for RLDS data loader."
405
+
406
+ return dataclasses.replace(
407
+ self.create_base_config(assets_dirs, model_config),
408
+ repack_transforms=repack_transform,
409
+ data_transforms=data_transforms,
410
+ model_transforms=model_transforms,
411
+ rlds_data_dir=self.rlds_data_dir,
412
+ action_space=self.action_space,
413
+ filter_dict_path=self.filter_dict_path,
414
+ )
415
+
416
+
417
+ @dataclasses.dataclass(frozen=True)
418
+ class LeRobotDROIDDataConfig(DataConfigFactory):
419
+ """
420
+ Example data config for custom DROID dataset in LeRobot format.
421
+ To convert your custom DROID dataset (<10s of hours) to LeRobot format, see examples/droid/convert_droid_data_to_lerobot.py
422
+ """
423
+
424
+ @override
425
+ def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
426
+ repack_transform = _transforms.Group(
427
+ inputs=[
428
+ _transforms.RepackTransform(
429
+ {
430
+ "observation/exterior_image_1_left": "exterior_image_1_left",
431
+ "observation/exterior_image_2_left": "exterior_image_2_left",
432
+ "observation/wrist_image_left": "wrist_image_left",
433
+ "observation/joint_position": "joint_position",
434
+ "observation/gripper_position": "gripper_position",
435
+ "actions": "actions",
436
+ "prompt": "prompt",
437
+ }
438
+ )
439
+ ]
440
+ )
441
+ # We assume joint *velocity* actions, so we should *not* apply an additional delta transform.
442
+ data_transforms = _transforms.Group(
443
+ inputs=[droid_policy.DroidInputs(model_type=model_config.model_type)],
444
+ outputs=[droid_policy.DroidOutputs()],
445
+ )
446
+ model_transforms = ModelTransformFactory()(model_config)
447
+
448
+ return dataclasses.replace(
449
+ self.create_base_config(assets_dirs, model_config),
450
+ repack_transforms=repack_transform,
451
+ data_transforms=data_transforms,
452
+ model_transforms=model_transforms,
453
+ )
454
+
455
+
456
+ @dataclasses.dataclass(frozen=True)
457
+ class TrainConfig:
458
+ # Name of the config. Must be unique. Will be used to reference this config.
459
+ name: tyro.conf.Suppress[str]
460
+ # Project name.
461
+ project_name: str = "openpi"
462
+ # Experiment name. Will be used to name the metadata and checkpoint directories.
463
+ exp_name: str = tyro.MISSING
464
+
465
+ # Defines the model config. Some attributes (action_dim, action_horizon, and max_token_len) are shared by all models
466
+ # -- see BaseModelConfig. Specific model implementations (e.g., Pi0Config) inherit from BaseModelConfig and may
467
+ # define additional attributes.
468
+ model: _model.BaseModelConfig = dataclasses.field(default_factory=pi0_config.Pi0Config)
469
+
470
+ # A weight loader can optionally load (possibly partial) weights from disk after the model is initialized.
471
+ weight_loader: weight_loaders.WeightLoader = dataclasses.field(default_factory=weight_loaders.NoOpWeightLoader)
472
+
473
+ # Optional path to a PyTorch checkpoint to load weights from.
474
+ pytorch_weight_path: str | None = None
475
+
476
+ # Spatial Forcing configs
477
+ vggt_weight_path: str | None = None
478
+ vggt_dim: int = 1024
479
+
480
+ vla_layers_align: int | None = None # total 18 for paligemma-2b
481
+ vggt_layers_align: int | None = None # total 24 for VGGT
482
+
483
+ pooling_func: str | None = None
484
+ use_vggt_pe: bool | None = None
485
+ use_vlm_norm: bool | None = None
486
+
487
+ align_loss_coeff: float = 0.0
488
+
489
+ # CapVector configs
490
+ regularization_vector_path: str | None = None
491
+ regularization_coeff: float = 0.0
492
+
493
+ # Precision for PyTorch training.
494
+ pytorch_training_precision: Literal["bfloat16", "float32"] = "bfloat16"
495
+
496
+ lr_schedule: _optimizer.LRScheduleConfig = dataclasses.field(default_factory=_optimizer.CosineDecaySchedule)
497
+ optimizer: _optimizer.OptimizerConfig = dataclasses.field(default_factory=_optimizer.AdamW)
498
+ ema_decay: float | None = 0.99
499
+
500
+ # Specifies which weights should be frozen.
501
+ freeze_filter: tyro.conf.Suppress[Filter] = dataclasses.field(default_factory=nnx.Nothing)
502
+
503
+ # Determines the data to be trained on.
504
+ data: DataConfigFactory = dataclasses.field(default_factory=FakeDataConfig)
505
+
506
+ # Base directory for config assets (e.g., norm stats).
507
+ assets_base_dir: str = "./assets"
508
+ # Base directory for checkpoints.
509
+ checkpoint_base_dir: str = "./checkpoints"
510
+
511
+ # Random seed that will be used by random generators during training.
512
+ seed: int = 42
513
+ # Global batch size.
514
+ batch_size: int = 32
515
+ # Number of workers to use for the data loader. Increasing this number will speed up data loading but
516
+ # will increase memory and CPU usage.
517
+ num_workers: int = 2
518
+ # Number of train steps (batches) to run.
519
+ num_train_steps: int = 30_000
520
+
521
+ # How often (in steps) to log training metrics.
522
+ log_interval: int = 100
523
+ # How often (in steps) to save checkpoints.
524
+ save_interval: int = 1000
525
+ # If set, any existing checkpoints matching step % keep_period == 0 will not be deleted.
526
+ keep_period: int | None = 5000
527
+
528
+ # If true, will overwrite the checkpoint directory if it already exists.
529
+ overwrite: bool = False
530
+ # If true, will resume training from the last checkpoint.
531
+ resume: bool = False
532
+
533
+ # If true, will enable wandb logging.
534
+ wandb_enabled: bool = True
535
+
536
+ # Used to pass metadata to the policy server.
537
+ policy_metadata: dict[str, Any] | None = None
538
+
539
+ # If the value is greater than 1, FSDP will be enabled and shard across number of specified devices; overall
540
+ # device memory will be reduced but training could potentially be slower.
541
+ # eg. if total device is 4 and fsdp devices is 2; then the model will shard to 2 devices and run
542
+ # data parallel between 2 groups of devices.
543
+ fsdp_devices: int = 1
544
+
545
+ @property
546
+ def assets_dirs(self) -> pathlib.Path:
547
+ """Get the assets directory for this config."""
548
+ return (pathlib.Path(self.assets_base_dir) / self.name).resolve()
549
+
550
+ @property
551
+ def checkpoint_dir(self) -> pathlib.Path:
552
+ """Get the checkpoint directory for this config."""
553
+ if not self.exp_name:
554
+ raise ValueError("--exp_name must be set")
555
+ return (pathlib.Path(self.checkpoint_base_dir) / self.name / self.exp_name).resolve()
556
+
557
+ @property
558
+ def trainable_filter(self) -> nnx.filterlib.Filter:
559
+ """Get the filter for the trainable parameters."""
560
+ return nnx.All(nnx.Param, nnx.Not(self.freeze_filter))
561
+
562
+ def __post_init__(self) -> None:
563
+ if self.resume and self.overwrite:
564
+ raise ValueError("Cannot resume and overwrite at the same time.")
565
+
566
+
567
+ # Use `get_config` if you need to get a config by name in your code.
568
+ _CONFIGS = [
569
+ #
570
+ # Inference Aloha configs.
571
+ #
572
+ TrainConfig(
573
+ name="pi0_aloha",
574
+ model=pi0_config.Pi0Config(),
575
+ data=LeRobotAlohaDataConfig(
576
+ assets=AssetsConfig(asset_id="trossen"),
577
+ ),
578
+ policy_metadata={"reset_pose": [0, -1.5, 1.5, 0, 0, 0]},
579
+ ),
580
+ TrainConfig(
581
+ name="pi05_aloha",
582
+ model=pi0_config.Pi0Config(pi05=True),
583
+ data=LeRobotAlohaDataConfig(
584
+ assets=AssetsConfig(asset_id="trossen"),
585
+ ),
586
+ policy_metadata={"reset_pose": [0, -1.5, 1.5, 0, 0, 0]},
587
+ ),
588
+ TrainConfig(
589
+ name="pi0_aloha_towel",
590
+ model=pi0_config.Pi0Config(),
591
+ data=LeRobotAlohaDataConfig(
592
+ assets=AssetsConfig(asset_id="trossen"),
593
+ default_prompt="fold the towel",
594
+ ),
595
+ policy_metadata={"reset_pose": [0, -1.5, 1.5, 0, 0, 0]},
596
+ ),
597
+ TrainConfig(
598
+ name="pi0_aloha_tupperware",
599
+ model=pi0_config.Pi0Config(),
600
+ data=LeRobotAlohaDataConfig(
601
+ assets=AssetsConfig(asset_id="trossen"),
602
+ default_prompt="open the tupperware and put the food on the plate",
603
+ ),
604
+ policy_metadata={"reset_pose": [0, -1.5, 1.5, 0, 0, 0]},
605
+ ),
606
+ #
607
+ # Inference DROID configs.
608
+ #
609
+ TrainConfig(
610
+ name="pi0_droid",
611
+ model=pi0_config.Pi0Config(action_horizon=10),
612
+ data=SimpleDataConfig(
613
+ assets=AssetsConfig(asset_id="droid"),
614
+ data_transforms=lambda model: _transforms.Group(
615
+ inputs=[droid_policy.DroidInputs(model_type=ModelType.PI0)],
616
+ outputs=[droid_policy.DroidOutputs()],
617
+ ),
618
+ base_config=DataConfig(
619
+ prompt_from_task=True,
620
+ ),
621
+ ),
622
+ ),
623
+ TrainConfig(
624
+ name="pi0_fast_droid",
625
+ model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=10),
626
+ data=SimpleDataConfig(
627
+ assets=AssetsConfig(asset_id="droid"),
628
+ data_transforms=lambda model: _transforms.Group(
629
+ inputs=[droid_policy.DroidInputs(model_type=ModelType.PI0_FAST)],
630
+ outputs=[droid_policy.DroidOutputs()],
631
+ ),
632
+ base_config=DataConfig(
633
+ prompt_from_task=True,
634
+ ),
635
+ ),
636
+ ),
637
+ TrainConfig(
638
+ name="pi05_droid",
639
+ model=pi0_config.Pi0Config(action_horizon=15, pi05=True),
640
+ data=SimpleDataConfig(
641
+ assets=AssetsConfig(asset_id="droid"),
642
+ data_transforms=lambda model: _transforms.Group(
643
+ inputs=[droid_policy.DroidInputs(model_type=ModelType.PI05)],
644
+ outputs=[droid_policy.DroidOutputs()],
645
+ ),
646
+ base_config=DataConfig(
647
+ prompt_from_task=True,
648
+ ),
649
+ ),
650
+ ),
651
+ #
652
+ # Fine-tuning Libero configs.
653
+ #
654
+ # These train configs define the hyperparameters for fine-tuning the base model on your own dataset.
655
+ # They are used to define key elements like the dataset you are training on, the base checkpoint you
656
+ # are using, and other hyperparameters like how many training steps to run or what learning rate to use.
657
+ # For your own dataset, you can copy this class and modify the dataset name, and data transforms based on
658
+ # the comments below.
659
+ TrainConfig(
660
+ # Change the name to reflect your model and dataset.
661
+ name="pi0_libero",
662
+ # Here you define the model config -- In this example we use pi0 as the model
663
+ # architecture and perform *full* finetuning. in the examples below we show how to modify
664
+ # this to perform *low-memory* (LORA) finetuning and use pi0-FAST as an alternative architecture.
665
+ model=pi0_config.Pi0Config(),
666
+ # Here you define the dataset you are training on. In this example we use the Libero
667
+ # dataset. For your own dataset, you can change the repo_id to point to your dataset.
668
+ # Also modify the DataConfig to use the new config you made for your dataset above.
669
+ data=LeRobotLiberoDataConfig(
670
+ repo_id="physical-intelligence/libero",
671
+ base_config=DataConfig(
672
+ # This flag determines whether we load the prompt (i.e. the task instruction) from the
673
+ # ``task`` field in the LeRobot dataset. If set to True, the prompt will show up in
674
+ # a field called ``prompt`` in the input dict. The recommended setting is True.
675
+ prompt_from_task=True,
676
+ ),
677
+ extra_delta_transform=True,
678
+ ),
679
+ # Here you define which pre-trained checkpoint you want to load to initialize the model.
680
+ # This should match the model config you chose above -- i.e. in this case we use the pi0 base model.
681
+ weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"),
682
+ # Below you can define other hyperparameters like the learning rate, number of training steps, etc.
683
+ # Check the base TrainConfig class for a full list of available hyperparameters.
684
+ num_train_steps=30_000,
685
+ ),
686
+ TrainConfig(
687
+ name="pi0_libero_low_mem_finetune",
688
+ # Here is an example of loading a pi0 model for LoRA fine-tuning.
689
+ model=pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora"),
690
+ data=LeRobotLiberoDataConfig(
691
+ repo_id="physical-intelligence/libero",
692
+ base_config=DataConfig(prompt_from_task=True),
693
+ extra_delta_transform=True,
694
+ ),
695
+ weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"),
696
+ num_train_steps=30_000,
697
+ # The freeze filter defines which parameters should be frozen during training.
698
+ # We have a convenience function in the model config that returns the default freeze filter
699
+ # for the given model config for LoRA finetuning. Just make sure it matches the model config
700
+ # you chose above.
701
+ freeze_filter=pi0_config.Pi0Config(
702
+ paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora"
703
+ ).get_freeze_filter(),
704
+ # Turn off EMA for LoRA finetuning.
705
+ ema_decay=None,
706
+ ),
707
+ TrainConfig(
708
+ name="pi0_fast_libero",
709
+ # Here is an example of loading a pi0-FAST model for full finetuning.
710
+ # Modify action_dim and action_horizon to match your dataset (action horizon is equal to
711
+ # the desired action chunk length).
712
+ # The max_token_len is the maximum number of (non-image) tokens the model can handle.
713
+ # This includes the tokenized prompt, proprioceptive state, and (FAST-tokenized) action tokens.
714
+ # Choosing this value too small may chop off tokens at the end of your sequence (the code will throw
715
+ # a warning), while choosing it too large will waste memory (since we pad each batch element to the
716
+ # max_token_len). A good rule of thumb is to use approx 180 for single-arm robots, and approx 250 for
717
+ # two-arm robots. Generally, err on the lower side here first, and potentially increase the value if
718
+ # you see many warnings being thrown during training.
719
+ model=pi0_fast.Pi0FASTConfig(action_dim=7, action_horizon=10, max_token_len=180),
720
+ data=LeRobotLiberoDataConfig(
721
+ repo_id="physical-intelligence/libero",
722
+ base_config=DataConfig(prompt_from_task=True),
723
+ extra_delta_transform=True,
724
+ ),
725
+ # Note that we load the pi0-FAST base model checkpoint here.
726
+ weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_fast_base/params"),
727
+ num_train_steps=30_000,
728
+ ),
729
+ TrainConfig(
730
+ name="pi0_fast_libero_low_mem_finetune",
731
+ # Here is an example of loading a pi0-FAST model for LoRA finetuning.
732
+ # For setting action_dim, action_horizon, and max_token_len, see the comments above.
733
+ model=pi0_fast.Pi0FASTConfig(
734
+ action_dim=7, action_horizon=10, max_token_len=180, paligemma_variant="gemma_2b_lora"
735
+ ),
736
+ data=LeRobotLiberoDataConfig(
737
+ repo_id="physical-intelligence/libero",
738
+ base_config=DataConfig(prompt_from_task=True),
739
+ extra_delta_transform=True,
740
+ ),
741
+ weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_fast_base/params"),
742
+ num_train_steps=30_000,
743
+ # Again, make sure to match the model config above when extracting the freeze filter
744
+ # that specifies which parameters should be frozen during LoRA finetuning.
745
+ freeze_filter=pi0_fast.Pi0FASTConfig(
746
+ action_dim=7, action_horizon=10, max_token_len=180, paligemma_variant="gemma_2b_lora"
747
+ ).get_freeze_filter(),
748
+ # Turn off EMA for LoRA finetuning.
749
+ ema_decay=None,
750
+ ),
751
+ TrainConfig(
752
+ name="pi05_libero",
753
+ model=pi0_config.Pi0Config(pi05=True, action_horizon=10, discrete_state_input=False),
754
+ data=LeRobotLiberoDataConfig(
755
+ repo_id="physical-intelligence/libero",
756
+ base_config=DataConfig(prompt_from_task=True),
757
+ extra_delta_transform=False,
758
+ ),
759
+ batch_size=256,
760
+ lr_schedule=_optimizer.CosineDecaySchedule(
761
+ warmup_steps=10_000,
762
+ peak_lr=5e-5,
763
+ decay_steps=1_000_000,
764
+ decay_lr=5e-5,
765
+ ),
766
+ optimizer=_optimizer.AdamW(clip_gradient_norm=1.0),
767
+ ema_decay=0.999,
768
+ weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_base/params"),
769
+ pytorch_weight_path="/path/to/your/pytorch_weight_path",
770
+ num_train_steps=30_000,
771
+ ),
772
+ #
773
+ # Fine-tuning Aloha configs.
774
+ #
775
+ # Personal Tasks
776
+ TrainConfig(
777
+ name="pi05_capvector_aloha_place_block", # <config_name>
778
+ model=pi0_config.Pi0Config(pi05=True, discrete_state_input=False),
779
+ data=LeRobotAlohaDataConfig(
780
+ repo_id="cobot_dataset/place_one_floor_block", # your datasets repo_id, like "<org>/<dataset-name>"
781
+ default_prompt="place the green block",
782
+ repack_transforms=_transforms.Group(
783
+ inputs=[
784
+ _transforms.RepackTransform(
785
+ {
786
+ "images": {
787
+ "cam_high": "observation.images.cam_high",
788
+ "cam_left_wrist": "observation.images.cam_left_wrist",
789
+ "cam_right_wrist": "observation.images.cam_right_wrist",
790
+ },
791
+ "state": "observation.state",
792
+ "actions": "action",
793
+ }
794
+ )
795
+ ]
796
+ ),
797
+ ),
798
+ weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_base/params"),
799
+ pytorch_weight_path='./checkpoints/vector_init/pi05SF-LIBEROspatial_minus_pi05-LIBEROspatial',
800
+ # CapVector
801
+ regularization_vector_path='checkpoints/diff/pi05SF-LIBEROspatial_minus_pi05-LIBEROspatial.pth',
802
+ regularization_coeff=1e-4,
803
+ #
804
+ num_train_steps=30_000,
805
+ batch_size=16,
806
+ ema_decay=None,
807
+ wandb_enabled=False,
808
+ ),
809
+ #
810
+ # This is a test config that is used to illustate how train on a custom LeRobot dataset.
811
+ # For instuctions on how to convert and train on your own Aloha dataset see examples/aloha_real/README.md
812
+ TrainConfig(
813
+ name="pi0_aloha_pen_uncap",
814
+ model=pi0_config.Pi0Config(),
815
+ data=LeRobotAlohaDataConfig(
816
+ repo_id="physical-intelligence/aloha_pen_uncap_diverse",
817
+ assets=AssetsConfig(
818
+ assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
819
+ asset_id="trossen",
820
+ ),
821
+ default_prompt="uncap the pen",
822
+ repack_transforms=_transforms.Group(
823
+ inputs=[
824
+ _transforms.RepackTransform(
825
+ {
826
+ "images": {
827
+ "cam_high": "observation.images.cam_high",
828
+ "cam_left_wrist": "observation.images.cam_left_wrist",
829
+ "cam_right_wrist": "observation.images.cam_right_wrist",
830
+ },
831
+ "state": "observation.state",
832
+ "actions": "action",
833
+ }
834
+ )
835
+ ]
836
+ ),
837
+ ),
838
+ weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"),
839
+ num_train_steps=20_000,
840
+ ),
841
+ TrainConfig(
842
+ name="pi05_aloha_pen_uncap",
843
+ model=pi0_config.Pi0Config(pi05=True),
844
+ data=LeRobotAlohaDataConfig(
845
+ repo_id="physical-intelligence/aloha_pen_uncap_diverse",
846
+ assets=AssetsConfig(
847
+ assets_dir="gs://openpi-assets/checkpoints/pi05_base/assets",
848
+ asset_id="trossen",
849
+ ),
850
+ default_prompt="uncap the pen",
851
+ repack_transforms=_transforms.Group(
852
+ inputs=[
853
+ _transforms.RepackTransform(
854
+ {
855
+ "images": {
856
+ "cam_high": "observation.images.cam_high",
857
+ "cam_left_wrist": "observation.images.cam_left_wrist",
858
+ "cam_right_wrist": "observation.images.cam_right_wrist",
859
+ },
860
+ "state": "observation.state",
861
+ "actions": "action",
862
+ }
863
+ )
864
+ ]
865
+ ),
866
+ ),
867
+ weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_base/params"),
868
+ num_train_steps=20_000,
869
+ batch_size=64,
870
+ ),
871
+ #
872
+ # Fine-tuning DROID configs.
873
+ #
874
+ TrainConfig(
875
+ # This config is for fine-tuning pi0-FAST-base on the *full* DROID dataset.
876
+ # We use RLDS data loading to make training on this large dataset tractable.
877
+ # For fine-tuning on your own DROID dataset, see below.
878
+ name="pi0_fast_full_droid_finetune",
879
+ model=pi0_fast.Pi0FASTConfig(
880
+ action_dim=8,
881
+ action_horizon=16,
882
+ max_token_len=180,
883
+ ),
884
+ data=RLDSDroidDataConfig(
885
+ repo_id="droid",
886
+ # Set this to the path to your DROID RLDS dataset (the parent directory of the `droid` directory).
887
+ rlds_data_dir="<path_to_droid_rlds_dataset>",
888
+ action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION,
889
+ ),
890
+ weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_fast_base/params"),
891
+ lr_schedule=_optimizer.CosineDecaySchedule(
892
+ warmup_steps=1_000,
893
+ peak_lr=5e-5,
894
+ decay_steps=1_000_000,
895
+ decay_lr=5e-5,
896
+ ),
897
+ num_train_steps=100_000, # 100k steps should be sufficient, takes ~2 days on 8x H100s
898
+ batch_size=256,
899
+ log_interval=100,
900
+ save_interval=5000,
901
+ keep_period=20_000,
902
+ num_workers=0, # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally
903
+ ),
904
+ TrainConfig(
905
+ # This config is for fine-tuning pi05 on the *full* DROID dataset.
906
+ # We use RLDS data loading to make training on this large dataset tractable.
907
+ # For fine-tuning on your own DROID dataset, see below.
908
+ name="pi05_full_droid_finetune",
909
+ model=pi0_config.Pi0Config(
910
+ pi05=True,
911
+ action_dim=32,
912
+ action_horizon=16,
913
+ ),
914
+ data=RLDSDroidDataConfig(
915
+ repo_id="droid",
916
+ # Set this to the path to your DROID RLDS dataset (the parent directory of the `droid` directory).
917
+ rlds_data_dir="/mnt/pi-data/kevin",
918
+ action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION,
919
+ assets=AssetsConfig(
920
+ assets_dir="gs://openpi-assets/checkpoints/pi05_base/assets/",
921
+ asset_id="droid",
922
+ ),
923
+ ),
924
+ weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_base/params"),
925
+ lr_schedule=_optimizer.CosineDecaySchedule(
926
+ warmup_steps=1_000,
927
+ peak_lr=5e-5,
928
+ decay_steps=1_000_000,
929
+ decay_lr=5e-5,
930
+ ),
931
+ num_train_steps=100_000,
932
+ batch_size=256,
933
+ log_interval=100,
934
+ save_interval=5000,
935
+ keep_period=10_000,
936
+ num_workers=0, # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally
937
+ ),
938
+ TrainConfig(
939
+ # This config is for fine-tuning pi05-DROID on a custom (smaller) DROID dataset.
940
+ # Here, we use LeRobot data format (like for all other fine-tuning examples)
941
+ # To convert your custom DROID dataset (<10s of hours) to LeRobot format, see examples/droid/convert_droid_data_to_lerobot.py
942
+ name="pi05_droid_finetune",
943
+ model=pi0_config.Pi0Config(
944
+ pi05=True,
945
+ action_dim=32, # pi05 is trained with 32-dim actions
946
+ action_horizon=16,
947
+ ),
948
+ data=LeRobotDROIDDataConfig(
949
+ # Replace with your custom DROID LeRobot dataset repo id.
950
+ repo_id="your_hf_username/my_droid_dataset",
951
+ base_config=DataConfig(prompt_from_task=True),
952
+ assets=AssetsConfig(
953
+ # Important: reuse the original DROID norm stats during fine-tuning!
954
+ assets_dir="gs://openpi-assets/checkpoints/pi05_droid/assets",
955
+ asset_id="droid",
956
+ ),
957
+ ),
958
+ weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_droid/params"),
959
+ num_train_steps=20_000,
960
+ batch_size=32,
961
+ ),
962
+ #
963
+ # ALOHA Sim configs. This config is used to demonstrate how to train on a simple simulated environment.
964
+ #
965
+ TrainConfig(
966
+ name="pi0_aloha_sim",
967
+ model=pi0_config.Pi0Config(),
968
+ data=LeRobotAlohaDataConfig(
969
+ repo_id="lerobot/aloha_sim_transfer_cube_human",
970
+ default_prompt="Transfer cube",
971
+ use_delta_joint_actions=False,
972
+ ),
973
+ weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"),
974
+ num_train_steps=20_000,
975
+ ),
976
+ #
977
+ # Debugging configs.
978
+ #
979
+ TrainConfig(
980
+ name="debug",
981
+ data=FakeDataConfig(),
982
+ batch_size=2,
983
+ model=pi0_config.Pi0Config(paligemma_variant="dummy", action_expert_variant="dummy"),
984
+ save_interval=100,
985
+ overwrite=True,
986
+ exp_name="debug",
987
+ num_train_steps=10,
988
+ wandb_enabled=False,
989
+ ),
990
+ TrainConfig(
991
+ name="debug_restore",
992
+ data=FakeDataConfig(),
993
+ batch_size=2,
994
+ model=pi0_config.Pi0Config(paligemma_variant="dummy", action_expert_variant="dummy"),
995
+ weight_loader=weight_loaders.CheckpointWeightLoader("./checkpoints/debug/debug/9/params"),
996
+ overwrite=True,
997
+ exp_name="debug",
998
+ num_train_steps=10,
999
+ wandb_enabled=False,
1000
+ ),
1001
+ TrainConfig(
1002
+ name="debug_pi05",
1003
+ model=pi0_config.Pi0Config(pi05=True, paligemma_variant="dummy", action_expert_variant="dummy"),
1004
+ data=FakeDataConfig(),
1005
+ batch_size=2,
1006
+ num_train_steps=10,
1007
+ overwrite=True,
1008
+ exp_name="debug_pi05",
1009
+ wandb_enabled=False,
1010
+ ),
1011
+ #
1012
+ # RoboArena configs.
1013
+ #
1014
+ *roboarena_config.get_roboarena_configs(),
1015
+ ]
1016
+
1017
+ if len({config.name for config in _CONFIGS}) != len(_CONFIGS):
1018
+ raise ValueError("Config names must be unique.")
1019
+ _CONFIGS_DICT = {config.name: config for config in _CONFIGS}
1020
+
1021
+
1022
+ def cli() -> TrainConfig:
1023
+ return tyro.extras.overridable_config_cli({k: (k, v) for k, v in _CONFIGS_DICT.items()})
1024
+
1025
+
1026
+ def get_config(config_name: str) -> TrainConfig:
1027
+ """Get a config by name."""
1028
+ if config_name not in _CONFIGS_DICT:
1029
+ closest = difflib.get_close_matches(config_name, _CONFIGS_DICT.keys(), n=1, cutoff=0.0)
1030
+ closest_str = f" Did you mean '{closest[0]}'? " if closest else ""
1031
+ raise ValueError(f"Config '{config_name}' not found.{closest_str}")
1032
+
1033
+ return _CONFIGS_DICT[config_name]
capvector-pi05/src/openpi/training/data_loader.py ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Iterator, Sequence
2
+ import logging
3
+ import multiprocessing
4
+ import os
5
+ import typing
6
+ from typing import Literal, Protocol, SupportsIndex, TypeVar
7
+
8
+ import jax
9
+ import jax.numpy as jnp
10
+ import lerobot.common.datasets.lerobot_dataset as lerobot_dataset
11
+ import numpy as np
12
+ import torch
13
+
14
+ import openpi.models.model as _model
15
+ import openpi.training.config as _config
16
+ from openpi.training.droid_rlds_dataset import DroidRldsDataset
17
+ import openpi.transforms as _transforms
18
+
19
+ T_co = TypeVar("T_co", covariant=True)
20
+
21
+
22
+ class Dataset(Protocol[T_co]):
23
+ """Interface for a dataset with random access."""
24
+
25
+ def __getitem__(self, index: SupportsIndex) -> T_co:
26
+ raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")
27
+
28
+ def __len__(self) -> int:
29
+ raise NotImplementedError("Subclasses of Dataset should implement __len__.")
30
+
31
+
32
+ class IterableDataset(Protocol[T_co]):
33
+ """Interface for an iterable dataset."""
34
+
35
+ def __iter__(self) -> Iterator[T_co]:
36
+ raise NotImplementedError("Subclasses of IterableDataset should implement __iter__.")
37
+
38
+ def __len__(self) -> int:
39
+ raise NotImplementedError("Subclasses of Dataset should implement __len__.")
40
+
41
+
42
+ class DataLoader(Protocol[T_co]):
43
+ """Interface for a data loader."""
44
+
45
+ def data_config(self) -> _config.DataConfig:
46
+ """Get the data config for this data loader."""
47
+ raise NotImplementedError("Subclasses of DataLoader should implement data_config.")
48
+
49
+ def __iter__(self) -> Iterator[T_co]:
50
+ raise NotImplementedError("Subclasses of DataLoader should implement __iter__.")
51
+
52
+
53
+ class TransformedDataset(Dataset[T_co]):
54
+ def __init__(self, dataset: Dataset, transforms: Sequence[_transforms.DataTransformFn]):
55
+ self._dataset = dataset
56
+ self._transform = _transforms.compose(transforms)
57
+
58
+ def __getitem__(self, index: SupportsIndex) -> T_co:
59
+ return self._transform(self._dataset[index])
60
+
61
+ def __len__(self) -> int:
62
+ return len(self._dataset)
63
+
64
+
65
+ class IterableTransformedDataset(IterableDataset[T_co]):
66
+ def __init__(
67
+ self,
68
+ dataset: IterableDataset,
69
+ transforms: Sequence[_transforms.DataTransformFn],
70
+ *,
71
+ is_batched: bool = False,
72
+ ):
73
+ self._dataset = dataset
74
+ self._transform = _transforms.compose(transforms)
75
+ self._is_batched = is_batched
76
+
77
+ def __iter__(self):
78
+ for sample in self._dataset:
79
+ if self._is_batched:
80
+ # Transforms are designed to be applied to individual samples. So we need to split the batch into
81
+ # individual samples and apply the transform to each sample individually.
82
+ batch_size = next(v.shape[0] for v in sample.values())
83
+
84
+ # Split batch into individual samples using tree_map
85
+ individual_samples = [jax.tree.map(lambda x: x[i], sample) for i in range(batch_size)] # noqa: B023
86
+
87
+ # Transform each sample
88
+ transformed = [self._transform(s) for s in individual_samples]
89
+
90
+ # Recombine batch with tree_map
91
+ yield jax.tree.map(lambda *x: np.stack(x, axis=0), *transformed)
92
+ else:
93
+ yield self._transform(sample)
94
+
95
+ def __len__(self) -> int:
96
+ return len(self._dataset)
97
+
98
+
99
+ class FakeDataset(Dataset):
100
+ def __init__(self, model_config: _model.BaseModelConfig, num_samples: int):
101
+ self._num_samples = num_samples
102
+ self._observation_spec, self._action_spec = model_config.inputs_spec()
103
+
104
+ def __getitem__(self, index: SupportsIndex) -> dict:
105
+ rng = jax.random.key(index.__index__())
106
+
107
+ def make_from_spec(spec: jax.ShapeDtypeStruct):
108
+ nonlocal rng
109
+ rng, data_rng = jax.random.split(rng)
110
+ # Remove the batch dimension.
111
+ shape = spec.shape[1:]
112
+ if spec.dtype == jnp.float32:
113
+ return jax.random.uniform(data_rng, shape=shape, minval=-1.0, maxval=1.0)
114
+ if spec.dtype == jnp.int32:
115
+ return jax.random.randint(data_rng, shape=shape, minval=0, maxval=2048)
116
+ return jnp.zeros(shape=shape, dtype=spec.dtype)
117
+
118
+ observation = jax.tree.map(make_from_spec, self._observation_spec)
119
+ action = jax.tree.map(make_from_spec, self._action_spec)
120
+
121
+ return {
122
+ **observation.to_dict(),
123
+ "actions": action,
124
+ }
125
+
126
+ def __len__(self) -> int:
127
+ return self._num_samples
128
+
129
+
130
+ def create_torch_dataset(
131
+ data_config: _config.DataConfig, action_horizon: int, model_config: _model.BaseModelConfig
132
+ ) -> Dataset:
133
+ """Create a dataset for training."""
134
+ repo_id = data_config.repo_id
135
+ if repo_id is None:
136
+ raise ValueError("Repo ID is not set. Cannot create dataset.")
137
+ if repo_id == "fake":
138
+ return FakeDataset(model_config, num_samples=1024)
139
+
140
+ dataset_meta = lerobot_dataset.LeRobotDatasetMetadata(repo_id)
141
+ dataset = lerobot_dataset.LeRobotDataset(
142
+ data_config.repo_id,
143
+ delta_timestamps={
144
+ key: [t / dataset_meta.fps for t in range(action_horizon)] for key in data_config.action_sequence_keys
145
+ },
146
+ )
147
+
148
+ if data_config.prompt_from_task:
149
+ dataset = TransformedDataset(dataset, [_transforms.PromptFromLeRobotTask(dataset_meta.tasks)])
150
+
151
+ return dataset
152
+
153
+
154
+ def create_rlds_dataset(
155
+ data_config: _config.DataConfig,
156
+ action_horizon: int,
157
+ batch_size: int,
158
+ *,
159
+ shuffle: bool = False,
160
+ ) -> Dataset:
161
+ # At the moment, we only support DROID for RLDS datasets.
162
+ return DroidRldsDataset(
163
+ data_dir=data_config.rlds_data_dir,
164
+ batch_size=batch_size,
165
+ shuffle=shuffle,
166
+ action_chunk_size=action_horizon,
167
+ action_space=data_config.action_space,
168
+ filter_dict_path=data_config.filter_dict_path,
169
+ )
170
+
171
+
172
+ def transform_dataset(dataset: Dataset, data_config: _config.DataConfig, *, skip_norm_stats: bool = False) -> Dataset:
173
+ """Transform the dataset by applying the data transforms."""
174
+ norm_stats = {}
175
+ if data_config.repo_id != "fake" and not skip_norm_stats:
176
+ if data_config.norm_stats is None:
177
+ raise ValueError(
178
+ "Normalization stats not found. "
179
+ "Make sure to run `scripts/compute_norm_stats.py --config-name=<your-config>`."
180
+ )
181
+ norm_stats = data_config.norm_stats
182
+
183
+ return TransformedDataset(
184
+ dataset,
185
+ [
186
+ *data_config.repack_transforms.inputs,
187
+ *data_config.data_transforms.inputs,
188
+ _transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
189
+ *data_config.model_transforms.inputs,
190
+ ],
191
+ )
192
+
193
+
194
+ def transform_iterable_dataset(
195
+ dataset: IterableDataset,
196
+ data_config: _config.DataConfig,
197
+ *,
198
+ skip_norm_stats: bool = False,
199
+ is_batched: bool = False,
200
+ ) -> IterableDataset:
201
+ """Transform the dataset by applying the data transforms."""
202
+ norm_stats = {}
203
+ if data_config.repo_id != "fake" and not skip_norm_stats:
204
+ if data_config.norm_stats is None:
205
+ raise ValueError(
206
+ "Normalization stats not found. "
207
+ "Make sure to run `scripts/compute_norm_stats.py --config-name=<your-config>`."
208
+ )
209
+ norm_stats = data_config.norm_stats
210
+
211
+ return IterableTransformedDataset(
212
+ dataset,
213
+ [
214
+ *data_config.repack_transforms.inputs,
215
+ *data_config.data_transforms.inputs,
216
+ _transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
217
+ *data_config.model_transforms.inputs,
218
+ ],
219
+ is_batched=is_batched,
220
+ )
221
+
222
+
223
+ def create_data_loader(
224
+ config: _config.TrainConfig,
225
+ *,
226
+ sharding: jax.sharding.Sharding | None = None,
227
+ shuffle: bool = False,
228
+ num_batches: int | None = None,
229
+ skip_norm_stats: bool = False,
230
+ framework: Literal["jax", "pytorch"] = "jax",
231
+ ) -> DataLoader[tuple[_model.Observation, _model.Actions]]:
232
+ """Create a data loader for training.
233
+
234
+ Args:
235
+ config: The training configuration.
236
+ sharding: The sharding to use for the data loader (JAX only).
237
+ shuffle: Whether to shuffle the data.
238
+ num_batches: Determines the number of batches to return.
239
+ skip_norm_stats: Whether to skip data normalization.
240
+ framework: The framework to use ("jax" or "pytorch").
241
+ """
242
+ data_config = config.data.create(config.assets_dirs, config.model)
243
+ logging.info(f"data_config: {data_config}")
244
+
245
+ if data_config.rlds_data_dir is not None:
246
+ return create_rlds_data_loader(
247
+ data_config,
248
+ action_horizon=config.model.action_horizon,
249
+ batch_size=config.batch_size,
250
+ sharding=sharding,
251
+ shuffle=shuffle,
252
+ num_batches=num_batches,
253
+ skip_norm_stats=skip_norm_stats,
254
+ framework=framework,
255
+ )
256
+ return create_torch_data_loader(
257
+ data_config,
258
+ model_config=config.model,
259
+ action_horizon=config.model.action_horizon,
260
+ batch_size=config.batch_size,
261
+ sharding=sharding,
262
+ shuffle=shuffle,
263
+ num_batches=num_batches,
264
+ num_workers=config.num_workers,
265
+ seed=config.seed,
266
+ skip_norm_stats=skip_norm_stats,
267
+ framework=framework,
268
+ )
269
+
270
+
271
+ def create_torch_data_loader(
272
+ data_config: _config.DataConfig,
273
+ model_config: _model.BaseModelConfig,
274
+ action_horizon: int,
275
+ batch_size: int,
276
+ *,
277
+ sharding: jax.sharding.Sharding | None = None,
278
+ skip_norm_stats: bool = False,
279
+ shuffle: bool = False,
280
+ num_batches: int | None = None,
281
+ num_workers: int = 0,
282
+ seed: int = 0,
283
+ framework: str = "jax",
284
+ ) -> DataLoader[tuple[_model.Observation, _model.Actions]]:
285
+ """Create a data loader for training.
286
+
287
+ Args:
288
+ data_config: The data configuration.
289
+ action_horizon: The action horizon.
290
+ batch_size: The batch size.
291
+ sharding: The sharding to use for the data loader. If None, the data loader will
292
+ use a single device sharding.
293
+ skip_norm_stats: Whether to skip data normalization.
294
+ shuffle: Whether to shuffle the data.
295
+ num_batches: Determines the number of batches to return. If the number exceeds the
296
+ number of batches in the dataset, the data loader will loop over the dataset.
297
+ If not provided, will iterate over the dataset indefinitely.
298
+ num_workers: The number of worker processes to use. If zero, the data loader will
299
+ execute in the main process.
300
+ seed: The seed to use for shuffling the data.
301
+ """
302
+ dataset = create_torch_dataset(data_config, action_horizon, model_config)
303
+ dataset = transform_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats)
304
+
305
+ # Use TorchDataLoader for both frameworks
306
+ # For PyTorch DDP, create DistributedSampler and divide batch size by world size
307
+ # For JAX, divide by process count
308
+ sampler = None
309
+ if framework == "pytorch":
310
+ if torch.distributed.is_initialized():
311
+ sampler = torch.utils.data.distributed.DistributedSampler(
312
+ dataset,
313
+ num_replicas=torch.distributed.get_world_size(),
314
+ rank=torch.distributed.get_rank(),
315
+ shuffle=shuffle,
316
+ drop_last=True,
317
+ )
318
+ local_batch_size = batch_size // torch.distributed.get_world_size()
319
+ else:
320
+ local_batch_size = batch_size
321
+ else:
322
+ local_batch_size = batch_size // jax.process_count()
323
+
324
+ logging.info(f"local_batch_size: {local_batch_size}")
325
+ data_loader = TorchDataLoader(
326
+ dataset,
327
+ local_batch_size=local_batch_size,
328
+ sharding=None if framework == "pytorch" else sharding,
329
+ shuffle=(sampler is None and shuffle), # Don't shuffle if using sampler
330
+ sampler=sampler,
331
+ num_batches=num_batches,
332
+ num_workers=num_workers,
333
+ seed=seed,
334
+ framework=framework,
335
+ )
336
+
337
+ return DataLoaderImpl(data_config, data_loader)
338
+
339
+
340
+ def create_rlds_data_loader(
341
+ data_config: _config.DataConfig,
342
+ action_horizon: int,
343
+ batch_size: int,
344
+ *,
345
+ sharding: jax.sharding.Sharding | None = None,
346
+ skip_norm_stats: bool = False,
347
+ shuffle: bool = False,
348
+ num_batches: int | None = None,
349
+ framework: str = "jax",
350
+ ) -> DataLoader[tuple[_model.Observation, _model.Actions]]:
351
+ """Create an RLDS data loader for training.
352
+
353
+ Note: This data loader requires some extra dependencies -- see examples/droid/README_train.md
354
+
355
+ Args:
356
+ data_config: The data configuration.
357
+ action_horizon: The action horizon.
358
+ batch_size: The batch size.
359
+ sharding: The sharding to use for the data loader. If None, the data loader will
360
+ use a single device sharding.
361
+ skip_norm_stats: Whether to skip data normalization.
362
+ shuffle: Whether to shuffle the data.
363
+ num_batches: Determines the number of batches to return. If the number exceeds the
364
+ number of batches in the dataset, the data loader will loop over the dataset.
365
+ If not provided, will iterate over the dataset indefinitely.
366
+ """
367
+ if framework == "pytorch":
368
+ raise NotImplementedError("PyTorch RLDS data loader is not supported yet")
369
+ dataset = create_rlds_dataset(data_config, action_horizon, batch_size, shuffle=shuffle)
370
+ dataset = transform_iterable_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats, is_batched=True)
371
+
372
+ data_loader = RLDSDataLoader(
373
+ dataset,
374
+ sharding=sharding,
375
+ num_batches=num_batches,
376
+ )
377
+
378
+ return DataLoaderImpl(data_config, data_loader)
379
+
380
+
381
+ class TorchDataLoader:
382
+ """Torch data loader implementation."""
383
+
384
+ def __init__(
385
+ self,
386
+ dataset,
387
+ local_batch_size: int,
388
+ *,
389
+ sharding: jax.sharding.Sharding | None = None,
390
+ shuffle: bool = False,
391
+ sampler: torch.utils.data.Sampler | None = None,
392
+ num_batches: int | None = None,
393
+ num_workers: int = 0,
394
+ seed: int = 0,
395
+ framework: str = "jax",
396
+ ):
397
+ """Create a PyTorch data loader.
398
+
399
+ Args:
400
+ dataset: The dataset to load.
401
+ local_batch_size: The local batch size for each process.
402
+ sharding: The sharding to use for the data loader.
403
+ shuffle: Whether to shuffle the data.
404
+ num_batches: If provided, determines the number of returned batches. If the
405
+ number is larger than the number of batches in the dataset, the data loader
406
+ will loop over the dataset. If not provided, will iterate over the dataset
407
+ indefinitely.
408
+ num_workers: The number of worker processes to use. If zero, the data loader will
409
+ execute in the main process.
410
+ seed: The seed to use for shuffling the data.
411
+ """
412
+ if jax.process_count() > 1:
413
+ raise NotImplementedError("Data loading with multiple processes is not supported.")
414
+
415
+ if len(dataset) < local_batch_size:
416
+ raise ValueError(f"Local batch size ({local_batch_size}) is larger than the dataset size ({len(dataset)}).")
417
+
418
+ # Store sharding - None for PyTorch, JAX sharding for JAX
419
+ self._sharding = sharding
420
+ if sharding is None and framework == "jax":
421
+ # Use data parallel sharding by default for JAX only.
422
+ self._sharding = jax.sharding.NamedSharding(
423
+ jax.sharding.Mesh(jax.devices(), ("B",)),
424
+ jax.sharding.PartitionSpec("B"),
425
+ )
426
+ self._num_batches = num_batches
427
+
428
+ mp_context = None
429
+ if num_workers > 0:
430
+ mp_context = multiprocessing.get_context("spawn")
431
+
432
+ generator = torch.Generator()
433
+ generator.manual_seed(seed)
434
+ self._data_loader = torch.utils.data.DataLoader(
435
+ typing.cast(torch.utils.data.Dataset, dataset),
436
+ batch_size=local_batch_size,
437
+ shuffle=(sampler is None and shuffle), # Don't shuffle if using sampler
438
+ sampler=sampler,
439
+ num_workers=num_workers,
440
+ multiprocessing_context=mp_context,
441
+ persistent_workers=num_workers > 0,
442
+ collate_fn=_collate_fn,
443
+ worker_init_fn=_worker_init_fn,
444
+ drop_last=True,
445
+ generator=generator,
446
+ )
447
+
448
+ @property
449
+ def torch_loader(self) -> torch.utils.data.DataLoader:
450
+ return self._data_loader
451
+
452
+ def __iter__(self):
453
+ num_items = 0
454
+ while True:
455
+ data_iter = iter(self._data_loader)
456
+ while True:
457
+ if self._num_batches is not None and num_items >= self._num_batches:
458
+ return
459
+ try:
460
+ batch = next(data_iter)
461
+ except StopIteration:
462
+ break # We've exhausted the dataset. Create a new iterator and start over.
463
+ num_items += 1
464
+ # For JAX, convert to sharded arrays; for PyTorch, return torch tensors
465
+ if self._sharding is not None:
466
+ yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch)
467
+ else:
468
+ yield jax.tree.map(torch.as_tensor, batch)
469
+
470
+
471
+ def _collate_fn(items):
472
+ """Collate the batch elements into batched numpy arrays."""
473
+ # Make sure to convert to numpy arrays before stacking since some of the incoming elements
474
+ # may be JAX arrays.
475
+ return jax.tree.map(lambda *xs: np.stack([np.asarray(x) for x in xs], axis=0), *items)
476
+
477
+
478
+ def _worker_init_fn(worker_id: int) -> None:
479
+ """Tell JAX inside the worker process not to preallocate the GPU memory."""
480
+ # NOTE: This is called after jax is imported inside the worker process. This
481
+ # means that this approach will not work for selecting the backend.
482
+ os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
483
+ os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
484
+
485
+
486
+ class RLDSDataLoader:
487
+ """Shallow wrapper around the DROID data loader to make it compatible with openpi.
488
+
489
+ All batching already happens in the DROID dataset, so we don't need to do anything here.
490
+ """
491
+
492
+ def __init__(
493
+ self,
494
+ dataset: DroidRldsDataset,
495
+ *,
496
+ sharding: jax.sharding.Sharding | None = None,
497
+ num_batches: int | None = None,
498
+ ):
499
+ self._dataset = dataset
500
+ self._num_batches = num_batches
501
+
502
+ if jax.process_count() > 1:
503
+ raise NotImplementedError("Data loading with multiple processes is not supported.")
504
+
505
+ if sharding is None:
506
+ # Use data parallel sharding by default.
507
+ sharding = jax.sharding.NamedSharding(
508
+ jax.sharding.Mesh(jax.devices(), ("B",)),
509
+ jax.sharding.PartitionSpec("B"),
510
+ )
511
+
512
+ self._sharding = sharding
513
+ self._num_batches = num_batches
514
+
515
+ def __iter__(self):
516
+ num_items = 0
517
+ while True:
518
+ data_iter = iter(self._dataset)
519
+ while True:
520
+ if self._num_batches is not None and num_items >= self._num_batches:
521
+ return
522
+ try:
523
+ batch = next(data_iter)
524
+ except StopIteration:
525
+ break # We've exhausted the dataset. Create a new iterator and start over.
526
+ num_items += 1
527
+ yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch)
528
+
529
+
530
+ class DataLoaderImpl(DataLoader):
531
+ def __init__(self, data_config: _config.DataConfig, data_loader: TorchDataLoader | RLDSDataLoader):
532
+ self._data_config = data_config
533
+ self._data_loader = data_loader
534
+
535
+ def data_config(self) -> _config.DataConfig:
536
+ return self._data_config
537
+
538
+ def __iter__(self):
539
+ for batch in self._data_loader:
540
+ yield _model.Observation.from_dict(batch), batch["actions"]
capvector-pi05/src/openpi/training/data_loader_test.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+
3
+ import jax
4
+
5
+ from openpi.models import pi0_config
6
+ from openpi.training import config as _config
7
+ from openpi.training import data_loader as _data_loader
8
+
9
+
10
+ def test_torch_data_loader():
11
+ config = pi0_config.Pi0Config(action_dim=24, action_horizon=50, max_token_len=48)
12
+ dataset = _data_loader.FakeDataset(config, 16)
13
+
14
+ loader = _data_loader.TorchDataLoader(
15
+ dataset,
16
+ local_batch_size=4,
17
+ num_batches=2,
18
+ )
19
+ batches = list(loader)
20
+
21
+ assert len(batches) == 2
22
+ for batch in batches:
23
+ assert all(x.shape[0] == 4 for x in jax.tree.leaves(batch))
24
+
25
+
26
+ def test_torch_data_loader_infinite():
27
+ config = pi0_config.Pi0Config(action_dim=24, action_horizon=50, max_token_len=48)
28
+ dataset = _data_loader.FakeDataset(config, 4)
29
+
30
+ loader = _data_loader.TorchDataLoader(dataset, local_batch_size=4)
31
+ data_iter = iter(loader)
32
+
33
+ for _ in range(10):
34
+ _ = next(data_iter)
35
+
36
+
37
+ def test_torch_data_loader_parallel():
38
+ config = pi0_config.Pi0Config(action_dim=24, action_horizon=50, max_token_len=48)
39
+ dataset = _data_loader.FakeDataset(config, 10)
40
+
41
+ loader = _data_loader.TorchDataLoader(dataset, local_batch_size=4, num_batches=2, num_workers=2)
42
+ batches = list(loader)
43
+
44
+ assert len(batches) == 2
45
+
46
+ for batch in batches:
47
+ assert all(x.shape[0] == 4 for x in jax.tree.leaves(batch))
48
+
49
+
50
+ def test_with_fake_dataset():
51
+ config = _config.get_config("debug")
52
+
53
+ loader = _data_loader.create_data_loader(config, skip_norm_stats=True, num_batches=2)
54
+ batches = list(loader)
55
+
56
+ assert len(batches) == 2
57
+
58
+ for batch in batches:
59
+ assert all(x.shape[0] == config.batch_size for x in jax.tree.leaves(batch))
60
+
61
+ for _, actions in batches:
62
+ assert actions.shape == (config.batch_size, config.model.action_horizon, config.model.action_dim)
63
+
64
+
65
+ def test_with_real_dataset():
66
+ config = _config.get_config("pi0_aloha_sim")
67
+ config = dataclasses.replace(config, batch_size=4)
68
+
69
+ loader = _data_loader.create_data_loader(
70
+ config,
71
+ # Skip since we may not have the data available.
72
+ skip_norm_stats=True,
73
+ num_batches=2,
74
+ shuffle=True,
75
+ )
76
+ # Make sure that we can get the data config.
77
+ assert loader.data_config().repo_id == config.data.repo_id
78
+
79
+ batches = list(loader)
80
+
81
+ assert len(batches) == 2
82
+
83
+ for _, actions in batches:
84
+ assert actions.shape == (config.batch_size, config.model.action_horizon, config.model.action_dim)
capvector-pi05/src/openpi/training/droid_rlds_dataset.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RLDS-based data loader for DROID.
3
+ While openpi typically uses LeRobot's data loader, it is not currently scalable enough for larger datasets like DROID.
4
+ Thus, we provide a data loader example here that uses the RLDS data format.
5
+ The data loader also applies a few DROID-specific data filters / transformations.
6
+ """
7
+
8
+ from enum import Enum
9
+ from enum import auto
10
+ import json
11
+ import logging
12
+ from pathlib import Path
13
+
14
+ import tqdm
15
+
16
+ import openpi.shared.download as download
17
+
18
+
19
+ class DroidActionSpace(Enum):
20
+ """Action space for DROID dataset."""
21
+
22
+ JOINT_POSITION = auto()
23
+ JOINT_VELOCITY = auto()
24
+
25
+
26
+ class DroidRldsDataset:
27
+ def __init__(
28
+ self,
29
+ data_dir: str,
30
+ batch_size: int,
31
+ *, # Force keyword-only arguments
32
+ shuffle: bool = True,
33
+ action_chunk_size: int = 16,
34
+ # We default to joint position actions, since they allow policy evaluation in simulation.
35
+ action_space: DroidActionSpace = DroidActionSpace.JOINT_POSITION,
36
+ max_loaded_steps_per_episode: int = 100,
37
+ # Reduce this if you are running out of memory, but careful -- below ~100k shuffling is not sufficiently random.
38
+ shuffle_buffer_size: int = 250_000,
39
+ num_parallel_reads: int = -1, # -1 == tf.data.AUTOTUNE -- hack to not import tf at top level
40
+ num_parallel_calls: int = -1, # -1 == tf.data.AUTOTUNE -- hack to not import tf at top level
41
+ filter_dict_path=None, # Path to json file with indices to sample during training
42
+ ):
43
+ # Import tensorflow here to not make it mandatory in case RLDS data loader is not used.
44
+ import dlimp as dl
45
+ import tensorflow as tf
46
+ import tensorflow_datasets as tfds
47
+
48
+ # Configure Tensorflow with *no GPU devices* (to prevent clobber with PyTorch / JAX)
49
+ tf.config.set_visible_devices([], "GPU")
50
+
51
+ builder = tfds.builder("droid", data_dir=data_dir, version="1.0.1")
52
+ dataset = dl.DLataset.from_rlds(builder, split="train", shuffle=shuffle, num_parallel_reads=num_parallel_reads)
53
+
54
+ # Filter out any unsuccessful trajectories -- we use the file name to check this
55
+ dataset = dataset.filter(
56
+ lambda traj: tf.strings.regex_full_match(
57
+ traj["traj_metadata"]["episode_metadata"]["file_path"][0], ".*success.*"
58
+ )
59
+ )
60
+
61
+ # # Repeat dataset so we never run out of data.
62
+ dataset = dataset.repeat()
63
+
64
+ # Load the filter dictionary if provided.
65
+ # The filter dictionary is a JSON file that maps episode keys to ranges of frames to sample
66
+ # (e.g.,
67
+ # {
68
+ # "<episode key>": [[0, 100], [200, 300]]
69
+ # }
70
+ # means keep frames 0-99 and 200-299).
71
+ if filter_dict_path is not None:
72
+ cached_filter_dict_path = download.maybe_download(filter_dict_path)
73
+ with Path(cached_filter_dict_path).open("r") as f:
74
+ filter_dict = json.load(f)
75
+
76
+ logging.info(f"Using filter dictionary with {len(filter_dict)} episodes")
77
+
78
+ keys_tensor = []
79
+ values_tensor = []
80
+
81
+ for episode_key, ranges in tqdm.tqdm(filter_dict.items(), desc="Creating idle filter hash table..."):
82
+ for start, end in ranges:
83
+ for t in range(start, end):
84
+ frame_key = f"{episode_key}--{t}"
85
+ keys_tensor.append(frame_key)
86
+ values_tensor.append(True)
87
+ self.filter_table = tf.lookup.StaticHashTable(
88
+ tf.lookup.KeyValueTensorInitializer(keys_tensor, values_tensor), default_value=False
89
+ )
90
+ logging.info("Filter hash table initialized")
91
+ else:
92
+ self.filter_table = tf.lookup.StaticHashTable(
93
+ tf.lookup.KeyValueTensorInitializer([""], [True]), default_value=True
94
+ )
95
+
96
+ def restructure(traj):
97
+ """Reformat observation and action keys, sample language instruction."""
98
+ # Important: we use joint *position* action space -- easier to simulate!
99
+ actions = tf.concat(
100
+ (
101
+ (
102
+ traj["action_dict"]["joint_position"]
103
+ if action_space == DroidActionSpace.JOINT_POSITION
104
+ else traj["action_dict"]["joint_velocity"]
105
+ ),
106
+ traj["action_dict"]["gripper_position"],
107
+ ),
108
+ axis=-1,
109
+ )
110
+ # Randomly samples one of the two exterior images in DROID during training (we only train with one at a time).
111
+ # Note: the "left" refers to the left camera in the stereo pair, we only train on the left camera.
112
+ exterior_img = tf.cond(
113
+ tf.random.uniform(shape=[]) > 0.5,
114
+ lambda: traj["observation"]["exterior_image_1_left"],
115
+ lambda: traj["observation"]["exterior_image_2_left"],
116
+ )
117
+ wrist_img = traj["observation"]["wrist_image_left"]
118
+ # Randomly sample one of the three language instructions
119
+ instruction = tf.random.shuffle(
120
+ [traj["language_instruction"], traj["language_instruction_2"], traj["language_instruction_3"]]
121
+ )[0]
122
+
123
+ traj_len = tf.shape(traj["action"])[0]
124
+ indices = tf.as_string(tf.range(traj_len))
125
+
126
+ # Data filtering:
127
+ # Compute a uniquely-identifying step ID by concatenating the recording folderpath, file path,
128
+ # and each step's time step index. This will index into the filter hash table, and if it returns true,
129
+ # then the frame passes the filter.
130
+ step_id = (
131
+ traj["traj_metadata"]["episode_metadata"]["recording_folderpath"]
132
+ + "--"
133
+ + traj["traj_metadata"]["episode_metadata"]["file_path"]
134
+ + "--"
135
+ + indices
136
+ )
137
+ passes_filter = self.filter_table.lookup(step_id)
138
+
139
+ return {
140
+ "actions": actions,
141
+ "observation": {
142
+ "image": exterior_img,
143
+ "wrist_image": wrist_img,
144
+ "joint_position": traj["observation"]["joint_position"],
145
+ "gripper_position": traj["observation"]["gripper_position"],
146
+ },
147
+ "prompt": instruction,
148
+ "step_id": step_id,
149
+ "passes_filter": passes_filter,
150
+ }
151
+
152
+ dataset = dataset.traj_map(restructure, num_parallel_calls)
153
+
154
+ def chunk_actions(traj):
155
+ """Splits episode into action chunks."""
156
+ traj_len = tf.shape(traj["actions"])[0]
157
+
158
+ # For each step in the trajectory, construct indices for the next n actions
159
+ action_chunk_indices = tf.broadcast_to(
160
+ tf.range(action_chunk_size)[None],
161
+ [traj_len, action_chunk_size],
162
+ ) + tf.broadcast_to(
163
+ tf.range(traj_len)[:, None],
164
+ [traj_len, action_chunk_size],
165
+ )
166
+
167
+ # Cap to length of the sequence --> final chunks will repeat the last action
168
+ # This makes sense, since we are using absolute joint + gripper position actions
169
+ action_chunk_indices = tf.minimum(action_chunk_indices, traj_len - 1)
170
+
171
+ # Gather the actions for each chunk
172
+ traj["actions"] = tf.gather(traj["actions"], action_chunk_indices)
173
+ return traj
174
+
175
+ dataset = dataset.traj_map(chunk_actions, num_parallel_calls)
176
+
177
+ # Flatten: map from trajectory dataset to dataset of individual action chunks
178
+ dataset = dataset.flatten(num_parallel_calls=num_parallel_calls)
179
+
180
+ # Filter data that doesn't pass the filter
181
+ def filter_from_dict(frame):
182
+ return frame["passes_filter"]
183
+
184
+ dataset = dataset.filter(filter_from_dict)
185
+
186
+ # Remove "passes_filter" key from output
187
+ def remove_passes_filter(frame):
188
+ frame.pop("passes_filter")
189
+ return frame
190
+
191
+ dataset = dataset.map(remove_passes_filter)
192
+
193
+ # Decode images: RLDS saves encoded images, only decode now for efficiency
194
+ def decode_images(traj):
195
+ traj["observation"]["image"] = tf.io.decode_image(
196
+ traj["observation"]["image"], expand_animations=False, dtype=tf.uint8
197
+ )
198
+ traj["observation"]["wrist_image"] = tf.io.decode_image(
199
+ traj["observation"]["wrist_image"], expand_animations=False, dtype=tf.uint8
200
+ )
201
+ return traj
202
+
203
+ dataset = dataset.frame_map(decode_images, num_parallel_calls)
204
+
205
+ # Shuffle, batch
206
+ dataset = dataset.shuffle(shuffle_buffer_size)
207
+ dataset = dataset.batch(batch_size)
208
+ # Note =>> Seems to reduce memory usage without affecting speed?
209
+ dataset = dataset.with_ram_budget(1)
210
+
211
+ self.dataset = dataset
212
+ self.batch_size = batch_size
213
+ self.shuffle = shuffle
214
+
215
+ def __iter__(self):
216
+ yield from self.dataset.as_numpy_iterator()
217
+
218
+ def __len__(self):
219
+ # This is the approximate number of samples in DROID after filtering.
220
+ # Easier to hardcode than to iterate through the dataset and compute it.
221
+ return 20_000_000
capvector-pi05/src/openpi/training/misc/roboarena_config.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """RoboArena baseline policy configs."""
2
+
3
+ from typing import TypeAlias
4
+
5
+ import openpi.models.model as _model
6
+ import openpi.models.pi0_config as pi0_config
7
+ import openpi.models.pi0_fast as pi0_fast
8
+ import openpi.models.tokenizer as _tokenizer
9
+ import openpi.policies.droid_policy as droid_policy
10
+ import openpi.transforms as _transforms
11
+
12
+ ModelType: TypeAlias = _model.ModelType
13
+
14
+
15
+ def get_roboarena_configs():
16
+ # Import here to avoid circular imports.
17
+ from openpi.training.config import AssetsConfig
18
+ from openpi.training.config import DataConfig
19
+ from openpi.training.config import SimpleDataConfig
20
+ from openpi.training.config import TrainConfig
21
+
22
+ return [
23
+ #
24
+ # RoboArena DROID baseline inference configs.
25
+ #
26
+ TrainConfig(
27
+ # Trained from PaliGemma, using RT-2 / OpenVLA style binning tokenizer.
28
+ name="paligemma_binning_droid",
29
+ model=pi0_fast.Pi0FASTConfig(
30
+ action_dim=8,
31
+ action_horizon=15,
32
+ max_token_len=400,
33
+ fast_model_tokenizer=_tokenizer.BinningTokenizer,
34
+ ),
35
+ data=SimpleDataConfig(
36
+ assets=AssetsConfig(asset_id="droid"),
37
+ data_transforms=lambda model: _transforms.Group(
38
+ inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
39
+ outputs=[droid_policy.DroidOutputs()],
40
+ ),
41
+ base_config=DataConfig(
42
+ prompt_from_task=True,
43
+ ),
44
+ ),
45
+ ),
46
+ TrainConfig(
47
+ # Trained from PaliGemma, using FAST tokenizer (using universal FAST+ tokenizer).
48
+ name="paligemma_fast_droid",
49
+ model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15),
50
+ data=SimpleDataConfig(
51
+ assets=AssetsConfig(asset_id="droid"),
52
+ data_transforms=lambda model: _transforms.Group(
53
+ inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
54
+ outputs=[droid_policy.DroidOutputs()],
55
+ ),
56
+ base_config=DataConfig(
57
+ prompt_from_task=True,
58
+ ),
59
+ ),
60
+ ),
61
+ TrainConfig(
62
+ # Trained from PaliGemma, using FAST tokenizer (tokenizer trained on DROID dataset).
63
+ name="paligemma_fast_specialist_droid",
64
+ model=pi0_fast.Pi0FASTConfig(
65
+ action_dim=8,
66
+ action_horizon=15,
67
+ fast_model_tokenizer=_tokenizer.FASTTokenizer,
68
+ fast_model_tokenizer_kwargs={"fast_tokenizer_path": "KarlP/fast_droid_specialist"},
69
+ ),
70
+ data=SimpleDataConfig(
71
+ assets=AssetsConfig(asset_id="droid"),
72
+ data_transforms=lambda model: _transforms.Group(
73
+ inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
74
+ outputs=[droid_policy.DroidOutputs()],
75
+ ),
76
+ base_config=DataConfig(
77
+ prompt_from_task=True,
78
+ ),
79
+ ),
80
+ ),
81
+ TrainConfig(
82
+ # Trained from PaliGemma, using FSQ tokenizer.
83
+ name="paligemma_vq_droid",
84
+ model=pi0_fast.Pi0FASTConfig(
85
+ action_dim=8,
86
+ action_horizon=15,
87
+ fast_model_tokenizer=_tokenizer.FSQTokenizer,
88
+ fast_model_tokenizer_kwargs={"fsq_tokenizer_path": "gs://openpi-assets/tokenizers/droid_fsq_tokenizer"},
89
+ ),
90
+ data=SimpleDataConfig(
91
+ assets=AssetsConfig(asset_id="droid"),
92
+ data_transforms=lambda model: _transforms.Group(
93
+ inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
94
+ outputs=[droid_policy.DroidOutputs()],
95
+ ),
96
+ base_config=DataConfig(
97
+ prompt_from_task=True,
98
+ ),
99
+ ),
100
+ ),
101
+ TrainConfig(
102
+ # pi0-style diffusion / flow VLA, trained on DROID from PaliGemma.
103
+ name="paligemma_diffusion_droid",
104
+ model=pi0_config.Pi0Config(action_horizon=10, action_dim=8),
105
+ data=SimpleDataConfig(
106
+ assets=AssetsConfig(asset_id="droid"),
107
+ data_transforms=lambda model: _transforms.Group(
108
+ inputs=[droid_policy.DroidInputs(action_dim=model.action_dim)],
109
+ outputs=[droid_policy.DroidOutputs()],
110
+ ),
111
+ base_config=DataConfig(
112
+ prompt_from_task=True,
113
+ ),
114
+ ),
115
+ ),
116
+ ]
capvector-pi05/src/openpi/training/optimizer.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from typing import Protocol, runtime_checkable
3
+
4
+ import jax.numpy as jnp
5
+ import optax
6
+
7
+ import openpi.shared.array_typing as at
8
+
9
+
10
+ @runtime_checkable
11
+ class LRScheduleConfig(Protocol):
12
+ def create(self) -> optax.Schedule: ...
13
+
14
+
15
+ @dataclasses.dataclass(frozen=True)
16
+ class CosineDecaySchedule(LRScheduleConfig):
17
+ """Cosine decay schedule with warmup."""
18
+
19
+ warmup_steps: int = 1_000
20
+ peak_lr: float = 2.5e-5
21
+ decay_steps: int = 30_000
22
+ decay_lr: float = 2.5e-6
23
+
24
+ def create(self) -> optax.Schedule:
25
+ return optax.warmup_cosine_decay_schedule(
26
+ init_value=self.peak_lr / (self.warmup_steps + 1),
27
+ peak_value=self.peak_lr,
28
+ warmup_steps=self.warmup_steps,
29
+ decay_steps=self.decay_steps,
30
+ end_value=self.decay_lr,
31
+ )
32
+
33
+
34
+ @dataclasses.dataclass(frozen=True)
35
+ class RsqrtDecaySchedule(LRScheduleConfig):
36
+ """Inverse square root decay schedule with warmup."""
37
+
38
+ warmup_steps: int = 1_000
39
+ peak_lr: float = 5e-5
40
+ timescale: float = 10_000
41
+
42
+ def create(self) -> optax.Schedule:
43
+ return optax.join_schedules(
44
+ [
45
+ optax.linear_schedule(
46
+ init_value=self.peak_lr / (self.warmup_steps + 1),
47
+ end_value=self.peak_lr,
48
+ transition_steps=self.warmup_steps,
49
+ ),
50
+ lambda step: self.peak_lr / jnp.sqrt((self.timescale + step) / self.timescale),
51
+ ],
52
+ [self.warmup_steps],
53
+ )
54
+
55
+
56
+ @runtime_checkable
57
+ class OptimizerConfig(Protocol):
58
+ def create(
59
+ self,
60
+ lr: optax.ScalarOrSchedule,
61
+ weight_decay_mask: at.PyTree | None = None,
62
+ ) -> optax.GradientTransformation: ...
63
+
64
+
65
+ @dataclasses.dataclass(frozen=True)
66
+ class AdamW(OptimizerConfig):
67
+ """AdamW optimizer."""
68
+
69
+ b1: float = 0.9
70
+ b2: float = 0.95
71
+ eps: float = 1e-8
72
+ # Changing this to 0 can cause out-of-memory errors for some reason, so we set it to a negligible value.
73
+ weight_decay: float = 1e-10
74
+ clip_gradient_norm: float = 1.0
75
+
76
+ def create(
77
+ self,
78
+ lr: optax.ScalarOrSchedule,
79
+ weight_decay_mask: at.PyTree | None = None,
80
+ ) -> optax.GradientTransformation:
81
+ tx = optax.adamw(
82
+ lr, b1=self.b1, b2=self.b2, eps=self.eps, weight_decay=self.weight_decay, mask=weight_decay_mask
83
+ )
84
+
85
+ return optax.chain(optax.clip_by_global_norm(self.clip_gradient_norm), tx)
86
+
87
+
88
+ @dataclasses.dataclass(frozen=True)
89
+ class SGD(OptimizerConfig):
90
+ """SGD optimizer."""
91
+
92
+ lr: float = 5e-5
93
+ momentum: float = 0.9
94
+ nesterov: bool = False
95
+
96
+ def create(
97
+ self,
98
+ lr: optax.ScalarOrSchedule,
99
+ weight_decay_mask: at.PyTree | None = None,
100
+ ) -> optax.GradientTransformation:
101
+ assert weight_decay_mask is None, "Weight decay is not supported for SGD"
102
+ return optax.sgd(lr, momentum=self.momentum, nesterov=self.nesterov)
103
+
104
+
105
+ def create_optimizer(
106
+ optimizer: OptimizerConfig, lr_schedule: LRScheduleConfig, weight_decay_mask: at.PyTree | None = None
107
+ ) -> optax.GradientTransformation:
108
+ lr = lr_schedule.create()
109
+ return optimizer.create(lr, weight_decay_mask=weight_decay_mask)
capvector-pi05/src/openpi/training/sharding.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import logging
3
+
4
+ import jax
5
+ import numpy as np
6
+
7
+ BATCH_AXIS = "batch"
8
+ FSDP_AXIS = "fsdp"
9
+ # In FSDP, we shard the data across both the batch and FSDP axes.
10
+ DATA_AXIS = (BATCH_AXIS, FSDP_AXIS)
11
+
12
+
13
+ class _MeshState:
14
+ active_mesh: jax.sharding.Mesh | None = None
15
+
16
+
17
+ def make_mesh(num_fsdp_devices: int) -> jax.sharding.Mesh:
18
+ if jax.device_count() % num_fsdp_devices != 0:
19
+ raise ValueError(
20
+ f"Number of devices {jax.device_count()} must be divisible by the number of FSDP devices {num_fsdp_devices}."
21
+ )
22
+ mesh_shape = (jax.device_count() // num_fsdp_devices, num_fsdp_devices)
23
+ return jax.make_mesh(mesh_shape, (BATCH_AXIS, FSDP_AXIS))
24
+
25
+
26
+ @contextlib.contextmanager
27
+ def set_mesh(mesh: jax.sharding.Mesh):
28
+ """Plumbing the mesh deep into the module tree is extremeley cumbersome; until the JAX team lands a better API, a
29
+ custom context manager like this one is the recommended way to maintain a reference to a global mesh. This is only used
30
+ in `activation_sharding_constraint` below."""
31
+ if _MeshState.active_mesh is not None:
32
+ raise ValueError("Cannot nest set_mesh context managers.")
33
+ _MeshState.active_mesh = mesh
34
+ try:
35
+ yield
36
+ finally:
37
+ _MeshState.active_mesh = None
38
+
39
+
40
+ def activation_sharding_constraint(pytree):
41
+ if _MeshState.active_mesh is None:
42
+ return pytree
43
+ return jax.lax.with_sharding_constraint(
44
+ pytree, jax.sharding.NamedSharding(_MeshState.active_mesh, jax.sharding.PartitionSpec(DATA_AXIS))
45
+ )
46
+
47
+
48
+ def fsdp_sharding(
49
+ pytree,
50
+ mesh: jax.sharding.Mesh,
51
+ *,
52
+ min_size_mbytes: int = 4, # 4 MiB
53
+ log: bool = False,
54
+ ):
55
+ """Apply FSDP sharding to a pytree of arrays based on the mesh shape.
56
+
57
+ Args:
58
+ pytree: A pytree to be apply sharding specified by the mesh, note that only array types (eg. contains .shape attr)
59
+ will be considered for sharding.
60
+ mesh: The mesh being used for applying sharding on to pytree.
61
+ min_size_mbytes: The minimum size of the array in MiB to be considered for sharding, any array smaller than this
62
+ will be replicated.
63
+ log: If true, will log the sharding decisions for arrays that are being considered for sharding.
64
+
65
+ Returns:
66
+ The sharded pytree.
67
+ """
68
+ min_size_bytes = min_size_mbytes * 2**20
69
+
70
+ def _shard_arr(kp, array: jax.ShapeDtypeStruct):
71
+ # if fsdp is not actually going to be used, replicate everything to avoid extraneous logging
72
+ if mesh.shape[FSDP_AXIS] == 1:
73
+ return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
74
+ # replicate scalar and vector arrays
75
+ if not hasattr(array, "shape"):
76
+ return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
77
+ if len(array.shape) < 2:
78
+ return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
79
+ # replicate small arrays
80
+ if (arr_size := np.prod(array.shape) * np.dtype(array.dtype).itemsize) < min_size_bytes:
81
+ return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
82
+
83
+ # shard matrices and larger tensors along the largest axis that is divisible by the fsdp dimension
84
+ axes = np.argsort(array.shape)[::-1]
85
+ spec = [None] * len(axes)
86
+ for i in axes:
87
+ if array.shape[i] % mesh.shape[FSDP_AXIS] == 0:
88
+ if log:
89
+ logging.info(
90
+ f"Sharding {jax.tree_util.keystr(kp)} of shape {array.shape} ({arr_size / 2**20:.2f} MiB) along axis {i}"
91
+ )
92
+ spec[i] = FSDP_AXIS
93
+ return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(*spec))
94
+
95
+ # replicate if no valid sharding was found
96
+ if log:
97
+ logging.warning(
98
+ f"Could not find a valid sharding for {jax.tree_util.keystr(kp)} of shape {array.shape} with mesh of shape {mesh.shape}"
99
+ )
100
+ return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
101
+
102
+ return jax.tree_util.tree_map_with_path(_shard_arr, pytree)
capvector-pi05/src/openpi/training/utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Callable
2
+ from typing import Any
3
+
4
+ from flax import nnx
5
+ from flax import struct
6
+ import jax
7
+ import optax
8
+
9
+ from openpi.models import model as _model
10
+ from openpi.shared import array_typing as at
11
+
12
+
13
+ @at.typecheck
14
+ @struct.dataclass
15
+ class TrainState:
16
+ step: at.Int[at.ArrayLike, ""]
17
+ params: nnx.State
18
+ model_def: nnx.GraphDef[_model.BaseModel]
19
+ opt_state: optax.OptState
20
+ tx: optax.GradientTransformation = struct.field(pytree_node=False)
21
+
22
+ ema_decay: float | None = struct.field(pytree_node=False)
23
+ ema_params: nnx.State | None = None
24
+
25
+
26
+ @at.typecheck
27
+ def tree_to_info(tree: at.PyTree, interp_func: Callable[[Any], str] = str) -> str:
28
+ """Converts a PyTree into a human-readable string for logging. Optionally, `interp_func` can be provided to convert
29
+ the leaf values to more meaningful strings.
30
+ """
31
+ tree, _ = jax.tree_util.tree_flatten_with_path(tree)
32
+ return "\n".join(f"{jax.tree_util.keystr(path)}: {interp_func(value)}" for path, value in tree)
33
+
34
+
35
+ @at.typecheck
36
+ def array_tree_to_info(tree: at.PyTree) -> str:
37
+ """Converts a PyTree of arrays into a human-readable string for logging."""
38
+ return tree_to_info(tree, lambda x: f"{x.shape}@{x.dtype}")
capvector-pi05/src/openpi/training/weight_loaders.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import logging
3
+ import re
4
+ from typing import Protocol, runtime_checkable
5
+
6
+ import flax.traverse_util
7
+ import numpy as np
8
+
9
+ import openpi.models.model as _model
10
+ import openpi.shared.array_typing as at
11
+ import openpi.shared.download as download
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ @runtime_checkable
17
+ class WeightLoader(Protocol):
18
+ def load(self, params: at.Params) -> at.Params:
19
+ """Loads the model weights.
20
+
21
+ Args:
22
+ params: Parameters of the model. This is a nested structure of array-like objects that
23
+ represent the model's parameters.
24
+
25
+ Returns:
26
+ Loaded parameters. The structure must be identical to `params`. If returning a subset of
27
+ the parameters the loader must merge the loaded parameters with `params`.
28
+ """
29
+
30
+
31
+ @dataclasses.dataclass(frozen=True)
32
+ class NoOpWeightLoader(WeightLoader):
33
+ def load(self, params: at.Params) -> at.Params:
34
+ return params
35
+
36
+
37
+ @dataclasses.dataclass(frozen=True)
38
+ class CheckpointWeightLoader(WeightLoader):
39
+ """Loads an entire set of weights from a checkpoint.
40
+
41
+ Compatible with:
42
+ trained checkpoints:
43
+ example: "./checkpoints/<config>/<exp>/<step>/params"
44
+ released checkpoints:
45
+ example: "gs://openpi-assets/checkpoints/<model>/params"
46
+ """
47
+
48
+ params_path: str
49
+
50
+ def load(self, params: at.Params) -> at.Params:
51
+ # We are loading np.ndarray and relying on the training code to properly convert and shard the params.
52
+ loaded_params = _model.restore_params(download.maybe_download(self.params_path), restore_type=np.ndarray)
53
+ # Add all missing LoRA weights.
54
+ return _merge_params(loaded_params, params, missing_regex=".*lora.*")
55
+
56
+
57
+ @dataclasses.dataclass(frozen=True)
58
+ class PaliGemmaWeightLoader(WeightLoader):
59
+ """Loads weights from the official PaliGemma checkpoint.
60
+
61
+ This will overwrite existing weights with similar names while keeping all extra weights intact.
62
+ This allows us to support the action expert which is used by the Pi0 model.
63
+ """
64
+
65
+ def load(self, params: at.Params) -> at.Params:
66
+ path = download.maybe_download(
67
+ "gs://vertex-model-garden-paligemma-us/paligemma/pt_224.npz", gs={"token": "anon"}
68
+ )
69
+ with path.open("rb") as f:
70
+ flat_params = dict(np.load(f, allow_pickle=False))
71
+ loaded_params = {"PaliGemma": flax.traverse_util.unflatten_dict(flat_params, sep="/")["params"]}
72
+ # Add all missing weights.
73
+ return _merge_params(loaded_params, params, missing_regex=".*")
74
+
75
+
76
+ def _merge_params(loaded_params: at.Params, params: at.Params, *, missing_regex: str) -> at.Params:
77
+ """Merges the loaded parameters with the reference parameters.
78
+
79
+ Args:
80
+ loaded_params: The parameters to merge.
81
+ params: The reference parameters.
82
+ missing_regex: A regex pattern for all missing keys that should be merged from the reference parameters.
83
+
84
+ Returns:
85
+ A new dictionary with the merged parameters.
86
+ """
87
+ flat_ref = flax.traverse_util.flatten_dict(params, sep="/")
88
+ flat_loaded = flax.traverse_util.flatten_dict(loaded_params, sep="/")
89
+
90
+ # First, take all weights that are a subset of the reference weights.
91
+ result = {}
92
+ for k, v in flat_loaded.items():
93
+ if k in flat_ref:
94
+ result[k] = v.astype(flat_ref[k].dtype) if v.dtype != flat_ref[k].dtype else v
95
+
96
+ flat_loaded.clear()
97
+
98
+ # Then, merge any missing weights as defined by the missing regex.
99
+ pattern = re.compile(missing_regex)
100
+ for k in {k for k in flat_ref if pattern.fullmatch(k)}:
101
+ if k not in result:
102
+ result[k] = flat_ref[k]
103
+
104
+ return flax.traverse_util.unflatten_dict(result, sep="/")
capvector-pi05/src/vggt/__init__.py ADDED
File without changes
capvector-pi05/src/vggt/dependency/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .track_modules.track_refine import refine_track
2
+ from .track_modules.blocks import BasicEncoder, ShallowEncoder
3
+ from .track_modules.base_track_predictor import BaseTrackerPredictor
capvector-pi05/src/vggt/dependency/distortion.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import numpy as np
9
+ from typing import Union
10
+
11
+ ArrayLike = Union[np.ndarray, torch.Tensor]
12
+
13
+
14
+ def _is_numpy(x: ArrayLike) -> bool:
15
+ return isinstance(x, np.ndarray)
16
+
17
+
18
+ def _is_torch(x: ArrayLike) -> bool:
19
+ return isinstance(x, torch.Tensor)
20
+
21
+
22
+ def _ensure_torch(x: ArrayLike) -> torch.Tensor:
23
+ """Convert input to torch tensor if it's not already one."""
24
+ if _is_numpy(x):
25
+ return torch.from_numpy(x)
26
+ elif _is_torch(x):
27
+ return x
28
+ else:
29
+ return torch.tensor(x)
30
+
31
+
32
+ def single_undistortion(params, tracks_normalized):
33
+ """
34
+ Apply undistortion to the normalized tracks using the given distortion parameters once.
35
+
36
+ Args:
37
+ params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN.
38
+ tracks_normalized (torch.Tensor or numpy.ndarray): Normalized tracks tensor of shape [batch_size, num_tracks, 2].
39
+
40
+ Returns:
41
+ torch.Tensor: Undistorted normalized tracks tensor.
42
+ """
43
+ params = _ensure_torch(params)
44
+ tracks_normalized = _ensure_torch(tracks_normalized)
45
+
46
+ u, v = tracks_normalized[..., 0].clone(), tracks_normalized[..., 1].clone()
47
+ u_undist, v_undist = apply_distortion(params, u, v)
48
+ return torch.stack([u_undist, v_undist], dim=-1)
49
+
50
+
51
+ def iterative_undistortion(params, tracks_normalized, max_iterations=100, max_step_norm=1e-10, rel_step_size=1e-6):
52
+ """
53
+ Iteratively undistort the normalized tracks using the given distortion parameters.
54
+
55
+ Args:
56
+ params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN.
57
+ tracks_normalized (torch.Tensor or numpy.ndarray): Normalized tracks tensor of shape [batch_size, num_tracks, 2].
58
+ max_iterations (int): Maximum number of iterations for the undistortion process.
59
+ max_step_norm (float): Maximum step norm for convergence.
60
+ rel_step_size (float): Relative step size for numerical differentiation.
61
+
62
+ Returns:
63
+ torch.Tensor: Undistorted normalized tracks tensor.
64
+ """
65
+ params = _ensure_torch(params)
66
+ tracks_normalized = _ensure_torch(tracks_normalized)
67
+
68
+ B, N, _ = tracks_normalized.shape
69
+ u, v = tracks_normalized[..., 0].clone(), tracks_normalized[..., 1].clone()
70
+ original_u, original_v = u.clone(), v.clone()
71
+
72
+ eps = torch.finfo(u.dtype).eps
73
+ for idx in range(max_iterations):
74
+ u_undist, v_undist = apply_distortion(params, u, v)
75
+ dx = original_u - u_undist
76
+ dy = original_v - v_undist
77
+
78
+ step_u = torch.clamp(torch.abs(u) * rel_step_size, min=eps)
79
+ step_v = torch.clamp(torch.abs(v) * rel_step_size, min=eps)
80
+
81
+ J_00 = (apply_distortion(params, u + step_u, v)[0] - apply_distortion(params, u - step_u, v)[0]) / (2 * step_u)
82
+ J_01 = (apply_distortion(params, u, v + step_v)[0] - apply_distortion(params, u, v - step_v)[0]) / (2 * step_v)
83
+ J_10 = (apply_distortion(params, u + step_u, v)[1] - apply_distortion(params, u - step_u, v)[1]) / (2 * step_u)
84
+ J_11 = (apply_distortion(params, u, v + step_v)[1] - apply_distortion(params, u, v - step_v)[1]) / (2 * step_v)
85
+
86
+ J = torch.stack([torch.stack([J_00 + 1, J_01], dim=-1), torch.stack([J_10, J_11 + 1], dim=-1)], dim=-2)
87
+
88
+ delta = torch.linalg.solve(J, torch.stack([dx, dy], dim=-1))
89
+
90
+ u += delta[..., 0]
91
+ v += delta[..., 1]
92
+
93
+ if torch.max((delta**2).sum(dim=-1)) < max_step_norm:
94
+ break
95
+
96
+ return torch.stack([u, v], dim=-1)
97
+
98
+
99
+ def apply_distortion(extra_params, u, v):
100
+ """
101
+ Applies radial or OpenCV distortion to the given 2D points.
102
+
103
+ Args:
104
+ extra_params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN, where N can be 1, 2, or 4.
105
+ u (torch.Tensor or numpy.ndarray): Normalized x coordinates of shape Bxnum_tracks.
106
+ v (torch.Tensor or numpy.ndarray): Normalized y coordinates of shape Bxnum_tracks.
107
+
108
+ Returns:
109
+ points2D (torch.Tensor): Distorted 2D points of shape BxNx2.
110
+ """
111
+ extra_params = _ensure_torch(extra_params)
112
+ u = _ensure_torch(u)
113
+ v = _ensure_torch(v)
114
+
115
+ num_params = extra_params.shape[1]
116
+
117
+ if num_params == 1:
118
+ # Simple radial distortion
119
+ k = extra_params[:, 0]
120
+ u2 = u * u
121
+ v2 = v * v
122
+ r2 = u2 + v2
123
+ radial = k[:, None] * r2
124
+ du = u * radial
125
+ dv = v * radial
126
+
127
+ elif num_params == 2:
128
+ # RadialCameraModel distortion
129
+ k1, k2 = extra_params[:, 0], extra_params[:, 1]
130
+ u2 = u * u
131
+ v2 = v * v
132
+ r2 = u2 + v2
133
+ radial = k1[:, None] * r2 + k2[:, None] * r2 * r2
134
+ du = u * radial
135
+ dv = v * radial
136
+
137
+ elif num_params == 4:
138
+ # OpenCVCameraModel distortion
139
+ k1, k2, p1, p2 = (extra_params[:, 0], extra_params[:, 1], extra_params[:, 2], extra_params[:, 3])
140
+ u2 = u * u
141
+ v2 = v * v
142
+ uv = u * v
143
+ r2 = u2 + v2
144
+ radial = k1[:, None] * r2 + k2[:, None] * r2 * r2
145
+ du = u * radial + 2 * p1[:, None] * uv + p2[:, None] * (r2 + 2 * u2)
146
+ dv = v * radial + 2 * p2[:, None] * uv + p1[:, None] * (r2 + 2 * v2)
147
+ else:
148
+ raise ValueError("Unsupported number of distortion parameters")
149
+
150
+ u = u.clone() + du
151
+ v = v.clone() + dv
152
+
153
+ return u, v
154
+
155
+
156
+ if __name__ == "__main__":
157
+ import random
158
+ import pycolmap
159
+
160
+ max_diff = 0
161
+ for i in range(1000):
162
+ # Define distortion parameters (assuming 1 parameter for simplicity)
163
+ B = random.randint(1, 500)
164
+ track_num = random.randint(100, 1000)
165
+ params = torch.rand((B, 1), dtype=torch.float32) # Batch size 1, 4 parameters
166
+ tracks_normalized = torch.rand((B, track_num, 2), dtype=torch.float32) # Batch size 1, 5 points
167
+
168
+ # Undistort the tracks
169
+ undistorted_tracks = iterative_undistortion(params, tracks_normalized)
170
+
171
+ for b in range(B):
172
+ pycolmap_intri = np.array([1, 0, 0, params[b].item()])
173
+ pycam = pycolmap.Camera(model="SIMPLE_RADIAL", width=1, height=1, params=pycolmap_intri, camera_id=0)
174
+
175
+ undistorted_tracks_pycolmap = pycam.cam_from_img(tracks_normalized[b].numpy())
176
+ diff = (undistorted_tracks[b] - undistorted_tracks_pycolmap).abs().median()
177
+ max_diff = max(max_diff, diff)
178
+ print(f"diff: {diff}, max_diff: {max_diff}")
179
+
180
+ import pdb
181
+
182
+ pdb.set_trace()
capvector-pi05/src/vggt/dependency/np_to_pycolmap.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import pycolmap
9
+ from .projection import project_3D_points_np
10
+
11
+
12
+ def batch_np_matrix_to_pycolmap(
13
+ points3d,
14
+ extrinsics,
15
+ intrinsics,
16
+ tracks,
17
+ image_size,
18
+ masks=None,
19
+ max_reproj_error=None,
20
+ max_points3D_val=3000,
21
+ shared_camera=False,
22
+ camera_type="SIMPLE_PINHOLE",
23
+ extra_params=None,
24
+ min_inlier_per_frame=64,
25
+ points_rgb=None,
26
+ ):
27
+ """
28
+ Convert Batched NumPy Arrays to PyCOLMAP
29
+
30
+ Check https://github.com/colmap/pycolmap for more details about its format
31
+
32
+ NOTE that colmap expects images/cameras/points3D to be 1-indexed
33
+ so there is a +1 offset between colmap index and batch index
34
+
35
+
36
+ NOTE: different from VGGSfM, this function:
37
+ 1. Use np instead of torch
38
+ 2. Frame index and camera id starts from 1 rather than 0 (to fit the format of PyCOLMAP)
39
+ """
40
+ # points3d: Px3
41
+ # extrinsics: Nx3x4
42
+ # intrinsics: Nx3x3
43
+ # tracks: NxPx2
44
+ # masks: NxP
45
+ # image_size: 2, assume all the frames have been padded to the same size
46
+ # where N is the number of frames and P is the number of tracks
47
+
48
+ N, P, _ = tracks.shape
49
+ assert len(extrinsics) == N
50
+ assert len(intrinsics) == N
51
+ assert len(points3d) == P
52
+ assert image_size.shape[0] == 2
53
+
54
+ reproj_mask = None
55
+
56
+ if max_reproj_error is not None:
57
+ projected_points_2d, projected_points_cam = project_3D_points_np(points3d, extrinsics, intrinsics)
58
+ projected_diff = np.linalg.norm(projected_points_2d - tracks, axis=-1)
59
+ projected_points_2d[projected_points_cam[:, -1] <= 0] = 1e6
60
+ reproj_mask = projected_diff < max_reproj_error
61
+
62
+ if masks is not None and reproj_mask is not None:
63
+ masks = np.logical_and(masks, reproj_mask)
64
+ elif masks is not None:
65
+ masks = masks
66
+ else:
67
+ masks = reproj_mask
68
+
69
+ assert masks is not None
70
+
71
+ if masks.sum(1).min() < min_inlier_per_frame:
72
+ print(f"Not enough inliers per frame, skip BA.")
73
+ return None, None
74
+
75
+ # Reconstruction object, following the format of PyCOLMAP/COLMAP
76
+ reconstruction = pycolmap.Reconstruction()
77
+
78
+ inlier_num = masks.sum(0)
79
+ valid_mask = inlier_num >= 2 # a track is invalid if without two inliers
80
+ valid_idx = np.nonzero(valid_mask)[0]
81
+
82
+ # Only add 3D points that have sufficient 2D points
83
+ for vidx in valid_idx:
84
+ # Use RGB colors if provided, otherwise use zeros
85
+ rgb = points_rgb[vidx] if points_rgb is not None else np.zeros(3)
86
+ reconstruction.add_point3D(points3d[vidx], pycolmap.Track(), rgb)
87
+
88
+ num_points3D = len(valid_idx)
89
+ camera = None
90
+ # frame idx
91
+ for fidx in range(N):
92
+ # set camera
93
+ if camera is None or (not shared_camera):
94
+ pycolmap_intri = _build_pycolmap_intri(fidx, intrinsics, camera_type, extra_params)
95
+
96
+ camera = pycolmap.Camera(
97
+ model=camera_type, width=image_size[0], height=image_size[1], params=pycolmap_intri, camera_id=fidx + 1
98
+ )
99
+
100
+ # add camera
101
+ reconstruction.add_camera(camera)
102
+
103
+ # set image
104
+ cam_from_world = pycolmap.Rigid3d(
105
+ pycolmap.Rotation3d(extrinsics[fidx][:3, :3]), extrinsics[fidx][:3, 3]
106
+ ) # Rot and Trans
107
+
108
+ image = pycolmap.Image(
109
+ id=fidx + 1, name=f"image_{fidx + 1}", camera_id=camera.camera_id, cam_from_world=cam_from_world
110
+ )
111
+
112
+ points2D_list = []
113
+
114
+ point2D_idx = 0
115
+
116
+ # NOTE point3D_id start by 1
117
+ for point3D_id in range(1, num_points3D + 1):
118
+ original_track_idx = valid_idx[point3D_id - 1]
119
+
120
+ if (reconstruction.points3D[point3D_id].xyz < max_points3D_val).all():
121
+ if masks[fidx][original_track_idx]:
122
+ # It seems we don't need +0.5 for BA
123
+ point2D_xy = tracks[fidx][original_track_idx]
124
+ # Please note when adding the Point2D object
125
+ # It not only requires the 2D xy location, but also the id to 3D point
126
+ points2D_list.append(pycolmap.Point2D(point2D_xy, point3D_id))
127
+
128
+ # add element
129
+ track = reconstruction.points3D[point3D_id].track
130
+ track.add_element(fidx + 1, point2D_idx)
131
+ point2D_idx += 1
132
+
133
+ assert point2D_idx == len(points2D_list)
134
+
135
+ try:
136
+ image.points2D = pycolmap.ListPoint2D(points2D_list)
137
+ image.registered = True
138
+ except:
139
+ print(f"frame {fidx + 1} is out of BA")
140
+ image.registered = False
141
+
142
+ # add image
143
+ reconstruction.add_image(image)
144
+
145
+ return reconstruction, valid_mask
146
+
147
+
148
+ def pycolmap_to_batch_np_matrix(reconstruction, device="cpu", camera_type="SIMPLE_PINHOLE"):
149
+ """
150
+ Convert a PyCOLMAP Reconstruction Object to batched NumPy arrays.
151
+
152
+ Args:
153
+ reconstruction (pycolmap.Reconstruction): The reconstruction object from PyCOLMAP.
154
+ device (str): Ignored in NumPy version (kept for API compatibility).
155
+ camera_type (str): The type of camera model used (default: "SIMPLE_PINHOLE").
156
+
157
+ Returns:
158
+ tuple: A tuple containing points3D, extrinsics, intrinsics, and optionally extra_params.
159
+ """
160
+
161
+ num_images = len(reconstruction.images)
162
+ max_points3D_id = max(reconstruction.point3D_ids())
163
+ points3D = np.zeros((max_points3D_id, 3))
164
+
165
+ for point3D_id in reconstruction.points3D:
166
+ points3D[point3D_id - 1] = reconstruction.points3D[point3D_id].xyz
167
+
168
+ extrinsics = []
169
+ intrinsics = []
170
+
171
+ extra_params = [] if camera_type == "SIMPLE_RADIAL" else None
172
+
173
+ for i in range(num_images):
174
+ # Extract and append extrinsics
175
+ pyimg = reconstruction.images[i + 1]
176
+ pycam = reconstruction.cameras[pyimg.camera_id]
177
+ matrix = pyimg.cam_from_world.matrix()
178
+ extrinsics.append(matrix)
179
+
180
+ # Extract and append intrinsics
181
+ calibration_matrix = pycam.calibration_matrix()
182
+ intrinsics.append(calibration_matrix)
183
+
184
+ if camera_type == "SIMPLE_RADIAL":
185
+ extra_params.append(pycam.params[-1])
186
+
187
+ # Convert lists to NumPy arrays instead of torch tensors
188
+ extrinsics = np.stack(extrinsics)
189
+ intrinsics = np.stack(intrinsics)
190
+
191
+ if camera_type == "SIMPLE_RADIAL":
192
+ extra_params = np.stack(extra_params)
193
+ extra_params = extra_params[:, None]
194
+
195
+ return points3D, extrinsics, intrinsics, extra_params
196
+
197
+
198
+ ########################################################
199
+
200
+
201
+ def batch_np_matrix_to_pycolmap_wo_track(
202
+ points3d,
203
+ points_xyf,
204
+ points_rgb,
205
+ extrinsics,
206
+ intrinsics,
207
+ image_size,
208
+ shared_camera=False,
209
+ camera_type="SIMPLE_PINHOLE",
210
+ ):
211
+ """
212
+ Convert Batched NumPy Arrays to PyCOLMAP
213
+
214
+ Different from batch_np_matrix_to_pycolmap, this function does not use tracks.
215
+
216
+ It saves points3d to colmap reconstruction format only to serve as init for Gaussians or other nvs methods.
217
+
218
+ Do NOT use this for BA.
219
+ """
220
+ # points3d: Px3
221
+ # points_xyf: Px3, with x, y coordinates and frame indices
222
+ # points_rgb: Px3, rgb colors
223
+ # extrinsics: Nx3x4
224
+ # intrinsics: Nx3x3
225
+ # image_size: 2, assume all the frames have been padded to the same size
226
+ # where N is the number of frames and P is the number of tracks
227
+
228
+ N = len(extrinsics)
229
+ P = len(points3d)
230
+
231
+ # Reconstruction object, following the format of PyCOLMAP/COLMAP
232
+ reconstruction = pycolmap.Reconstruction()
233
+
234
+ for vidx in range(P):
235
+ reconstruction.add_point3D(points3d[vidx], pycolmap.Track(), points_rgb[vidx])
236
+
237
+ camera = None
238
+ # frame idx
239
+ for fidx in range(N):
240
+ # set camera
241
+ if camera is None or (not shared_camera):
242
+ pycolmap_intri = _build_pycolmap_intri(fidx, intrinsics, camera_type)
243
+
244
+ camera = pycolmap.Camera(
245
+ model=camera_type, width=image_size[0], height=image_size[1], params=pycolmap_intri, camera_id=fidx + 1
246
+ )
247
+
248
+ # add camera
249
+ reconstruction.add_camera(camera)
250
+
251
+ # set image
252
+ cam_from_world = pycolmap.Rigid3d(
253
+ pycolmap.Rotation3d(extrinsics[fidx][:3, :3]), extrinsics[fidx][:3, 3]
254
+ ) # Rot and Trans
255
+
256
+ image = pycolmap.Image(
257
+ id=fidx + 1, name=f"image_{fidx + 1}", camera_id=camera.camera_id, cam_from_world=cam_from_world
258
+ )
259
+
260
+ points2D_list = []
261
+
262
+ point2D_idx = 0
263
+
264
+ points_belong_to_fidx = points_xyf[:, 2].astype(np.int32) == fidx
265
+ points_belong_to_fidx = np.nonzero(points_belong_to_fidx)[0]
266
+
267
+ for point3D_batch_idx in points_belong_to_fidx:
268
+ point3D_id = point3D_batch_idx + 1
269
+ point2D_xyf = points_xyf[point3D_batch_idx]
270
+ point2D_xy = point2D_xyf[:2]
271
+ points2D_list.append(pycolmap.Point2D(point2D_xy, point3D_id))
272
+
273
+ # add element
274
+ track = reconstruction.points3D[point3D_id].track
275
+ track.add_element(fidx + 1, point2D_idx)
276
+ point2D_idx += 1
277
+
278
+ assert point2D_idx == len(points2D_list)
279
+
280
+ try:
281
+ image.points2D = pycolmap.ListPoint2D(points2D_list)
282
+ image.registered = True
283
+ except:
284
+ print(f"frame {fidx + 1} does not have any points")
285
+ image.registered = False
286
+
287
+ # add image
288
+ reconstruction.add_image(image)
289
+
290
+ return reconstruction
291
+
292
+
293
+ def _build_pycolmap_intri(fidx, intrinsics, camera_type, extra_params=None):
294
+ """
295
+ Helper function to get camera parameters based on camera type.
296
+
297
+ Args:
298
+ fidx: Frame index
299
+ intrinsics: Camera intrinsic parameters
300
+ camera_type: Type of camera model
301
+ extra_params: Additional parameters for certain camera types
302
+
303
+ Returns:
304
+ pycolmap_intri: NumPy array of camera parameters
305
+ """
306
+ if camera_type == "PINHOLE":
307
+ pycolmap_intri = np.array(
308
+ [intrinsics[fidx][0, 0], intrinsics[fidx][1, 1], intrinsics[fidx][0, 2], intrinsics[fidx][1, 2]]
309
+ )
310
+ elif camera_type == "SIMPLE_PINHOLE":
311
+ focal = (intrinsics[fidx][0, 0] + intrinsics[fidx][1, 1]) / 2
312
+ pycolmap_intri = np.array([focal, intrinsics[fidx][0, 2], intrinsics[fidx][1, 2]])
313
+ elif camera_type == "SIMPLE_RADIAL":
314
+ raise NotImplementedError("SIMPLE_RADIAL is not supported yet")
315
+ focal = (intrinsics[fidx][0, 0] + intrinsics[fidx][1, 1]) / 2
316
+ pycolmap_intri = np.array([focal, intrinsics[fidx][0, 2], intrinsics[fidx][1, 2], extra_params[fidx][0]])
317
+ else:
318
+ raise ValueError(f"Camera type {camera_type} is not supported yet")
319
+
320
+ return pycolmap_intri
capvector-pi05/src/vggt/dependency/projection.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import numpy as np
9
+ from .distortion import apply_distortion
10
+
11
+
12
+ def img_from_cam_np(
13
+ intrinsics: np.ndarray, points_cam: np.ndarray, extra_params: np.ndarray | None = None, default: float = 0.0
14
+ ) -> np.ndarray:
15
+ """
16
+ Apply intrinsics (and optional radial distortion) to camera-space points.
17
+
18
+ Args
19
+ ----
20
+ intrinsics : (B,3,3) camera matrix K.
21
+ points_cam : (B,3,N) homogeneous camera coords (x, y, z)ᵀ.
22
+ extra_params: (B, N) or (B, k) distortion params (k = 1,2,4) or None.
23
+ default : value used for np.nan replacement.
24
+
25
+ Returns
26
+ -------
27
+ points2D : (B,N,2) pixel coordinates.
28
+ """
29
+ # 1. perspective divide ───────────────────────────────────────
30
+ z = points_cam[:, 2:3, :] # (B,1,N)
31
+ points_cam_norm = points_cam / z # (B,3,N)
32
+ uv = points_cam_norm[:, :2, :] # (B,2,N)
33
+
34
+ # 2. optional distortion ──────────────────────────────────────
35
+ if extra_params is not None:
36
+ uu, vv = apply_distortion(extra_params, uv[:, 0], uv[:, 1])
37
+ uv = np.stack([uu, vv], axis=1) # (B,2,N)
38
+
39
+ # 3. homogeneous coords then K multiplication ─────────────────
40
+ ones = np.ones_like(uv[:, :1, :]) # (B,1,N)
41
+ points_cam_h = np.concatenate([uv, ones], axis=1) # (B,3,N)
42
+
43
+ # batched mat-mul: K · [u v 1]ᵀ
44
+ points2D_h = np.einsum("bij,bjk->bik", intrinsics, points_cam_h) # (B,3,N)
45
+ points2D = np.nan_to_num(points2D_h[:, :2, :], nan=default) # (B,2,N)
46
+
47
+ return points2D.transpose(0, 2, 1) # (B,N,2)
48
+
49
+
50
+ def project_3D_points_np(
51
+ points3D: np.ndarray,
52
+ extrinsics: np.ndarray,
53
+ intrinsics: np.ndarray | None = None,
54
+ extra_params: np.ndarray | None = None,
55
+ *,
56
+ default: float = 0.0,
57
+ only_points_cam: bool = False,
58
+ ):
59
+ """
60
+ NumPy clone of ``project_3D_points``.
61
+
62
+ Parameters
63
+ ----------
64
+ points3D : (N,3) world-space points.
65
+ extrinsics : (B,3,4) [R|t] matrix for each of B cameras.
66
+ intrinsics : (B,3,3) K matrix (optional if you only need cam-space).
67
+ extra_params : (B,k) or (B,N) distortion parameters (k ∈ {1,2,4}) or None.
68
+ default : value used to replace NaNs.
69
+ only_points_cam : if True, skip the projection and return points_cam with points2D as None.
70
+
71
+ Returns
72
+ -------
73
+ (points2D, points_cam) : A tuple where points2D is (B,N,2) pixel coords or None if only_points_cam=True,
74
+ and points_cam is (B,3,N) camera-space coordinates.
75
+ """
76
+ # ----- 0. prep sizes -----------------------------------------------------
77
+ N = points3D.shape[0] # #points
78
+ B = extrinsics.shape[0] # #cameras
79
+
80
+ # ----- 1. world → homogeneous -------------------------------------------
81
+ w_h = np.ones((N, 1), dtype=points3D.dtype)
82
+ points3D_h = np.concatenate([points3D, w_h], axis=1) # (N,4)
83
+
84
+ # broadcast to every camera (no actual copying with np.broadcast_to) ------
85
+ points3D_h_B = np.broadcast_to(points3D_h, (B, N, 4)) # (B,N,4)
86
+
87
+ # ----- 2. apply extrinsics (camera frame) ------------------------------
88
+ # X_cam = E · X_hom
89
+ # einsum: E_(b i j) · X_(b n j) → (b n i)
90
+ points_cam = np.einsum("bij,bnj->bni", extrinsics, points3D_h_B) # (B,N,3)
91
+ points_cam = points_cam.transpose(0, 2, 1) # (B,3,N)
92
+
93
+ if only_points_cam:
94
+ return None, points_cam
95
+
96
+ # ----- 3. intrinsics + distortion ---------------------------------------
97
+ if intrinsics is None:
98
+ raise ValueError("`intrinsics` must be provided unless only_points_cam=True")
99
+
100
+ points2D = img_from_cam_np(intrinsics, points_cam, extra_params=extra_params, default=default)
101
+
102
+ return points2D, points_cam
103
+
104
+
105
+ def project_3D_points(points3D, extrinsics, intrinsics=None, extra_params=None, default=0, only_points_cam=False):
106
+ """
107
+ Transforms 3D points to 2D using extrinsic and intrinsic parameters.
108
+ Args:
109
+ points3D (torch.Tensor): 3D points of shape Px3.
110
+ extrinsics (torch.Tensor): Extrinsic parameters of shape Bx3x4.
111
+ intrinsics (torch.Tensor): Intrinsic parameters of shape Bx3x3.
112
+ extra_params (torch.Tensor): Extra parameters of shape BxN, used for radial distortion.
113
+ default (float): Default value to replace NaNs.
114
+ only_points_cam (bool): If True, skip the projection and return points2D as None.
115
+
116
+ Returns:
117
+ tuple: (points2D, points_cam) where points2D is of shape BxNx2 or None if only_points_cam=True,
118
+ and points_cam is of shape Bx3xN.
119
+ """
120
+ with torch.cuda.amp.autocast(dtype=torch.double):
121
+ N = points3D.shape[0] # Number of points
122
+ B = extrinsics.shape[0] # Batch size, i.e., number of cameras
123
+ points3D_homogeneous = torch.cat([points3D, torch.ones_like(points3D[..., 0:1])], dim=1) # Nx4
124
+ # Reshape for batch processing
125
+ points3D_homogeneous = points3D_homogeneous.unsqueeze(0).expand(B, -1, -1) # BxNx4
126
+
127
+ # Step 1: Apply extrinsic parameters
128
+ # Transform 3D points to camera coordinate system for all cameras
129
+ points_cam = torch.bmm(extrinsics, points3D_homogeneous.transpose(-1, -2))
130
+
131
+ if only_points_cam:
132
+ return None, points_cam
133
+
134
+ # Step 2: Apply intrinsic parameters and (optional) distortion
135
+ points2D = img_from_cam(intrinsics, points_cam, extra_params, default)
136
+
137
+ return points2D, points_cam
138
+
139
+
140
+ def img_from_cam(intrinsics, points_cam, extra_params=None, default=0.0):
141
+ """
142
+ Applies intrinsic parameters and optional distortion to the given 3D points.
143
+
144
+ Args:
145
+ intrinsics (torch.Tensor): Intrinsic camera parameters of shape Bx3x3.
146
+ points_cam (torch.Tensor): 3D points in camera coordinates of shape Bx3xN.
147
+ extra_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4.
148
+ default (float, optional): Default value to replace NaNs in the output.
149
+
150
+ Returns:
151
+ points2D (torch.Tensor): 2D points in pixel coordinates of shape BxNx2.
152
+ """
153
+
154
+ # Normalize by the third coordinate (homogeneous division)
155
+ points_cam = points_cam / points_cam[:, 2:3, :]
156
+ # Extract uv
157
+ uv = points_cam[:, :2, :]
158
+
159
+ # Apply distortion if extra_params are provided
160
+ if extra_params is not None:
161
+ uu, vv = apply_distortion(extra_params, uv[:, 0], uv[:, 1])
162
+ uv = torch.stack([uu, vv], dim=1)
163
+
164
+ # Prepare points_cam for batch matrix multiplication
165
+ points_cam_homo = torch.cat((uv, torch.ones_like(uv[:, :1, :])), dim=1) # Bx3xN
166
+ # Apply intrinsic parameters using batch matrix multiplication
167
+ points2D_homo = torch.bmm(intrinsics, points_cam_homo) # Bx3xN
168
+
169
+ # Extract x and y coordinates
170
+ points2D = points2D_homo[:, :2, :] # Bx2xN
171
+
172
+ # Replace NaNs with default value
173
+ points2D = torch.nan_to_num(points2D, nan=default)
174
+
175
+ return points2D.transpose(1, 2) # BxNx2
176
+
177
+
178
+ if __name__ == "__main__":
179
+ # Set up example input
180
+ B, N = 24, 10240
181
+
182
+ for _ in range(100):
183
+ points3D = np.random.rand(N, 3).astype(np.float64)
184
+ extrinsics = np.random.rand(B, 3, 4).astype(np.float64)
185
+ intrinsics = np.random.rand(B, 3, 3).astype(np.float64)
186
+
187
+ # Convert to torch tensors
188
+ points3D_torch = torch.tensor(points3D)
189
+ extrinsics_torch = torch.tensor(extrinsics)
190
+ intrinsics_torch = torch.tensor(intrinsics)
191
+
192
+ # Run NumPy implementation
193
+ points2D_np, points_cam_np = project_3D_points_np(points3D, extrinsics, intrinsics)
194
+
195
+ # Run torch implementation
196
+ points2D_torch, points_cam_torch = project_3D_points(points3D_torch, extrinsics_torch, intrinsics_torch)
197
+
198
+ # Convert torch output to numpy
199
+ points2D_torch_np = points2D_torch.detach().numpy()
200
+ points_cam_torch_np = points_cam_torch.detach().numpy()
201
+
202
+ # Compute difference
203
+ diff = np.abs(points2D_np - points2D_torch_np)
204
+ print("Difference between NumPy and PyTorch implementations:")
205
+ print(diff)
206
+
207
+ # Check max error
208
+ max_diff = np.max(diff)
209
+ print(f"Maximum difference: {max_diff}")
210
+
211
+ if np.allclose(points2D_np, points2D_torch_np, atol=1e-6):
212
+ print("Implementations match closely.")
213
+ else:
214
+ print("Significant differences detected.")
215
+
216
+ if points_cam_np is not None:
217
+ points_cam_diff = np.abs(points_cam_np - points_cam_torch_np)
218
+ print("Difference between NumPy and PyTorch camera-space coordinates:")
219
+ print(points_cam_diff)
220
+
221
+ # Check max error
222
+ max_cam_diff = np.max(points_cam_diff)
223
+ print(f"Maximum camera-space coordinate difference: {max_cam_diff}")
224
+
225
+ if np.allclose(points_cam_np, points_cam_torch_np, atol=1e-6):
226
+ print("Camera-space coordinates match closely.")
227
+ else:
228
+ print("Significant differences detected in camera-space coordinates.")
capvector-pi05/src/vggt/dependency/track_modules/__init__.py ADDED
File without changes
capvector-pi05/src/vggt/dependency/track_modules/base_track_predictor.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from einops import rearrange, repeat
10
+
11
+ from .blocks import EfficientUpdateFormer, CorrBlock
12
+ from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed
13
+
14
+
15
+ class BaseTrackerPredictor(nn.Module):
16
+ def __init__(
17
+ self,
18
+ stride=4,
19
+ corr_levels=5,
20
+ corr_radius=4,
21
+ latent_dim=128,
22
+ hidden_size=384,
23
+ use_spaceatt=True,
24
+ depth=6,
25
+ fine=False,
26
+ ):
27
+ super(BaseTrackerPredictor, self).__init__()
28
+ """
29
+ The base template to create a track predictor
30
+
31
+ Modified from https://github.com/facebookresearch/co-tracker/
32
+ """
33
+
34
+ self.stride = stride
35
+ self.latent_dim = latent_dim
36
+ self.corr_levels = corr_levels
37
+ self.corr_radius = corr_radius
38
+ self.hidden_size = hidden_size
39
+ self.fine = fine
40
+
41
+ self.flows_emb_dim = latent_dim // 2
42
+ self.transformer_dim = self.corr_levels * (self.corr_radius * 2 + 1) ** 2 + self.latent_dim * 2
43
+
44
+ if self.fine:
45
+ # TODO this is the old dummy code, will remove this when we train next model
46
+ self.transformer_dim += 4 if self.transformer_dim % 2 == 0 else 5
47
+ else:
48
+ self.transformer_dim += (4 - self.transformer_dim % 4) % 4
49
+
50
+ space_depth = depth if use_spaceatt else 0
51
+ time_depth = depth
52
+
53
+ self.updateformer = EfficientUpdateFormer(
54
+ space_depth=space_depth,
55
+ time_depth=time_depth,
56
+ input_dim=self.transformer_dim,
57
+ hidden_size=self.hidden_size,
58
+ output_dim=self.latent_dim + 2,
59
+ mlp_ratio=4.0,
60
+ add_space_attn=use_spaceatt,
61
+ )
62
+
63
+ self.norm = nn.GroupNorm(1, self.latent_dim)
64
+
65
+ # A linear layer to update track feats at each iteration
66
+ self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU())
67
+
68
+ if not self.fine:
69
+ self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
70
+
71
+ def forward(self, query_points, fmaps=None, iters=4, return_feat=False, down_ratio=1):
72
+ """
73
+ query_points: B x N x 2, the number of batches, tracks, and xy
74
+ fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension.
75
+ note HH and WW is the size of feature maps instead of original images
76
+ """
77
+ B, N, D = query_points.shape
78
+ B, S, C, HH, WW = fmaps.shape
79
+
80
+ assert D == 2
81
+
82
+ # Scale the input query_points because we may downsample the images
83
+ # by down_ratio or self.stride
84
+ # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map
85
+ # its query_points should be query_points/4
86
+ if down_ratio > 1:
87
+ query_points = query_points / float(down_ratio)
88
+ query_points = query_points / float(self.stride)
89
+
90
+ # Init with coords as the query points
91
+ # It means the search will start from the position of query points at the reference frames
92
+ coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)
93
+
94
+ # Sample/extract the features of the query points in the query frame
95
+ query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])
96
+
97
+ # init track feats by query feats
98
+ track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C
99
+ # back up the init coords
100
+ coords_backup = coords.clone()
101
+
102
+ # Construct the correlation block
103
+
104
+ fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius)
105
+
106
+ coord_preds = []
107
+
108
+ # Iterative Refinement
109
+ for itr in range(iters):
110
+ # Detach the gradients from the last iteration
111
+ # (in my experience, not very important for performance)
112
+ coords = coords.detach()
113
+
114
+ # Compute the correlation (check the implementation of CorrBlock)
115
+
116
+ fcorr_fn.corr(track_feats)
117
+ fcorrs = fcorr_fn.sample(coords) # B, S, N, corrdim
118
+
119
+ corrdim = fcorrs.shape[3]
120
+
121
+ fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corrdim)
122
+
123
+ # Movement of current coords relative to query points
124
+ flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
125
+
126
+ flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)
127
+
128
+ # (In my trials, it is also okay to just add the flows_emb instead of concat)
129
+ flows_emb = torch.cat([flows_emb, flows], dim=-1)
130
+
131
+ track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
132
+
133
+ # Concatenate them as the input for the transformers
134
+ transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)
135
+
136
+ if transformer_input.shape[2] < self.transformer_dim:
137
+ # pad the features to match the dimension
138
+ pad_dim = self.transformer_dim - transformer_input.shape[2]
139
+ pad = torch.zeros_like(flows_emb[..., 0:pad_dim])
140
+ transformer_input = torch.cat([transformer_input, pad], dim=2)
141
+
142
+ # 2D positional embed
143
+ # TODO: this can be much simplified
144
+ pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device)
145
+ sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0])
146
+ sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1)
147
+
148
+ x = transformer_input + sampled_pos_emb
149
+
150
+ # B, N, S, C
151
+ x = rearrange(x, "(b n) s d -> b n s d", b=B)
152
+
153
+ # Compute the delta coordinates and delta track features
154
+ delta = self.updateformer(x)
155
+ # BN, S, C
156
+ delta = rearrange(delta, " b n s d -> (b n) s d", b=B)
157
+ delta_coords_ = delta[:, :, :2]
158
+ delta_feats_ = delta[:, :, 2:]
159
+
160
+ track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim)
161
+ delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
162
+
163
+ # Update the track features
164
+ track_feats_ = self.ffeat_updater(self.norm(delta_feats_)) + track_feats_
165
+ track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC
166
+
167
+ # B x S x N x 2
168
+ coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
169
+
170
+ # Force coord0 as query
171
+ # because we assume the query points should not be changed
172
+ coords[:, 0] = coords_backup[:, 0]
173
+
174
+ # The predicted tracks are in the original image scale
175
+ if down_ratio > 1:
176
+ coord_preds.append(coords * self.stride * down_ratio)
177
+ else:
178
+ coord_preds.append(coords * self.stride)
179
+
180
+ # B, S, N
181
+ if not self.fine:
182
+ vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
183
+ vis_e = torch.sigmoid(vis_e)
184
+ else:
185
+ vis_e = None
186
+
187
+ if return_feat:
188
+ return coord_preds, vis_e, track_feats, query_track_feat
189
+ else:
190
+ return coord_preds, vis_e
capvector-pi05/src/vggt/dependency/track_modules/blocks.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ # Modified from https://github.com/facebookresearch/co-tracker/
9
+
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from functools import partial
15
+ from typing import Callable
16
+ import collections
17
+ from torch import Tensor
18
+ from itertools import repeat
19
+
20
+ from .utils import bilinear_sampler
21
+
22
+ from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock
23
+
24
+
25
+ class BasicEncoder(nn.Module):
26
+ def __init__(self, input_dim=3, output_dim=128, stride=4):
27
+ super(BasicEncoder, self).__init__()
28
+
29
+ self.stride = stride
30
+ self.norm_fn = "instance"
31
+ self.in_planes = output_dim // 2
32
+
33
+ self.norm1 = nn.InstanceNorm2d(self.in_planes)
34
+ self.norm2 = nn.InstanceNorm2d(output_dim * 2)
35
+
36
+ self.conv1 = nn.Conv2d(input_dim, self.in_planes, kernel_size=7, stride=2, padding=3, padding_mode="zeros")
37
+ self.relu1 = nn.ReLU(inplace=True)
38
+ self.layer1 = self._make_layer(output_dim // 2, stride=1)
39
+ self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
40
+ self.layer3 = self._make_layer(output_dim, stride=2)
41
+ self.layer4 = self._make_layer(output_dim, stride=2)
42
+
43
+ self.conv2 = nn.Conv2d(
44
+ output_dim * 3 + output_dim // 4, output_dim * 2, kernel_size=3, padding=1, padding_mode="zeros"
45
+ )
46
+ self.relu2 = nn.ReLU(inplace=True)
47
+ self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
48
+
49
+ for m in self.modules():
50
+ if isinstance(m, nn.Conv2d):
51
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
52
+ elif isinstance(m, (nn.InstanceNorm2d)):
53
+ if m.weight is not None:
54
+ nn.init.constant_(m.weight, 1)
55
+ if m.bias is not None:
56
+ nn.init.constant_(m.bias, 0)
57
+
58
+ def _make_layer(self, dim, stride=1):
59
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
60
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
61
+ layers = (layer1, layer2)
62
+
63
+ self.in_planes = dim
64
+ return nn.Sequential(*layers)
65
+
66
+ def forward(self, x):
67
+ _, _, H, W = x.shape
68
+
69
+ x = self.conv1(x)
70
+ x = self.norm1(x)
71
+ x = self.relu1(x)
72
+
73
+ a = self.layer1(x)
74
+ b = self.layer2(a)
75
+ c = self.layer3(b)
76
+ d = self.layer4(c)
77
+
78
+ a = _bilinear_intepolate(a, self.stride, H, W)
79
+ b = _bilinear_intepolate(b, self.stride, H, W)
80
+ c = _bilinear_intepolate(c, self.stride, H, W)
81
+ d = _bilinear_intepolate(d, self.stride, H, W)
82
+
83
+ x = self.conv2(torch.cat([a, b, c, d], dim=1))
84
+ x = self.norm2(x)
85
+ x = self.relu2(x)
86
+ x = self.conv3(x)
87
+ return x
88
+
89
+
90
+ class ShallowEncoder(nn.Module):
91
+ def __init__(self, input_dim=3, output_dim=32, stride=1, norm_fn="instance"):
92
+ super(ShallowEncoder, self).__init__()
93
+ self.stride = stride
94
+ self.norm_fn = norm_fn
95
+ self.in_planes = output_dim
96
+
97
+ if self.norm_fn == "group":
98
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes)
99
+ self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2)
100
+ elif self.norm_fn == "batch":
101
+ self.norm1 = nn.BatchNorm2d(self.in_planes)
102
+ self.norm2 = nn.BatchNorm2d(output_dim * 2)
103
+ elif self.norm_fn == "instance":
104
+ self.norm1 = nn.InstanceNorm2d(self.in_planes)
105
+ self.norm2 = nn.InstanceNorm2d(output_dim * 2)
106
+ elif self.norm_fn == "none":
107
+ self.norm1 = nn.Sequential()
108
+
109
+ self.conv1 = nn.Conv2d(input_dim, self.in_planes, kernel_size=3, stride=2, padding=1, padding_mode="zeros")
110
+ self.relu1 = nn.ReLU(inplace=True)
111
+
112
+ self.layer1 = self._make_layer(output_dim, stride=2)
113
+
114
+ self.layer2 = self._make_layer(output_dim, stride=2)
115
+ self.conv2 = nn.Conv2d(output_dim, output_dim, kernel_size=1)
116
+
117
+ for m in self.modules():
118
+ if isinstance(m, nn.Conv2d):
119
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
120
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
121
+ if m.weight is not None:
122
+ nn.init.constant_(m.weight, 1)
123
+ if m.bias is not None:
124
+ nn.init.constant_(m.bias, 0)
125
+
126
+ def _make_layer(self, dim, stride=1):
127
+ self.in_planes = dim
128
+
129
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
130
+ return layer1
131
+
132
+ def forward(self, x):
133
+ _, _, H, W = x.shape
134
+
135
+ x = self.conv1(x)
136
+ x = self.norm1(x)
137
+ x = self.relu1(x)
138
+
139
+ tmp = self.layer1(x)
140
+ x = x + F.interpolate(tmp, (x.shape[-2:]), mode="bilinear", align_corners=True)
141
+ tmp = self.layer2(tmp)
142
+ x = x + F.interpolate(tmp, (x.shape[-2:]), mode="bilinear", align_corners=True)
143
+ tmp = None
144
+ x = self.conv2(x) + x
145
+
146
+ x = F.interpolate(x, (H // self.stride, W // self.stride), mode="bilinear", align_corners=True)
147
+
148
+ return x
149
+
150
+
151
+ def _bilinear_intepolate(x, stride, H, W):
152
+ return F.interpolate(x, (H // stride, W // stride), mode="bilinear", align_corners=True)
153
+
154
+
155
+ class EfficientUpdateFormer(nn.Module):
156
+ """
157
+ Transformer model that updates track estimates.
158
+ """
159
+
160
+ def __init__(
161
+ self,
162
+ space_depth=6,
163
+ time_depth=6,
164
+ input_dim=320,
165
+ hidden_size=384,
166
+ num_heads=8,
167
+ output_dim=130,
168
+ mlp_ratio=4.0,
169
+ add_space_attn=True,
170
+ num_virtual_tracks=64,
171
+ ):
172
+ super().__init__()
173
+
174
+ self.out_channels = 2
175
+ self.num_heads = num_heads
176
+ self.hidden_size = hidden_size
177
+ self.add_space_attn = add_space_attn
178
+ self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
179
+ self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
180
+ self.num_virtual_tracks = num_virtual_tracks
181
+
182
+ if self.add_space_attn:
183
+ self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size))
184
+ else:
185
+ self.virual_tracks = None
186
+
187
+ self.time_blocks = nn.ModuleList(
188
+ [
189
+ AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention)
190
+ for _ in range(time_depth)
191
+ ]
192
+ )
193
+
194
+ if add_space_attn:
195
+ self.space_virtual_blocks = nn.ModuleList(
196
+ [
197
+ AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention)
198
+ for _ in range(space_depth)
199
+ ]
200
+ )
201
+ self.space_point2virtual_blocks = nn.ModuleList(
202
+ [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
203
+ )
204
+ self.space_virtual2point_blocks = nn.ModuleList(
205
+ [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
206
+ )
207
+ assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
208
+ self.initialize_weights()
209
+
210
+ def initialize_weights(self):
211
+ def _basic_init(module):
212
+ if isinstance(module, nn.Linear):
213
+ torch.nn.init.xavier_uniform_(module.weight)
214
+ if module.bias is not None:
215
+ nn.init.constant_(module.bias, 0)
216
+
217
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
218
+ """ViT weight initialization, original timm impl (for reproducibility)"""
219
+ if isinstance(module, nn.Linear):
220
+ trunc_normal_(module.weight, std=0.02)
221
+ if module.bias is not None:
222
+ nn.init.zeros_(module.bias)
223
+
224
+ def forward(self, input_tensor, mask=None):
225
+ tokens = self.input_transform(input_tensor)
226
+
227
+ init_tokens = tokens
228
+
229
+ B, _, T, _ = tokens.shape
230
+
231
+ if self.add_space_attn:
232
+ virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
233
+ tokens = torch.cat([tokens, virtual_tokens], dim=1)
234
+
235
+ _, N, _, _ = tokens.shape
236
+
237
+ j = 0
238
+ for i in range(len(self.time_blocks)):
239
+ time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
240
+ time_tokens = self.time_blocks[i](time_tokens)
241
+
242
+ tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
243
+ if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0):
244
+ space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C
245
+ point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
246
+ virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
247
+
248
+ virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask)
249
+ virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
250
+ point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask)
251
+ space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
252
+ tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C
253
+ j += 1
254
+
255
+ if self.add_space_attn:
256
+ tokens = tokens[:, : N - self.num_virtual_tracks]
257
+
258
+ tokens = tokens + init_tokens
259
+
260
+ flow = self.flow_head(tokens)
261
+ return flow
262
+
263
+
264
+ class CorrBlock:
265
+ def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"):
266
+ B, S, C, H, W = fmaps.shape
267
+ self.S, self.C, self.H, self.W = S, C, H, W
268
+ self.padding_mode = padding_mode
269
+ self.num_levels = num_levels
270
+ self.radius = radius
271
+ self.fmaps_pyramid = []
272
+ self.multiple_track_feats = multiple_track_feats
273
+
274
+ self.fmaps_pyramid.append(fmaps)
275
+ for i in range(self.num_levels - 1):
276
+ fmaps_ = fmaps.reshape(B * S, C, H, W)
277
+ fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
278
+ _, _, H, W = fmaps_.shape
279
+ fmaps = fmaps_.reshape(B, S, C, H, W)
280
+ self.fmaps_pyramid.append(fmaps)
281
+
282
+ def sample(self, coords):
283
+ r = self.radius
284
+ B, S, N, D = coords.shape
285
+ assert D == 2
286
+
287
+ H, W = self.H, self.W
288
+ out_pyramid = []
289
+ for i in range(self.num_levels):
290
+ corrs = self.corrs_pyramid[i] # B, S, N, H, W
291
+ *_, H, W = corrs.shape
292
+
293
+ dx = torch.linspace(-r, r, 2 * r + 1)
294
+ dy = torch.linspace(-r, r, 2 * r + 1)
295
+ delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device)
296
+
297
+ centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i
298
+ delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
299
+ coords_lvl = centroid_lvl + delta_lvl
300
+
301
+ corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode)
302
+ corrs = corrs.view(B, S, N, -1)
303
+
304
+ out_pyramid.append(corrs)
305
+
306
+ out = torch.cat(out_pyramid, dim=-1).contiguous() # B, S, N, LRR*2
307
+ return out
308
+
309
+ def corr(self, targets):
310
+ B, S, N, C = targets.shape
311
+ if self.multiple_track_feats:
312
+ targets_split = targets.split(C // self.num_levels, dim=-1)
313
+ B, S, N, C = targets_split[0].shape
314
+
315
+ assert C == self.C
316
+ assert S == self.S
317
+
318
+ fmap1 = targets
319
+
320
+ self.corrs_pyramid = []
321
+ for i, fmaps in enumerate(self.fmaps_pyramid):
322
+ *_, H, W = fmaps.shape
323
+ fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W)
324
+ if self.multiple_track_feats:
325
+ fmap1 = targets_split[i]
326
+ corrs = torch.matmul(fmap1, fmap2s)
327
+ corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W
328
+ corrs = corrs / torch.sqrt(torch.tensor(C).float())
329
+ self.corrs_pyramid.append(corrs)
capvector-pi05/src/vggt/dependency/track_modules/modules.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from functools import partial
12
+ from typing import Callable
13
+ import collections
14
+ from torch import Tensor
15
+ from itertools import repeat
16
+
17
+
18
+ # From PyTorch internals
19
+ def _ntuple(n):
20
+ def parse(x):
21
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
22
+ return tuple(x)
23
+ return tuple(repeat(x, n))
24
+
25
+ return parse
26
+
27
+
28
+ def exists(val):
29
+ return val is not None
30
+
31
+
32
+ def default(val, d):
33
+ return val if exists(val) else d
34
+
35
+
36
+ to_2tuple = _ntuple(2)
37
+
38
+
39
+ class ResidualBlock(nn.Module):
40
+ """
41
+ ResidualBlock: construct a block of two conv layers with residual connections
42
+ """
43
+
44
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3):
45
+ super(ResidualBlock, self).__init__()
46
+
47
+ self.conv1 = nn.Conv2d(
48
+ in_planes, planes, kernel_size=kernel_size, padding=1, stride=stride, padding_mode="zeros"
49
+ )
50
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, padding=1, padding_mode="zeros")
51
+ self.relu = nn.ReLU(inplace=True)
52
+
53
+ num_groups = planes // 8
54
+
55
+ if norm_fn == "group":
56
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
57
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
58
+ if not stride == 1:
59
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
60
+
61
+ elif norm_fn == "batch":
62
+ self.norm1 = nn.BatchNorm2d(planes)
63
+ self.norm2 = nn.BatchNorm2d(planes)
64
+ if not stride == 1:
65
+ self.norm3 = nn.BatchNorm2d(planes)
66
+
67
+ elif norm_fn == "instance":
68
+ self.norm1 = nn.InstanceNorm2d(planes)
69
+ self.norm2 = nn.InstanceNorm2d(planes)
70
+ if not stride == 1:
71
+ self.norm3 = nn.InstanceNorm2d(planes)
72
+
73
+ elif norm_fn == "none":
74
+ self.norm1 = nn.Sequential()
75
+ self.norm2 = nn.Sequential()
76
+ if not stride == 1:
77
+ self.norm3 = nn.Sequential()
78
+ else:
79
+ raise NotImplementedError
80
+
81
+ if stride == 1:
82
+ self.downsample = None
83
+ else:
84
+ self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
85
+
86
+ def forward(self, x):
87
+ y = x
88
+ y = self.relu(self.norm1(self.conv1(y)))
89
+ y = self.relu(self.norm2(self.conv2(y)))
90
+
91
+ if self.downsample is not None:
92
+ x = self.downsample(x)
93
+
94
+ return self.relu(x + y)
95
+
96
+
97
+ class Mlp(nn.Module):
98
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
99
+
100
+ def __init__(
101
+ self,
102
+ in_features,
103
+ hidden_features=None,
104
+ out_features=None,
105
+ act_layer=nn.GELU,
106
+ norm_layer=None,
107
+ bias=True,
108
+ drop=0.0,
109
+ use_conv=False,
110
+ ):
111
+ super().__init__()
112
+ out_features = out_features or in_features
113
+ hidden_features = hidden_features or in_features
114
+ bias = to_2tuple(bias)
115
+ drop_probs = to_2tuple(drop)
116
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
117
+
118
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
119
+ self.act = act_layer()
120
+ self.drop1 = nn.Dropout(drop_probs[0])
121
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
122
+ self.drop2 = nn.Dropout(drop_probs[1])
123
+
124
+ def forward(self, x):
125
+ x = self.fc1(x)
126
+ x = self.act(x)
127
+ x = self.drop1(x)
128
+ x = self.fc2(x)
129
+ x = self.drop2(x)
130
+ return x
131
+
132
+
133
+ class AttnBlock(nn.Module):
134
+ def __init__(
135
+ self,
136
+ hidden_size,
137
+ num_heads,
138
+ attn_class: Callable[..., nn.Module] = nn.MultiheadAttention,
139
+ mlp_ratio=4.0,
140
+ **block_kwargs,
141
+ ):
142
+ """
143
+ Self attention block
144
+ """
145
+ super().__init__()
146
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
147
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
148
+
149
+ self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs)
150
+
151
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
152
+
153
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
154
+
155
+ def forward(self, x, mask=None):
156
+ # Prepare the mask for PyTorch's attention (it expects a different format)
157
+ # attn_mask = mask if mask is not None else None
158
+ # Normalize before attention
159
+ x = self.norm1(x)
160
+
161
+ # PyTorch's MultiheadAttention returns attn_output, attn_output_weights
162
+ # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask)
163
+
164
+ attn_output, _ = self.attn(x, x, x)
165
+
166
+ # Add & Norm
167
+ x = x + attn_output
168
+ x = x + self.mlp(self.norm2(x))
169
+ return x
170
+
171
+
172
+ class CrossAttnBlock(nn.Module):
173
+ def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):
174
+ """
175
+ Cross attention block
176
+ """
177
+ super().__init__()
178
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
179
+ self.norm_context = nn.LayerNorm(hidden_size)
180
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
181
+
182
+ self.cross_attn = nn.MultiheadAttention(
183
+ embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs
184
+ )
185
+
186
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
187
+
188
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
189
+
190
+ def forward(self, x, context, mask=None):
191
+ # Normalize inputs
192
+ x = self.norm1(x)
193
+ context = self.norm_context(context)
194
+
195
+ # Apply cross attention
196
+ # Note: nn.MultiheadAttention returns attn_output, attn_output_weights
197
+ attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask)
198
+
199
+ # Add & Norm
200
+ x = x + attn_output
201
+ x = x + self.mlp(self.norm2(x))
202
+ return x
capvector-pi05/src/vggt/dependency/track_modules/track_refine.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from functools import partial
13
+ from torch import nn, einsum
14
+ from einops import rearrange, repeat
15
+ from einops.layers.torch import Rearrange, Reduce
16
+
17
+ from PIL import Image
18
+ import os
19
+ from typing import Union, Tuple
20
+
21
+
22
+ def refine_track(
23
+ images, fine_fnet, fine_tracker, coarse_pred, compute_score=False, pradius=15, sradius=2, fine_iters=6, chunk=40960
24
+ ):
25
+ """
26
+ Refines the tracking of images using a fine track predictor and a fine feature network.
27
+ Check https://arxiv.org/abs/2312.04563 for more details.
28
+
29
+ Args:
30
+ images (torch.Tensor): The images to be tracked.
31
+ fine_fnet (nn.Module): The fine feature network.
32
+ fine_tracker (nn.Module): The fine track predictor.
33
+ coarse_pred (torch.Tensor): The coarse predictions of tracks.
34
+ compute_score (bool, optional): Whether to compute the score. Defaults to False.
35
+ pradius (int, optional): The radius of a patch. Defaults to 15.
36
+ sradius (int, optional): The search radius. Defaults to 2.
37
+
38
+ Returns:
39
+ torch.Tensor: The refined tracks.
40
+ torch.Tensor, optional: The score.
41
+ """
42
+
43
+ # coarse_pred shape: BxSxNx2,
44
+ # where B is the batch, S is the video/images length, and N is the number of tracks
45
+ # now we are going to extract patches with the center at coarse_pred
46
+ # Please note that the last dimension indicates x and y, and hence has a dim number of 2
47
+ B, S, N, _ = coarse_pred.shape
48
+ _, _, _, H, W = images.shape
49
+
50
+ # Given the raidus of a patch, compute the patch size
51
+ psize = pradius * 2 + 1
52
+
53
+ # Note that we assume the first frame is the query frame
54
+ # so the 2D locations of the first frame are the query points
55
+ query_points = coarse_pred[:, 0]
56
+
57
+ # Given 2D positions, we can use grid_sample to extract patches
58
+ # but it takes too much memory.
59
+ # Instead, we use the floored track xy to sample patches.
60
+
61
+ # For example, if the query point xy is (128.16, 252.78),
62
+ # and the patch size is (31, 31),
63
+ # our goal is to extract the content of a rectangle
64
+ # with left top: (113.16, 237.78)
65
+ # and right bottom: (143.16, 267.78).
66
+ # However, we record the floored left top: (113, 237)
67
+ # and the offset (0.16, 0.78)
68
+ # Then what we need is just unfolding the images like in CNN,
69
+ # picking the content at [(113, 237), (143, 267)].
70
+ # Such operations are highly optimized at pytorch
71
+ # (well if you really want to use interpolation, check the function extract_glimpse() below)
72
+
73
+ with torch.no_grad():
74
+ content_to_extract = images.reshape(B * S, 3, H, W)
75
+ C_in = content_to_extract.shape[1]
76
+
77
+ # Please refer to https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html
78
+ # for the detailed explanation of unfold()
79
+ # Here it runs sliding windows (psize x psize) to build patches
80
+ # The shape changes from
81
+ # (B*S)x C_in x H x W to (B*S)x C_in x H_new x W_new x Psize x Psize
82
+ # where Psize is the size of patch
83
+ content_to_extract = content_to_extract.unfold(2, psize, 1).unfold(3, psize, 1)
84
+
85
+ # Floor the coarse predictions to get integers and save the fractional/decimal
86
+ track_int = coarse_pred.floor().int()
87
+ track_frac = coarse_pred - track_int
88
+
89
+ # Note the points represent the center of patches
90
+ # now we get the location of the top left corner of patches
91
+ # because the ouput of pytorch unfold are indexed by top left corner
92
+ topleft = track_int - pradius
93
+ topleft_BSN = topleft.clone()
94
+
95
+ # clamp the values so that we will not go out of indexes
96
+ # NOTE: (VERY IMPORTANT: This operation ASSUMES H=W).
97
+ # You need to seperately clamp x and y if H!=W
98
+ topleft = topleft.clamp(0, H - psize)
99
+
100
+ # Reshape from BxSxNx2 -> (B*S)xNx2
101
+ topleft = topleft.reshape(B * S, N, 2)
102
+
103
+ # Prepare batches for indexing, shape: (B*S)xN
104
+ batch_indices = torch.arange(B * S)[:, None].expand(-1, N).to(content_to_extract.device)
105
+
106
+ # extracted_patches: (B*S) x N x C_in x Psize x Psize
107
+ extracted_patches = content_to_extract[batch_indices, :, topleft[..., 1], topleft[..., 0]]
108
+
109
+ if chunk < 0:
110
+ # Extract image patches based on top left corners
111
+ # Feed patches to fine fent for features
112
+ patch_feat = fine_fnet(extracted_patches.reshape(B * S * N, C_in, psize, psize))
113
+ else:
114
+ patches = extracted_patches.reshape(B * S * N, C_in, psize, psize)
115
+
116
+ patch_feat_list = []
117
+ for p in torch.split(patches, chunk):
118
+ patch_feat_list += [fine_fnet(p)]
119
+ patch_feat = torch.cat(patch_feat_list, 0)
120
+
121
+ C_out = patch_feat.shape[1]
122
+
123
+ # Refine the coarse tracks by fine_tracker
124
+ # reshape back to B x S x N x C_out x Psize x Psize
125
+ patch_feat = patch_feat.reshape(B, S, N, C_out, psize, psize)
126
+ patch_feat = rearrange(patch_feat, "b s n c p q -> (b n) s c p q")
127
+
128
+ # Prepare for the query points for fine tracker
129
+ # They are relative to the patch left top corner,
130
+ # instead of the image top left corner now
131
+ # patch_query_points: N x 1 x 2
132
+ # only 1 here because for each patch we only have 1 query point
133
+ patch_query_points = track_frac[:, 0] + pradius
134
+ patch_query_points = patch_query_points.reshape(B * N, 2).unsqueeze(1)
135
+
136
+ # Feed the PATCH query points and tracks into fine tracker
137
+ fine_pred_track_lists, _, _, query_point_feat = fine_tracker(
138
+ query_points=patch_query_points, fmaps=patch_feat, iters=fine_iters, return_feat=True
139
+ )
140
+
141
+ # relative the patch top left
142
+ fine_pred_track = fine_pred_track_lists[-1].clone()
143
+
144
+ # From (relative to the patch top left) to (relative to the image top left)
145
+ for idx in range(len(fine_pred_track_lists)):
146
+ fine_level = rearrange(fine_pred_track_lists[idx], "(b n) s u v -> b s n u v", b=B, n=N)
147
+ fine_level = fine_level.squeeze(-2)
148
+ fine_level = fine_level + topleft_BSN
149
+ fine_pred_track_lists[idx] = fine_level
150
+
151
+ # relative to the image top left
152
+ refined_tracks = fine_pred_track_lists[-1].clone()
153
+ refined_tracks[:, 0] = query_points
154
+
155
+ score = None
156
+
157
+ if compute_score:
158
+ score = compute_score_fn(query_point_feat, patch_feat, fine_pred_track, sradius, psize, B, N, S, C_out)
159
+
160
+ return refined_tracks, score
161
+
162
+
163
+ def refine_track_v0(
164
+ images, fine_fnet, fine_tracker, coarse_pred, compute_score=False, pradius=15, sradius=2, fine_iters=6
165
+ ):
166
+ """
167
+ COPIED FROM VGGSfM
168
+
169
+ Refines the tracking of images using a fine track predictor and a fine feature network.
170
+ Check https://arxiv.org/abs/2312.04563 for more details.
171
+
172
+ Args:
173
+ images (torch.Tensor): The images to be tracked.
174
+ fine_fnet (nn.Module): The fine feature network.
175
+ fine_tracker (nn.Module): The fine track predictor.
176
+ coarse_pred (torch.Tensor): The coarse predictions of tracks.
177
+ compute_score (bool, optional): Whether to compute the score. Defaults to False.
178
+ pradius (int, optional): The radius of a patch. Defaults to 15.
179
+ sradius (int, optional): The search radius. Defaults to 2.
180
+
181
+ Returns:
182
+ torch.Tensor: The refined tracks.
183
+ torch.Tensor, optional: The score.
184
+ """
185
+
186
+ # coarse_pred shape: BxSxNx2,
187
+ # where B is the batch, S is the video/images length, and N is the number of tracks
188
+ # now we are going to extract patches with the center at coarse_pred
189
+ # Please note that the last dimension indicates x and y, and hence has a dim number of 2
190
+ B, S, N, _ = coarse_pred.shape
191
+ _, _, _, H, W = images.shape
192
+
193
+ # Given the raidus of a patch, compute the patch size
194
+ psize = pradius * 2 + 1
195
+
196
+ # Note that we assume the first frame is the query frame
197
+ # so the 2D locations of the first frame are the query points
198
+ query_points = coarse_pred[:, 0]
199
+
200
+ # Given 2D positions, we can use grid_sample to extract patches
201
+ # but it takes too much memory.
202
+ # Instead, we use the floored track xy to sample patches.
203
+
204
+ # For example, if the query point xy is (128.16, 252.78),
205
+ # and the patch size is (31, 31),
206
+ # our goal is to extract the content of a rectangle
207
+ # with left top: (113.16, 237.78)
208
+ # and right bottom: (143.16, 267.78).
209
+ # However, we record the floored left top: (113, 237)
210
+ # and the offset (0.16, 0.78)
211
+ # Then what we need is just unfolding the images like in CNN,
212
+ # picking the content at [(113, 237), (143, 267)].
213
+ # Such operations are highly optimized at pytorch
214
+ # (well if you really want to use interpolation, check the function extract_glimpse() below)
215
+
216
+ with torch.no_grad():
217
+ content_to_extract = images.reshape(B * S, 3, H, W)
218
+ C_in = content_to_extract.shape[1]
219
+
220
+ # Please refer to https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html
221
+ # for the detailed explanation of unfold()
222
+ # Here it runs sliding windows (psize x psize) to build patches
223
+ # The shape changes from
224
+ # (B*S)x C_in x H x W to (B*S)x C_in x H_new x W_new x Psize x Psize
225
+ # where Psize is the size of patch
226
+ content_to_extract = content_to_extract.unfold(2, psize, 1).unfold(3, psize, 1)
227
+
228
+ # Floor the coarse predictions to get integers and save the fractional/decimal
229
+ track_int = coarse_pred.floor().int()
230
+ track_frac = coarse_pred - track_int
231
+
232
+ # Note the points represent the center of patches
233
+ # now we get the location of the top left corner of patches
234
+ # because the ouput of pytorch unfold are indexed by top left corner
235
+ topleft = track_int - pradius
236
+ topleft_BSN = topleft.clone()
237
+
238
+ # clamp the values so that we will not go out of indexes
239
+ # NOTE: (VERY IMPORTANT: This operation ASSUMES H=W).
240
+ # You need to seperately clamp x and y if H!=W
241
+ topleft = topleft.clamp(0, H - psize)
242
+
243
+ # Reshape from BxSxNx2 -> (B*S)xNx2
244
+ topleft = topleft.reshape(B * S, N, 2)
245
+
246
+ # Prepare batches for indexing, shape: (B*S)xN
247
+ batch_indices = torch.arange(B * S)[:, None].expand(-1, N).to(content_to_extract.device)
248
+
249
+ # Extract image patches based on top left corners
250
+ # extracted_patches: (B*S) x N x C_in x Psize x Psize
251
+ extracted_patches = content_to_extract[batch_indices, :, topleft[..., 1], topleft[..., 0]]
252
+
253
+ # Feed patches to fine fent for features
254
+ patch_feat = fine_fnet(extracted_patches.reshape(B * S * N, C_in, psize, psize))
255
+
256
+ C_out = patch_feat.shape[1]
257
+
258
+ # Refine the coarse tracks by fine_tracker
259
+
260
+ # reshape back to B x S x N x C_out x Psize x Psize
261
+ patch_feat = patch_feat.reshape(B, S, N, C_out, psize, psize)
262
+ patch_feat = rearrange(patch_feat, "b s n c p q -> (b n) s c p q")
263
+
264
+ # Prepare for the query points for fine tracker
265
+ # They are relative to the patch left top corner,
266
+ # instead of the image top left corner now
267
+ # patch_query_points: N x 1 x 2
268
+ # only 1 here because for each patch we only have 1 query point
269
+ patch_query_points = track_frac[:, 0] + pradius
270
+ patch_query_points = patch_query_points.reshape(B * N, 2).unsqueeze(1)
271
+
272
+ # Feed the PATCH query points and tracks into fine tracker
273
+ fine_pred_track_lists, _, _, query_point_feat = fine_tracker(
274
+ query_points=patch_query_points, fmaps=patch_feat, iters=fine_iters, return_feat=True
275
+ )
276
+
277
+ # relative the patch top left
278
+ fine_pred_track = fine_pred_track_lists[-1].clone()
279
+
280
+ # From (relative to the patch top left) to (relative to the image top left)
281
+ for idx in range(len(fine_pred_track_lists)):
282
+ fine_level = rearrange(fine_pred_track_lists[idx], "(b n) s u v -> b s n u v", b=B, n=N)
283
+ fine_level = fine_level.squeeze(-2)
284
+ fine_level = fine_level + topleft_BSN
285
+ fine_pred_track_lists[idx] = fine_level
286
+
287
+ # relative to the image top left
288
+ refined_tracks = fine_pred_track_lists[-1].clone()
289
+ refined_tracks[:, 0] = query_points
290
+
291
+ score = None
292
+
293
+ if compute_score:
294
+ score = compute_score_fn(query_point_feat, patch_feat, fine_pred_track, sradius, psize, B, N, S, C_out)
295
+
296
+ return refined_tracks, score
297
+
298
+
299
+ ################################## NOTE: NOT USED ##################################
300
+
301
+
302
+ def compute_score_fn(query_point_feat, patch_feat, fine_pred_track, sradius, psize, B, N, S, C_out):
303
+ """
304
+ Compute the scores, i.e., the standard deviation of the 2D similarity heatmaps,
305
+ given the query point features and reference frame feature maps
306
+ """
307
+
308
+ from kornia.utils.grid import create_meshgrid
309
+ from kornia.geometry.subpix import dsnt
310
+
311
+ # query_point_feat initial shape: B x N x C_out,
312
+ # query_point_feat indicates the feat at the coorponsing query points
313
+ # Therefore we don't have S dimension here
314
+ query_point_feat = query_point_feat.reshape(B, N, C_out)
315
+ # reshape and expand to B x (S-1) x N x C_out
316
+ query_point_feat = query_point_feat.unsqueeze(1).expand(-1, S - 1, -1, -1)
317
+ # and reshape to (B*(S-1)*N) x C_out
318
+ query_point_feat = query_point_feat.reshape(B * (S - 1) * N, C_out)
319
+
320
+ # Radius and size for computing the score
321
+ ssize = sradius * 2 + 1
322
+
323
+ # Reshape, you know it, so many reshaping operations
324
+ patch_feat = rearrange(patch_feat, "(b n) s c p q -> b s n c p q", b=B, n=N)
325
+
326
+ # Again, we unfold the patches to smaller patches
327
+ # so that we can then focus on smaller patches
328
+ # patch_feat_unfold shape:
329
+ # B x S x N x C_out x (psize - 2*sradius) x (psize - 2*sradius) x ssize x ssize
330
+ # well a bit scary, but actually not
331
+ patch_feat_unfold = patch_feat.unfold(4, ssize, 1).unfold(5, ssize, 1)
332
+
333
+ # Do the same stuffs above, i.e., the same as extracting patches
334
+ fine_prediction_floor = fine_pred_track.floor().int()
335
+ fine_level_floor_topleft = fine_prediction_floor - sradius
336
+
337
+ # Clamp to ensure the smaller patch is valid
338
+ fine_level_floor_topleft = fine_level_floor_topleft.clamp(0, psize - ssize)
339
+ fine_level_floor_topleft = fine_level_floor_topleft.squeeze(2)
340
+
341
+ # Prepare the batch indices and xy locations
342
+
343
+ batch_indices_score = torch.arange(B)[:, None, None].expand(-1, S, N) # BxSxN
344
+ batch_indices_score = batch_indices_score.reshape(-1).to(patch_feat_unfold.device) # B*S*N
345
+ y_indices = fine_level_floor_topleft[..., 0].flatten() # Flatten H indices
346
+ x_indices = fine_level_floor_topleft[..., 1].flatten() # Flatten W indices
347
+
348
+ reference_frame_feat = patch_feat_unfold.reshape(
349
+ B * S * N, C_out, psize - sradius * 2, psize - sradius * 2, ssize, ssize
350
+ )
351
+
352
+ # Note again, according to pytorch convention
353
+ # x_indices cooresponds to [..., 1] and y_indices cooresponds to [..., 0]
354
+ reference_frame_feat = reference_frame_feat[batch_indices_score, :, x_indices, y_indices]
355
+ reference_frame_feat = reference_frame_feat.reshape(B, S, N, C_out, ssize, ssize)
356
+ # pick the frames other than the first one, so we have S-1 frames here
357
+ reference_frame_feat = reference_frame_feat[:, 1:].reshape(B * (S - 1) * N, C_out, ssize * ssize)
358
+
359
+ # Compute similarity
360
+ sim_matrix = torch.einsum("mc,mcr->mr", query_point_feat, reference_frame_feat)
361
+ softmax_temp = 1.0 / C_out**0.5
362
+ heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1)
363
+ # 2D heatmaps
364
+ heatmap = heatmap.reshape(B * (S - 1) * N, ssize, ssize) # * x ssize x ssize
365
+
366
+ coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0]
367
+ grid_normalized = create_meshgrid(ssize, ssize, normalized_coordinates=True, device=heatmap.device).reshape(
368
+ 1, -1, 2
369
+ )
370
+
371
+ var = torch.sum(grid_normalized**2 * heatmap.view(-1, ssize * ssize, 1), dim=1) - coords_normalized**2
372
+ std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # clamp needed for numerical stability
373
+
374
+ score = std.reshape(B, S - 1, N)
375
+ # set score as 1 for the query frame
376
+ score = torch.cat([torch.ones_like(score[:, 0:1]), score], dim=1)
377
+
378
+ return score
379
+
380
+
381
+ def extract_glimpse(
382
+ tensor: torch.Tensor, size: Tuple[int, int], offsets, mode="bilinear", padding_mode="zeros", debug=False, orib=None
383
+ ):
384
+ B, C, W, H = tensor.shape
385
+
386
+ h, w = size
387
+ xs = torch.arange(0, w, dtype=tensor.dtype, device=tensor.device) - (w - 1) / 2.0
388
+ ys = torch.arange(0, h, dtype=tensor.dtype, device=tensor.device) - (h - 1) / 2.0
389
+
390
+ vy, vx = torch.meshgrid(ys, xs)
391
+ grid = torch.stack([vx, vy], dim=-1) # h, w, 2
392
+ grid = grid[None]
393
+
394
+ B, N, _ = offsets.shape
395
+
396
+ offsets = offsets.reshape((B * N), 1, 1, 2)
397
+ offsets_grid = offsets + grid
398
+
399
+ # normalised grid to [-1, 1]
400
+ offsets_grid = (offsets_grid - offsets_grid.new_tensor([W / 2, H / 2])) / offsets_grid.new_tensor([W / 2, H / 2])
401
+
402
+ # BxCxHxW -> Bx1xCxHxW
403
+ tensor = tensor[:, None]
404
+
405
+ # Bx1xCxHxW -> BxNxCxHxW
406
+ tensor = tensor.expand(-1, N, -1, -1, -1)
407
+
408
+ # BxNxCxHxW -> (B*N)xCxHxW
409
+ tensor = tensor.reshape((B * N), C, W, H)
410
+
411
+ sampled = torch.nn.functional.grid_sample(
412
+ tensor, offsets_grid, mode=mode, align_corners=False, padding_mode=padding_mode
413
+ )
414
+
415
+ # NOTE: I am not sure it should be h, w or w, h here
416
+ # but okay for sqaures
417
+ sampled = sampled.reshape(B, N, C, h, w)
418
+
419
+ return sampled
capvector-pi05/src/vggt/dependency/track_modules/utils.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from https://github.com/facebookresearch/PoseDiffusion
8
+ # and https://github.com/facebookresearch/co-tracker/tree/main
9
+
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from typing import Optional, Tuple, Union
16
+ from einops import rearrange, repeat
17
+
18
+
19
+ def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor:
20
+ """
21
+ This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
22
+ It is a wrapper of get_2d_sincos_pos_embed_from_grid.
23
+ Args:
24
+ - embed_dim: The embedding dimension.
25
+ - grid_size: The grid size.
26
+ Returns:
27
+ - pos_embed: The generated 2D positional embedding.
28
+ """
29
+ if isinstance(grid_size, tuple):
30
+ grid_size_h, grid_size_w = grid_size
31
+ else:
32
+ grid_size_h = grid_size_w = grid_size
33
+ grid_h = torch.arange(grid_size_h, dtype=torch.float)
34
+ grid_w = torch.arange(grid_size_w, dtype=torch.float)
35
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
36
+ grid = torch.stack(grid, dim=0)
37
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
38
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
39
+ if return_grid:
40
+ return (pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2), grid)
41
+ return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
42
+
43
+
44
+ def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor:
45
+ """
46
+ This function generates a 2D positional embedding from a given grid using sine and cosine functions.
47
+
48
+ Args:
49
+ - embed_dim: The embedding dimension.
50
+ - grid: The grid to generate the embedding from.
51
+
52
+ Returns:
53
+ - emb: The generated 2D positional embedding.
54
+ """
55
+ assert embed_dim % 2 == 0
56
+
57
+ # use half of dimensions to encode grid_h
58
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
59
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
60
+
61
+ emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
62
+ return emb
63
+
64
+
65
+ def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor:
66
+ """
67
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
68
+
69
+ Args:
70
+ - embed_dim: The embedding dimension.
71
+ - pos: The position to generate the embedding from.
72
+
73
+ Returns:
74
+ - emb: The generated 1D positional embedding.
75
+ """
76
+ assert embed_dim % 2 == 0
77
+ omega = torch.arange(embed_dim // 2, dtype=torch.double)
78
+ omega /= embed_dim / 2.0
79
+ omega = 1.0 / 10000**omega # (D/2,)
80
+
81
+ pos = pos.reshape(-1) # (M,)
82
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
83
+
84
+ emb_sin = torch.sin(out) # (M, D/2)
85
+ emb_cos = torch.cos(out) # (M, D/2)
86
+
87
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
88
+ return emb[None].float()
89
+
90
+
91
+ def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
92
+ """
93
+ This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
94
+
95
+ Args:
96
+ - xy: The coordinates to generate the embedding from.
97
+ - C: The size of the embedding.
98
+ - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
99
+
100
+ Returns:
101
+ - pe: The generated 2D positional embedding.
102
+ """
103
+ B, N, D = xy.shape
104
+ assert D == 2
105
+
106
+ x = xy[:, :, 0:1]
107
+ y = xy[:, :, 1:2]
108
+ div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2))
109
+
110
+ pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
111
+ pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
112
+
113
+ pe_x[:, :, 0::2] = torch.sin(x * div_term)
114
+ pe_x[:, :, 1::2] = torch.cos(x * div_term)
115
+
116
+ pe_y[:, :, 0::2] = torch.sin(y * div_term)
117
+ pe_y[:, :, 1::2] = torch.cos(y * div_term)
118
+
119
+ pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
120
+ if cat_coords:
121
+ pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
122
+ return pe
123
+
124
+
125
+ def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
126
+ r"""Sample a tensor using bilinear interpolation
127
+
128
+ `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
129
+ coordinates :attr:`coords` using bilinear interpolation. It is the same
130
+ as `torch.nn.functional.grid_sample()` but with a different coordinate
131
+ convention.
132
+
133
+ The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
134
+ :math:`B` is the batch size, :math:`C` is the number of channels,
135
+ :math:`H` is the height of the image, and :math:`W` is the width of the
136
+ image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
137
+ interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
138
+
139
+ Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
140
+ in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
141
+ that in this case the order of the components is slightly different
142
+ from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
143
+
144
+ If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
145
+ in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
146
+ left-most image pixel :math:`W-1` to the center of the right-most
147
+ pixel.
148
+
149
+ If `align_corners` is `False`, the coordinate :math:`x` is assumed to
150
+ be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
151
+ the left-most pixel :math:`W` to the right edge of the right-most
152
+ pixel.
153
+
154
+ Similar conventions apply to the :math:`y` for the range
155
+ :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
156
+ :math:`[0,T-1]` and :math:`[0,T]`.
157
+
158
+ Args:
159
+ input (Tensor): batch of input images.
160
+ coords (Tensor): batch of coordinates.
161
+ align_corners (bool, optional): Coordinate convention. Defaults to `True`.
162
+ padding_mode (str, optional): Padding mode. Defaults to `"border"`.
163
+
164
+ Returns:
165
+ Tensor: sampled points.
166
+ """
167
+
168
+ sizes = input.shape[2:]
169
+
170
+ assert len(sizes) in [2, 3]
171
+
172
+ if len(sizes) == 3:
173
+ # t x y -> x y t to match dimensions T H W in grid_sample
174
+ coords = coords[..., [1, 2, 0]]
175
+
176
+ if align_corners:
177
+ coords = coords * torch.tensor([2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device)
178
+ else:
179
+ coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device)
180
+
181
+ coords -= 1
182
+
183
+ return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
184
+
185
+
186
+ def sample_features4d(input, coords):
187
+ r"""Sample spatial features
188
+
189
+ `sample_features4d(input, coords)` samples the spatial features
190
+ :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
191
+
192
+ The field is sampled at coordinates :attr:`coords` using bilinear
193
+ interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
194
+ 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
195
+ same convention as :func:`bilinear_sampler` with `align_corners=True`.
196
+
197
+ The output tensor has one feature per point, and has shape :math:`(B,
198
+ R, C)`.
199
+
200
+ Args:
201
+ input (Tensor): spatial features.
202
+ coords (Tensor): points.
203
+
204
+ Returns:
205
+ Tensor: sampled features.
206
+ """
207
+
208
+ B, _, _, _ = input.shape
209
+
210
+ # B R 2 -> B R 1 2
211
+ coords = coords.unsqueeze(2)
212
+
213
+ # B C R 1
214
+ feats = bilinear_sampler(input, coords)
215
+
216
+ return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C
capvector-pi05/src/vggt/dependency/track_predict.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import numpy as np
9
+ from .vggsfm_utils import *
10
+
11
+
12
+ def predict_tracks(
13
+ images,
14
+ conf=None,
15
+ points_3d=None,
16
+ masks=None,
17
+ max_query_pts=2048,
18
+ query_frame_num=5,
19
+ keypoint_extractor="aliked+sp",
20
+ max_points_num=163840,
21
+ fine_tracking=True,
22
+ complete_non_vis=True,
23
+ ):
24
+ """
25
+ Predict tracks for the given images and masks.
26
+
27
+ TODO: support non-square images
28
+ TODO: support masks
29
+
30
+
31
+ This function predicts the tracks for the given images and masks using the specified query method
32
+ and track predictor. It finds query points, and predicts the tracks, visibility, and scores for the query frames.
33
+
34
+ Args:
35
+ images: Tensor of shape [S, 3, H, W] containing the input images.
36
+ conf: Tensor of shape [S, 1, H, W] containing the confidence scores. Default is None.
37
+ points_3d: Tensor containing 3D points. Default is None.
38
+ masks: Optional tensor of shape [S, 1, H, W] containing masks. Default is None.
39
+ max_query_pts: Maximum number of query points. Default is 2048.
40
+ query_frame_num: Number of query frames to use. Default is 5.
41
+ keypoint_extractor: Method for keypoint extraction. Default is "aliked+sp".
42
+ max_points_num: Maximum number of points to process at once. Default is 163840.
43
+ fine_tracking: Whether to use fine tracking. Default is True.
44
+ complete_non_vis: Whether to augment non-visible frames. Default is True.
45
+
46
+ Returns:
47
+ pred_tracks: Numpy array containing the predicted tracks.
48
+ pred_vis_scores: Numpy array containing the visibility scores for the tracks.
49
+ pred_confs: Numpy array containing the confidence scores for the tracks.
50
+ pred_points_3d: Numpy array containing the 3D points for the tracks.
51
+ pred_colors: Numpy array containing the point colors for the tracks. (0, 255)
52
+ """
53
+
54
+ device = images.device
55
+ dtype = images.dtype
56
+ tracker = build_vggsfm_tracker().to(device, dtype)
57
+
58
+ # Find query frames
59
+ query_frame_indexes = generate_rank_by_dino(images, query_frame_num=query_frame_num, device=device)
60
+
61
+ # Add the first image to the front if not already present
62
+ if 0 in query_frame_indexes:
63
+ query_frame_indexes.remove(0)
64
+ query_frame_indexes = [0, *query_frame_indexes]
65
+
66
+ # TODO: add the functionality to handle the masks
67
+ keypoint_extractors = initialize_feature_extractors(
68
+ max_query_pts, extractor_method=keypoint_extractor, device=device
69
+ )
70
+
71
+ pred_tracks = []
72
+ pred_vis_scores = []
73
+ pred_confs = []
74
+ pred_points_3d = []
75
+ pred_colors = []
76
+
77
+ fmaps_for_tracker = tracker.process_images_to_fmaps(images)
78
+
79
+ if fine_tracking:
80
+ print("For faster inference, consider disabling fine_tracking")
81
+
82
+ for query_index in query_frame_indexes:
83
+ print(f"Predicting tracks for query frame {query_index}")
84
+ pred_track, pred_vis, pred_conf, pred_point_3d, pred_color = _forward_on_query(
85
+ query_index,
86
+ images,
87
+ conf,
88
+ points_3d,
89
+ fmaps_for_tracker,
90
+ keypoint_extractors,
91
+ tracker,
92
+ max_points_num,
93
+ fine_tracking,
94
+ device,
95
+ )
96
+
97
+ pred_tracks.append(pred_track)
98
+ pred_vis_scores.append(pred_vis)
99
+ pred_confs.append(pred_conf)
100
+ pred_points_3d.append(pred_point_3d)
101
+ pred_colors.append(pred_color)
102
+
103
+ if complete_non_vis:
104
+ pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors = _augment_non_visible_frames(
105
+ pred_tracks,
106
+ pred_vis_scores,
107
+ pred_confs,
108
+ pred_points_3d,
109
+ pred_colors,
110
+ images,
111
+ conf,
112
+ points_3d,
113
+ fmaps_for_tracker,
114
+ keypoint_extractors,
115
+ tracker,
116
+ max_points_num,
117
+ fine_tracking,
118
+ min_vis=500,
119
+ non_vis_thresh=0.1,
120
+ device=device,
121
+ )
122
+
123
+ pred_tracks = np.concatenate(pred_tracks, axis=1)
124
+ pred_vis_scores = np.concatenate(pred_vis_scores, axis=1)
125
+ pred_confs = np.concatenate(pred_confs, axis=0) if pred_confs else None
126
+ pred_points_3d = np.concatenate(pred_points_3d, axis=0) if pred_points_3d else None
127
+ pred_colors = np.concatenate(pred_colors, axis=0) if pred_colors else None
128
+
129
+ # from vggt.utils.visual_track import visualize_tracks_on_images
130
+ # visualize_tracks_on_images(images[None], torch.from_numpy(pred_tracks[None]), torch.from_numpy(pred_vis_scores[None])>0.2, out_dir="track_visuals")
131
+
132
+ return pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors
133
+
134
+
135
+ def _forward_on_query(
136
+ query_index,
137
+ images,
138
+ conf,
139
+ points_3d,
140
+ fmaps_for_tracker,
141
+ keypoint_extractors,
142
+ tracker,
143
+ max_points_num,
144
+ fine_tracking,
145
+ device,
146
+ ):
147
+ """
148
+ Process a single query frame for track prediction.
149
+
150
+ Args:
151
+ query_index: Index of the query frame
152
+ images: Tensor of shape [S, 3, H, W] containing the input images
153
+ conf: Confidence tensor
154
+ points_3d: 3D points tensor
155
+ fmaps_for_tracker: Feature maps for the tracker
156
+ keypoint_extractors: Initialized feature extractors
157
+ tracker: VGG-SFM tracker
158
+ max_points_num: Maximum number of points to process at once
159
+ fine_tracking: Whether to use fine tracking
160
+ device: Device to use for computation
161
+
162
+ Returns:
163
+ pred_track: Predicted tracks
164
+ pred_vis: Visibility scores for the tracks
165
+ pred_conf: Confidence scores for the tracks
166
+ pred_point_3d: 3D points for the tracks
167
+ pred_color: Point colors for the tracks (0, 255)
168
+ """
169
+ frame_num, _, height, width = images.shape
170
+
171
+ query_image = images[query_index]
172
+ query_points = extract_keypoints(query_image, keypoint_extractors, round_keypoints=False)
173
+ query_points = query_points[:, torch.randperm(query_points.shape[1], device=device)]
174
+
175
+ # Extract the color at the keypoint locations
176
+ query_points_long = query_points.squeeze(0).round().long()
177
+ pred_color = images[query_index][:, query_points_long[:, 1], query_points_long[:, 0]]
178
+ pred_color = (pred_color.permute(1, 0).cpu().numpy() * 255).astype(np.uint8)
179
+
180
+ # Query the confidence and points_3d at the keypoint locations
181
+ if (conf is not None) and (points_3d is not None):
182
+ assert height == width
183
+ assert conf.shape[-2] == conf.shape[-1]
184
+ assert conf.shape[:3] == points_3d.shape[:3]
185
+ scale = conf.shape[-1] / width
186
+
187
+ query_points_scaled = (query_points.squeeze(0) * scale).round().long()
188
+ query_points_scaled = query_points_scaled.cpu().numpy()
189
+
190
+ pred_conf = conf[query_index][query_points_scaled[:, 1], query_points_scaled[:, 0]]
191
+ pred_point_3d = points_3d[query_index][query_points_scaled[:, 1], query_points_scaled[:, 0]]
192
+
193
+ # heuristic to remove low confidence points
194
+ # should I export this as an input parameter?
195
+ valid_mask = pred_conf > 1.2
196
+ if valid_mask.sum() > 512:
197
+ query_points = query_points[:, valid_mask] # Make sure shape is compatible
198
+ pred_conf = pred_conf[valid_mask]
199
+ pred_point_3d = pred_point_3d[valid_mask]
200
+ pred_color = pred_color[valid_mask]
201
+ else:
202
+ pred_conf = None
203
+ pred_point_3d = None
204
+
205
+ reorder_index = calculate_index_mappings(query_index, frame_num, device=device)
206
+
207
+ images_feed, fmaps_feed = switch_tensor_order([images, fmaps_for_tracker], reorder_index, dim=0)
208
+ images_feed = images_feed[None] # add batch dimension
209
+ fmaps_feed = fmaps_feed[None] # add batch dimension
210
+
211
+ all_points_num = images_feed.shape[1] * query_points.shape[1]
212
+
213
+ # Don't need to be scared, this is just chunking to make GPU happy
214
+ if all_points_num > max_points_num:
215
+ num_splits = (all_points_num + max_points_num - 1) // max_points_num
216
+ query_points = torch.chunk(query_points, num_splits, dim=1)
217
+ else:
218
+ query_points = [query_points]
219
+
220
+ pred_track, pred_vis, _ = predict_tracks_in_chunks(
221
+ tracker, images_feed, query_points, fmaps_feed, fine_tracking=fine_tracking
222
+ )
223
+
224
+ pred_track, pred_vis = switch_tensor_order([pred_track, pred_vis], reorder_index, dim=1)
225
+
226
+ pred_track = pred_track.squeeze(0).float().cpu().numpy()
227
+ pred_vis = pred_vis.squeeze(0).float().cpu().numpy()
228
+
229
+ return pred_track, pred_vis, pred_conf, pred_point_3d, pred_color
230
+
231
+
232
+ def _augment_non_visible_frames(
233
+ pred_tracks: list, # ← running list of np.ndarrays
234
+ pred_vis_scores: list, # ← running list of np.ndarrays
235
+ pred_confs: list, # ← running list of np.ndarrays for confidence scores
236
+ pred_points_3d: list, # ← running list of np.ndarrays for 3D points
237
+ pred_colors: list, # ← running list of np.ndarrays for colors
238
+ images: torch.Tensor,
239
+ conf,
240
+ points_3d,
241
+ fmaps_for_tracker,
242
+ keypoint_extractors,
243
+ tracker,
244
+ max_points_num: int,
245
+ fine_tracking: bool,
246
+ *,
247
+ min_vis: int = 500,
248
+ non_vis_thresh: float = 0.1,
249
+ device: torch.device = None,
250
+ ):
251
+ """
252
+ Augment tracking for frames with insufficient visibility.
253
+
254
+ Args:
255
+ pred_tracks: List of numpy arrays containing predicted tracks.
256
+ pred_vis_scores: List of numpy arrays containing visibility scores.
257
+ pred_confs: List of numpy arrays containing confidence scores.
258
+ pred_points_3d: List of numpy arrays containing 3D points.
259
+ pred_colors: List of numpy arrays containing point colors.
260
+ images: Tensor of shape [S, 3, H, W] containing the input images.
261
+ conf: Tensor of shape [S, 1, H, W] containing confidence scores
262
+ points_3d: Tensor containing 3D points
263
+ fmaps_for_tracker: Feature maps for the tracker
264
+ keypoint_extractors: Initialized feature extractors
265
+ tracker: VGG-SFM tracker
266
+ max_points_num: Maximum number of points to process at once
267
+ fine_tracking: Whether to use fine tracking
268
+ min_vis: Minimum visibility threshold
269
+ non_vis_thresh: Non-visibility threshold
270
+ device: Device to use for computation
271
+
272
+ Returns:
273
+ Updated pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, and pred_colors lists.
274
+ """
275
+ last_query = -1
276
+ final_trial = False
277
+ cur_extractors = keypoint_extractors # may be replaced on the final trial
278
+
279
+ while True:
280
+ # Visibility per frame
281
+ vis_array = np.concatenate(pred_vis_scores, axis=1)
282
+
283
+ # Count frames with sufficient visibility using numpy
284
+ sufficient_vis_count = (vis_array > non_vis_thresh).sum(axis=-1)
285
+ non_vis_frames = np.where(sufficient_vis_count < min_vis)[0].tolist()
286
+
287
+ if len(non_vis_frames) == 0:
288
+ break
289
+
290
+ print("Processing non visible frames:", non_vis_frames)
291
+
292
+ # Decide the frames & extractor for this round
293
+ if non_vis_frames[0] == last_query:
294
+ # Same frame failed twice - final "all-in" attempt
295
+ final_trial = True
296
+ cur_extractors = initialize_feature_extractors(2048, extractor_method="sp+sift+aliked", device=device)
297
+ query_frame_list = non_vis_frames # blast them all at once
298
+ else:
299
+ query_frame_list = [non_vis_frames[0]] # Process one at a time
300
+
301
+ last_query = non_vis_frames[0]
302
+
303
+ # Run the tracker for every selected frame
304
+ for query_index in query_frame_list:
305
+ new_track, new_vis, new_conf, new_point_3d, new_color = _forward_on_query(
306
+ query_index,
307
+ images,
308
+ conf,
309
+ points_3d,
310
+ fmaps_for_tracker,
311
+ cur_extractors,
312
+ tracker,
313
+ max_points_num,
314
+ fine_tracking,
315
+ device,
316
+ )
317
+ pred_tracks.append(new_track)
318
+ pred_vis_scores.append(new_vis)
319
+ pred_confs.append(new_conf)
320
+ pred_points_3d.append(new_point_3d)
321
+ pred_colors.append(new_color)
322
+
323
+ if final_trial:
324
+ break # Stop after final attempt
325
+
326
+ return pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors
capvector-pi05/src/vggt/dependency/vggsfm_tracker.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from functools import partial
13
+ from torch import nn, einsum
14
+ from einops import rearrange, repeat
15
+ from einops.layers.torch import Rearrange, Reduce
16
+
17
+ from hydra.utils import instantiate
18
+ from omegaconf import OmegaConf
19
+
20
+ from .track_modules.track_refine import refine_track
21
+ from .track_modules.blocks import BasicEncoder, ShallowEncoder
22
+ from .track_modules.base_track_predictor import BaseTrackerPredictor
23
+
24
+
25
+ class TrackerPredictor(nn.Module):
26
+ def __init__(self, **extra_args):
27
+ super(TrackerPredictor, self).__init__()
28
+ """
29
+ Initializes the tracker predictor.
30
+
31
+ Both coarse_predictor and fine_predictor are constructed as a BaseTrackerPredictor,
32
+ check track_modules/base_track_predictor.py
33
+
34
+ Both coarse_fnet and fine_fnet are constructed as a 2D CNN network
35
+ check track_modules/blocks.py for BasicEncoder and ShallowEncoder
36
+ """
37
+ # Define coarse predictor configuration
38
+ coarse_stride = 4
39
+ self.coarse_down_ratio = 2
40
+
41
+ # Create networks directly instead of using instantiate
42
+ self.coarse_fnet = BasicEncoder(stride=coarse_stride)
43
+ self.coarse_predictor = BaseTrackerPredictor(stride=coarse_stride)
44
+
45
+ # Create fine predictor with stride = 1
46
+ self.fine_fnet = ShallowEncoder(stride=1)
47
+ self.fine_predictor = BaseTrackerPredictor(
48
+ stride=1,
49
+ depth=4,
50
+ corr_levels=3,
51
+ corr_radius=3,
52
+ latent_dim=32,
53
+ hidden_size=256,
54
+ fine=True,
55
+ use_spaceatt=False,
56
+ )
57
+
58
+ def forward(
59
+ self, images, query_points, fmaps=None, coarse_iters=6, inference=True, fine_tracking=True, fine_chunk=40960
60
+ ):
61
+ """
62
+ Args:
63
+ images (torch.Tensor): Images as RGB, in the range of [0, 1], with a shape of B x S x 3 x H x W.
64
+ query_points (torch.Tensor): 2D xy of query points, relative to top left, with a shape of B x N x 2.
65
+ fmaps (torch.Tensor, optional): Precomputed feature maps. Defaults to None.
66
+ coarse_iters (int, optional): Number of iterations for coarse prediction. Defaults to 6.
67
+ inference (bool, optional): Whether to perform inference. Defaults to True.
68
+ fine_tracking (bool, optional): Whether to perform fine tracking. Defaults to True.
69
+
70
+ Returns:
71
+ tuple: A tuple containing fine_pred_track, coarse_pred_track, pred_vis, and pred_score.
72
+ """
73
+
74
+ if fmaps is None:
75
+ batch_num, frame_num, image_dim, height, width = images.shape
76
+ reshaped_image = images.reshape(batch_num * frame_num, image_dim, height, width)
77
+ fmaps = self.process_images_to_fmaps(reshaped_image)
78
+ fmaps = fmaps.reshape(batch_num, frame_num, -1, fmaps.shape[-2], fmaps.shape[-1])
79
+
80
+ if inference:
81
+ torch.cuda.empty_cache()
82
+
83
+ # Coarse prediction
84
+ coarse_pred_track_lists, pred_vis = self.coarse_predictor(
85
+ query_points=query_points, fmaps=fmaps, iters=coarse_iters, down_ratio=self.coarse_down_ratio
86
+ )
87
+ coarse_pred_track = coarse_pred_track_lists[-1]
88
+
89
+ if inference:
90
+ torch.cuda.empty_cache()
91
+
92
+ if fine_tracking:
93
+ # Refine the coarse prediction
94
+ fine_pred_track, pred_score = refine_track(
95
+ images, self.fine_fnet, self.fine_predictor, coarse_pred_track, compute_score=False, chunk=fine_chunk
96
+ )
97
+
98
+ if inference:
99
+ torch.cuda.empty_cache()
100
+ else:
101
+ fine_pred_track = coarse_pred_track
102
+ pred_score = torch.ones_like(pred_vis)
103
+
104
+ return fine_pred_track, coarse_pred_track, pred_vis, pred_score
105
+
106
+ def process_images_to_fmaps(self, images):
107
+ """
108
+ This function processes images for inference.
109
+
110
+ Args:
111
+ images (torch.Tensor): The images to be processed with shape S x 3 x H x W.
112
+
113
+ Returns:
114
+ torch.Tensor: The processed feature maps.
115
+ """
116
+ if self.coarse_down_ratio > 1:
117
+ # whether or not scale down the input images to save memory
118
+ fmaps = self.coarse_fnet(
119
+ F.interpolate(images, scale_factor=1 / self.coarse_down_ratio, mode="bilinear", align_corners=True)
120
+ )
121
+ else:
122
+ fmaps = self.coarse_fnet(images)
123
+
124
+ return fmaps
capvector-pi05/src/vggt/dependency/vggsfm_utils.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import warnings
9
+ from typing import Dict, List, Optional, Tuple, Union
10
+
11
+ import numpy as np
12
+ import pycolmap
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from lightglue import ALIKED, SIFT, SuperPoint
16
+
17
+ from .vggsfm_tracker import TrackerPredictor
18
+
19
+ # Suppress verbose logging from dependencies
20
+ logging.getLogger("dinov2").setLevel(logging.WARNING)
21
+ warnings.filterwarnings("ignore", message="xFormers is available")
22
+ warnings.filterwarnings("ignore", message="dinov2")
23
+
24
+ # Constants
25
+ _RESNET_MEAN = [0.485, 0.456, 0.406]
26
+ _RESNET_STD = [0.229, 0.224, 0.225]
27
+
28
+
29
+ def build_vggsfm_tracker(model_path=None):
30
+ """
31
+ Build and initialize the VGGSfM tracker.
32
+
33
+ Args:
34
+ model_path: Path to the model weights file. If None, weights are downloaded from HuggingFace.
35
+
36
+ Returns:
37
+ Initialized tracker model in eval mode.
38
+ """
39
+ tracker = TrackerPredictor()
40
+
41
+ if model_path is None:
42
+ default_url = "https://huggingface.co/facebook/VGGSfM/resolve/main/vggsfm_v2_tracker.pt"
43
+ tracker.load_state_dict(torch.hub.load_state_dict_from_url(default_url))
44
+ else:
45
+ tracker.load_state_dict(torch.load(model_path))
46
+
47
+ tracker.eval()
48
+ return tracker
49
+
50
+
51
+ def generate_rank_by_dino(
52
+ images, query_frame_num, image_size=336, model_name="dinov2_vitb14_reg", device="cuda", spatial_similarity=False
53
+ ):
54
+ """
55
+ Generate a ranking of frames using DINO ViT features.
56
+
57
+ Args:
58
+ images: Tensor of shape (S, 3, H, W) with values in range [0, 1]
59
+ query_frame_num: Number of frames to select
60
+ image_size: Size to resize images to before processing
61
+ model_name: Name of the DINO model to use
62
+ device: Device to run the model on
63
+ spatial_similarity: Whether to use spatial token similarity or CLS token similarity
64
+
65
+ Returns:
66
+ List of frame indices ranked by their representativeness
67
+ """
68
+ # Resize images to the target size
69
+ images = F.interpolate(images, (image_size, image_size), mode="bilinear", align_corners=False)
70
+
71
+ # Load DINO model
72
+ dino_v2_model = torch.hub.load("facebookresearch/dinov2", model_name)
73
+ dino_v2_model.eval()
74
+ dino_v2_model = dino_v2_model.to(device)
75
+
76
+ # Normalize images using ResNet normalization
77
+ resnet_mean = torch.tensor(_RESNET_MEAN, device=device).view(1, 3, 1, 1)
78
+ resnet_std = torch.tensor(_RESNET_STD, device=device).view(1, 3, 1, 1)
79
+ images_resnet_norm = (images - resnet_mean) / resnet_std
80
+
81
+ with torch.no_grad():
82
+ frame_feat = dino_v2_model(images_resnet_norm, is_training=True)
83
+
84
+ # Process features based on similarity type
85
+ if spatial_similarity:
86
+ frame_feat = frame_feat["x_norm_patchtokens"]
87
+ frame_feat_norm = F.normalize(frame_feat, p=2, dim=1)
88
+
89
+ # Compute the similarity matrix
90
+ frame_feat_norm = frame_feat_norm.permute(1, 0, 2)
91
+ similarity_matrix = torch.bmm(frame_feat_norm, frame_feat_norm.transpose(-1, -2))
92
+ similarity_matrix = similarity_matrix.mean(dim=0)
93
+ else:
94
+ frame_feat = frame_feat["x_norm_clstoken"]
95
+ frame_feat_norm = F.normalize(frame_feat, p=2, dim=1)
96
+ similarity_matrix = torch.mm(frame_feat_norm, frame_feat_norm.transpose(-1, -2))
97
+
98
+ distance_matrix = 100 - similarity_matrix.clone()
99
+
100
+ # Ignore self-pairing
101
+ similarity_matrix.fill_diagonal_(-100)
102
+ similarity_sum = similarity_matrix.sum(dim=1)
103
+
104
+ # Find the most common frame
105
+ most_common_frame_index = torch.argmax(similarity_sum).item()
106
+
107
+ # Conduct FPS sampling starting from the most common frame
108
+ fps_idx = farthest_point_sampling(distance_matrix, query_frame_num, most_common_frame_index)
109
+
110
+ # Clean up all tensors and models to free memory
111
+ del frame_feat, frame_feat_norm, similarity_matrix, distance_matrix
112
+ del dino_v2_model
113
+ torch.cuda.empty_cache()
114
+
115
+ return fps_idx
116
+
117
+
118
+ def farthest_point_sampling(distance_matrix, num_samples, most_common_frame_index=0):
119
+ """
120
+ Farthest point sampling algorithm to select diverse frames.
121
+
122
+ Args:
123
+ distance_matrix: Matrix of distances between frames
124
+ num_samples: Number of frames to select
125
+ most_common_frame_index: Index of the first frame to select
126
+
127
+ Returns:
128
+ List of selected frame indices
129
+ """
130
+ distance_matrix = distance_matrix.clamp(min=0)
131
+ N = distance_matrix.size(0)
132
+
133
+ # Initialize with the most common frame
134
+ selected_indices = [most_common_frame_index]
135
+ check_distances = distance_matrix[selected_indices]
136
+
137
+ while len(selected_indices) < num_samples:
138
+ # Find the farthest point from the current set of selected points
139
+ farthest_point = torch.argmax(check_distances)
140
+ selected_indices.append(farthest_point.item())
141
+
142
+ check_distances = distance_matrix[farthest_point]
143
+ # Mark already selected points to avoid selecting them again
144
+ check_distances[selected_indices] = 0
145
+
146
+ # Break if all points have been selected
147
+ if len(selected_indices) == N:
148
+ break
149
+
150
+ return selected_indices
151
+
152
+
153
+ def calculate_index_mappings(query_index, S, device=None):
154
+ """
155
+ Construct an order that switches [query_index] and [0]
156
+ so that the content of query_index would be placed at [0].
157
+
158
+ Args:
159
+ query_index: Index to swap with 0
160
+ S: Total number of elements
161
+ device: Device to place the tensor on
162
+
163
+ Returns:
164
+ Tensor of indices with the swapped order
165
+ """
166
+ new_order = torch.arange(S)
167
+ new_order[0] = query_index
168
+ new_order[query_index] = 0
169
+ if device is not None:
170
+ new_order = new_order.to(device)
171
+ return new_order
172
+
173
+
174
+ def switch_tensor_order(tensors, order, dim=1):
175
+ """
176
+ Reorder tensors along a specific dimension according to the given order.
177
+
178
+ Args:
179
+ tensors: List of tensors to reorder
180
+ order: Tensor of indices specifying the new order
181
+ dim: Dimension along which to reorder
182
+
183
+ Returns:
184
+ List of reordered tensors
185
+ """
186
+ return [torch.index_select(tensor, dim, order) if tensor is not None else None for tensor in tensors]
187
+
188
+
189
+ def initialize_feature_extractors(max_query_num, det_thres=0.005, extractor_method="aliked", device="cuda"):
190
+ """
191
+ Initialize feature extractors that can be reused based on a method string.
192
+
193
+ Args:
194
+ max_query_num: Maximum number of keypoints to extract
195
+ det_thres: Detection threshold for keypoint extraction
196
+ extractor_method: String specifying which extractors to use (e.g., "aliked", "sp+sift", "aliked+sp+sift")
197
+ device: Device to run extraction on
198
+
199
+ Returns:
200
+ Dictionary of initialized extractors
201
+ """
202
+ extractors = {}
203
+ methods = extractor_method.lower().split("+")
204
+
205
+ for method in methods:
206
+ method = method.strip()
207
+ if method == "aliked":
208
+ aliked_extractor = ALIKED(max_num_keypoints=max_query_num, detection_threshold=det_thres)
209
+ extractors["aliked"] = aliked_extractor.to(device).eval()
210
+ elif method == "sp":
211
+ sp_extractor = SuperPoint(max_num_keypoints=max_query_num, detection_threshold=det_thres)
212
+ extractors["sp"] = sp_extractor.to(device).eval()
213
+ elif method == "sift":
214
+ sift_extractor = SIFT(max_num_keypoints=max_query_num)
215
+ extractors["sift"] = sift_extractor.to(device).eval()
216
+ else:
217
+ print(f"Warning: Unknown feature extractor '{method}', ignoring.")
218
+
219
+ if not extractors:
220
+ print(f"Warning: No valid extractors found in '{extractor_method}'. Using ALIKED by default.")
221
+ aliked_extractor = ALIKED(max_num_keypoints=max_query_num, detection_threshold=det_thres)
222
+ extractors["aliked"] = aliked_extractor.to(device).eval()
223
+
224
+ return extractors
225
+
226
+
227
+ def extract_keypoints(query_image, extractors, round_keypoints=True):
228
+ """
229
+ Extract keypoints using pre-initialized feature extractors.
230
+
231
+ Args:
232
+ query_image: Input image tensor (3xHxW, range [0, 1])
233
+ extractors: Dictionary of initialized extractors
234
+
235
+ Returns:
236
+ Tensor of keypoint coordinates (1xNx2)
237
+ """
238
+ query_points = None
239
+
240
+ with torch.no_grad():
241
+ for extractor_name, extractor in extractors.items():
242
+ query_points_data = extractor.extract(query_image, invalid_mask=None)
243
+ extractor_points = query_points_data["keypoints"]
244
+ if round_keypoints:
245
+ extractor_points = extractor_points.round()
246
+
247
+ if query_points is not None:
248
+ query_points = torch.cat([query_points, extractor_points], dim=1)
249
+ else:
250
+ query_points = extractor_points
251
+
252
+ return query_points
253
+
254
+
255
+ def predict_tracks_in_chunks(
256
+ track_predictor, images_feed, query_points_list, fmaps_feed, fine_tracking, num_splits=None, fine_chunk=40960
257
+ ):
258
+ """
259
+ Process a list of query points to avoid memory issues.
260
+
261
+ Args:
262
+ track_predictor (object): The track predictor object used for predicting tracks.
263
+ images_feed (torch.Tensor): A tensor of shape (B, T, C, H, W) representing a batch of images.
264
+ query_points_list (list or tuple): A list/tuple of tensors, each of shape (B, Ni, 2) representing chunks of query points.
265
+ fmaps_feed (torch.Tensor): A tensor of feature maps for the tracker.
266
+ fine_tracking (bool): Whether to perform fine tracking.
267
+ num_splits (int, optional): Ignored when query_points_list is provided. Kept for backward compatibility.
268
+
269
+ Returns:
270
+ tuple: A tuple containing the concatenated predicted tracks, visibility, and scores.
271
+ """
272
+ # If query_points_list is not a list or tuple but a single tensor, handle it like the old version for backward compatibility
273
+ if not isinstance(query_points_list, (list, tuple)):
274
+ query_points = query_points_list
275
+ if num_splits is None:
276
+ num_splits = 1
277
+ query_points_list = torch.chunk(query_points, num_splits, dim=1)
278
+
279
+ # Ensure query_points_list is a list for iteration (as torch.chunk returns a tuple)
280
+ if isinstance(query_points_list, tuple):
281
+ query_points_list = list(query_points_list)
282
+
283
+ fine_pred_track_list = []
284
+ pred_vis_list = []
285
+ pred_score_list = []
286
+
287
+ for split_points in query_points_list:
288
+ # Feed into track predictor for each split
289
+ fine_pred_track, _, pred_vis, pred_score = track_predictor(
290
+ images_feed, split_points, fmaps=fmaps_feed, fine_tracking=fine_tracking, fine_chunk=fine_chunk
291
+ )
292
+ fine_pred_track_list.append(fine_pred_track)
293
+ pred_vis_list.append(pred_vis)
294
+ pred_score_list.append(pred_score)
295
+
296
+ # Concatenate the results from all splits
297
+ fine_pred_track = torch.cat(fine_pred_track_list, dim=2)
298
+ pred_vis = torch.cat(pred_vis_list, dim=2)
299
+
300
+ if pred_score is not None:
301
+ pred_score = torch.cat(pred_score_list, dim=2)
302
+ else:
303
+ pred_score = None
304
+
305
+ return fine_pred_track, pred_vis, pred_score
capvector-pi05/src/vggt/heads/camera_head.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import numpy as np
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from vggt.layers import Mlp
15
+ from vggt.layers.block import Block
16
+ from vggt.heads.head_act import activate_pose
17
+
18
+
19
+ class CameraHead(nn.Module):
20
+ """
21
+ CameraHead predicts camera parameters from token representations using iterative refinement.
22
+
23
+ It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ dim_in: int = 2048,
29
+ trunk_depth: int = 4,
30
+ pose_encoding_type: str = "absT_quaR_FoV",
31
+ num_heads: int = 16,
32
+ mlp_ratio: int = 4,
33
+ init_values: float = 0.01,
34
+ trans_act: str = "linear",
35
+ quat_act: str = "linear",
36
+ fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
37
+ ):
38
+ super().__init__()
39
+
40
+ if pose_encoding_type == "absT_quaR_FoV":
41
+ self.target_dim = 9
42
+ else:
43
+ raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
44
+
45
+ self.trans_act = trans_act
46
+ self.quat_act = quat_act
47
+ self.fl_act = fl_act
48
+ self.trunk_depth = trunk_depth
49
+
50
+ # Build the trunk using a sequence of transformer blocks.
51
+ self.trunk = nn.Sequential(
52
+ *[
53
+ Block(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values)
54
+ for _ in range(trunk_depth)
55
+ ]
56
+ )
57
+
58
+ # Normalizations for camera token and trunk output.
59
+ self.token_norm = nn.LayerNorm(dim_in)
60
+ self.trunk_norm = nn.LayerNorm(dim_in)
61
+
62
+ # Learnable empty camera pose token.
63
+ self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
64
+ self.embed_pose = nn.Linear(self.target_dim, dim_in)
65
+
66
+ # Module for producing modulation parameters: shift, scale, and a gate.
67
+ self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
68
+
69
+ # Adaptive layer normalization without affine parameters.
70
+ self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
71
+ self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0)
72
+
73
+ def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
74
+ """
75
+ Forward pass to predict camera parameters.
76
+
77
+ Args:
78
+ aggregated_tokens_list (list): List of token tensors from the network;
79
+ the last tensor is used for prediction.
80
+ num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
81
+
82
+ Returns:
83
+ list: A list of predicted camera encodings (post-activation) from each iteration.
84
+ """
85
+ # Use tokens from the last block for camera prediction.
86
+ tokens = aggregated_tokens_list[-1]
87
+
88
+ # Extract the camera tokens
89
+ pose_tokens = tokens[:, :, 0]
90
+ pose_tokens = self.token_norm(pose_tokens)
91
+
92
+ pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
93
+ return pred_pose_enc_list
94
+
95
+ def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
96
+ """
97
+ Iteratively refine camera pose predictions.
98
+
99
+ Args:
100
+ pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, S, C].
101
+ num_iterations (int): Number of refinement iterations.
102
+
103
+ Returns:
104
+ list: List of activated camera encodings from each iteration.
105
+ """
106
+ B, S, C = pose_tokens.shape
107
+ pred_pose_enc = None
108
+ pred_pose_enc_list = []
109
+
110
+ for _ in range(num_iterations):
111
+ # Use a learned empty pose for the first iteration.
112
+ if pred_pose_enc is None:
113
+ module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
114
+ else:
115
+ # Detach the previous prediction to avoid backprop through time.
116
+ pred_pose_enc = pred_pose_enc.detach()
117
+ module_input = self.embed_pose(pred_pose_enc)
118
+
119
+ # Generate modulation parameters and split them into shift, scale, and gate components.
120
+ shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
121
+
122
+ # Adaptive layer normalization and modulation.
123
+ pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
124
+ pose_tokens_modulated = pose_tokens_modulated + pose_tokens
125
+
126
+ pose_tokens_modulated = self.trunk(pose_tokens_modulated)
127
+ # Compute the delta update for the pose encoding.
128
+ pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
129
+
130
+ if pred_pose_enc is None:
131
+ pred_pose_enc = pred_pose_enc_delta
132
+ else:
133
+ pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
134
+
135
+ # Apply final activation functions for translation, quaternion, and field-of-view.
136
+ activated_pose = activate_pose(
137
+ pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act
138
+ )
139
+ pred_pose_enc_list.append(activated_pose)
140
+
141
+ return pred_pose_enc_list
142
+
143
+
144
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
145
+ """
146
+ Modulate the input tensor using scaling and shifting parameters.
147
+ """
148
+ # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
149
+ return x * (1 + scale) + shift
capvector-pi05/src/vggt/heads/dpt_head.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ # Inspired by https://github.com/DepthAnything/Depth-Anything-V2
9
+
10
+
11
+ import os
12
+ from typing import List, Dict, Tuple, Union
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from .head_act import activate_head
18
+ from .utils import create_uv_grid, position_grid_to_embed
19
+
20
+
21
+ class DPTHead(nn.Module):
22
+ """
23
+ DPT Head for dense prediction tasks.
24
+
25
+ This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
26
+ (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
27
+ backbone and produces dense predictions by fusing multi-scale features.
28
+
29
+ Args:
30
+ dim_in (int): Input dimension (channels).
31
+ patch_size (int, optional): Patch size. Default is 14.
32
+ output_dim (int, optional): Number of output channels. Default is 4.
33
+ activation (str, optional): Activation type. Default is "inv_log".
34
+ conf_activation (str, optional): Confidence activation type. Default is "expp1".
35
+ features (int, optional): Feature channels for intermediate representations. Default is 256.
36
+ out_channels (List[int], optional): Output channels for each intermediate layer.
37
+ intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
38
+ pos_embed (bool, optional): Whether to use positional embedding. Default is True.
39
+ feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
40
+ down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ dim_in: int,
46
+ patch_size: int = 14,
47
+ output_dim: int = 4,
48
+ activation: str = "inv_log",
49
+ conf_activation: str = "expp1",
50
+ features: int = 256,
51
+ out_channels: List[int] = [256, 512, 1024, 1024],
52
+ intermediate_layer_idx: List[int] = [4, 11, 17, 23],
53
+ pos_embed: bool = True,
54
+ feature_only: bool = False,
55
+ down_ratio: int = 1,
56
+ ) -> None:
57
+ super(DPTHead, self).__init__()
58
+ self.patch_size = patch_size
59
+ self.activation = activation
60
+ self.conf_activation = conf_activation
61
+ self.pos_embed = pos_embed
62
+ self.feature_only = feature_only
63
+ self.down_ratio = down_ratio
64
+ self.intermediate_layer_idx = intermediate_layer_idx
65
+
66
+ self.norm = nn.LayerNorm(dim_in)
67
+
68
+ # Projection layers for each output channel from tokens.
69
+ self.projects = nn.ModuleList(
70
+ [nn.Conv2d(in_channels=dim_in, out_channels=oc, kernel_size=1, stride=1, padding=0) for oc in out_channels]
71
+ )
72
+
73
+ # Resize layers for upsampling feature maps.
74
+ self.resize_layers = nn.ModuleList(
75
+ [
76
+ nn.ConvTranspose2d(
77
+ in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
78
+ ),
79
+ nn.ConvTranspose2d(
80
+ in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
81
+ ),
82
+ nn.Identity(),
83
+ nn.Conv2d(
84
+ in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
85
+ ),
86
+ ]
87
+ )
88
+
89
+ self.scratch = _make_scratch(out_channels, features, expand=False)
90
+
91
+ # Attach additional modules to scratch.
92
+ self.scratch.stem_transpose = None
93
+ self.scratch.refinenet1 = _make_fusion_block(features)
94
+ self.scratch.refinenet2 = _make_fusion_block(features)
95
+ self.scratch.refinenet3 = _make_fusion_block(features)
96
+ self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
97
+
98
+ head_features_1 = features
99
+ head_features_2 = 32
100
+
101
+ if feature_only:
102
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
103
+ else:
104
+ self.scratch.output_conv1 = nn.Conv2d(
105
+ head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
106
+ )
107
+ conv2_in_channels = head_features_1 // 2
108
+
109
+ self.scratch.output_conv2 = nn.Sequential(
110
+ nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
111
+ nn.ReLU(inplace=True),
112
+ nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
113
+ )
114
+
115
+ def forward(
116
+ self,
117
+ aggregated_tokens_list: List[torch.Tensor],
118
+ images: torch.Tensor,
119
+ patch_start_idx: int,
120
+ frames_chunk_size: int = 8,
121
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
122
+ """
123
+ Forward pass through the DPT head, supports processing by chunking frames.
124
+ Args:
125
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
126
+ images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
127
+ patch_start_idx (int): Starting index for patch tokens in the token sequence.
128
+ Used to separate patch tokens from other tokens (e.g., camera or register tokens).
129
+ frames_chunk_size (int, optional): Number of frames to process in each chunk.
130
+ If None or larger than S, all frames are processed at once. Default: 8.
131
+
132
+ Returns:
133
+ Tensor or Tuple[Tensor, Tensor]:
134
+ - If feature_only=True: Feature maps with shape [B, S, C, H, W]
135
+ - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
136
+ """
137
+ B, S, _, H, W = images.shape
138
+
139
+ # If frames_chunk_size is not specified or greater than S, process all frames at once
140
+ if frames_chunk_size is None or frames_chunk_size >= S:
141
+ return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
142
+
143
+ # Otherwise, process frames in chunks to manage memory usage
144
+ assert frames_chunk_size > 0
145
+
146
+ # Process frames in batches
147
+ all_preds = []
148
+ all_conf = []
149
+
150
+ for frames_start_idx in range(0, S, frames_chunk_size):
151
+ frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
152
+
153
+ # Process batch of frames
154
+ if self.feature_only:
155
+ chunk_output = self._forward_impl(
156
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
157
+ )
158
+ all_preds.append(chunk_output)
159
+ else:
160
+ chunk_preds, chunk_conf = self._forward_impl(
161
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
162
+ )
163
+ all_preds.append(chunk_preds)
164
+ all_conf.append(chunk_conf)
165
+
166
+ # Concatenate results along the sequence dimension
167
+ if self.feature_only:
168
+ return torch.cat(all_preds, dim=1)
169
+ else:
170
+ return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
171
+
172
+ def _forward_impl(
173
+ self,
174
+ aggregated_tokens_list: List[torch.Tensor],
175
+ images: torch.Tensor,
176
+ patch_start_idx: int,
177
+ frames_start_idx: int = None,
178
+ frames_end_idx: int = None,
179
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
180
+ """
181
+ Implementation of the forward pass through the DPT head.
182
+
183
+ This method processes a specific chunk of frames from the sequence.
184
+
185
+ Args:
186
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
187
+ images (Tensor): Input images with shape [B, S, 3, H, W].
188
+ patch_start_idx (int): Starting index for patch tokens.
189
+ frames_start_idx (int, optional): Starting index for frames to process.
190
+ frames_end_idx (int, optional): Ending index for frames to process.
191
+
192
+ Returns:
193
+ Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
194
+ """
195
+ if frames_start_idx is not None and frames_end_idx is not None:
196
+ images = images[:, frames_start_idx:frames_end_idx].contiguous()
197
+
198
+ B, S, _, H, W = images.shape
199
+
200
+ patch_h, patch_w = H // self.patch_size, W // self.patch_size
201
+
202
+ out = []
203
+ dpt_idx = 0
204
+
205
+ for layer_idx in self.intermediate_layer_idx:
206
+ x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
207
+
208
+ # Select frames if processing a chunk
209
+ if frames_start_idx is not None and frames_end_idx is not None:
210
+ x = x[:, frames_start_idx:frames_end_idx]
211
+
212
+ x = x.reshape(B * S, -1, x.shape[-1])
213
+
214
+ x = self.norm(x)
215
+
216
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
217
+
218
+ x = self.projects[dpt_idx](x)
219
+ if self.pos_embed:
220
+ x = self._apply_pos_embed(x, W, H)
221
+ x = self.resize_layers[dpt_idx](x)
222
+
223
+ out.append(x)
224
+ dpt_idx += 1
225
+
226
+ # Fuse features from multiple layers.
227
+ out = self.scratch_forward(out)
228
+ # Interpolate fused output to match target image resolution.
229
+ out = custom_interpolate(
230
+ out,
231
+ (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)),
232
+ mode="bilinear",
233
+ align_corners=True,
234
+ )
235
+
236
+ if self.pos_embed:
237
+ out = self._apply_pos_embed(out, W, H)
238
+
239
+ if self.feature_only:
240
+ return out.view(B, S, *out.shape[1:])
241
+
242
+ out = self.scratch.output_conv2(out)
243
+ preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation)
244
+
245
+ preds = preds.view(B, S, *preds.shape[1:])
246
+ conf = conf.view(B, S, *conf.shape[1:])
247
+ return preds, conf
248
+
249
+ def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
250
+ """
251
+ Apply positional embedding to tensor x.
252
+ """
253
+ patch_w = x.shape[-1]
254
+ patch_h = x.shape[-2]
255
+ pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
256
+ pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
257
+ pos_embed = pos_embed * ratio
258
+ pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
259
+ return x + pos_embed
260
+
261
+ def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
262
+ """
263
+ Forward pass through the fusion blocks.
264
+
265
+ Args:
266
+ features (List[Tensor]): List of feature maps from different layers.
267
+
268
+ Returns:
269
+ Tensor: Fused feature map.
270
+ """
271
+ layer_1, layer_2, layer_3, layer_4 = features
272
+
273
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
274
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
275
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
276
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
277
+
278
+ out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
279
+ del layer_4_rn, layer_4
280
+
281
+ out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
282
+ del layer_3_rn, layer_3
283
+
284
+ out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
285
+ del layer_2_rn, layer_2
286
+
287
+ out = self.scratch.refinenet1(out, layer_1_rn)
288
+ del layer_1_rn, layer_1
289
+
290
+ out = self.scratch.output_conv1(out)
291
+ return out
292
+
293
+
294
+ ################################################################################
295
+ # Modules
296
+ ################################################################################
297
+
298
+
299
+ def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
300
+ return FeatureFusionBlock(
301
+ features,
302
+ nn.ReLU(inplace=True),
303
+ deconv=False,
304
+ bn=False,
305
+ expand=False,
306
+ align_corners=True,
307
+ size=size,
308
+ has_residual=has_residual,
309
+ groups=groups,
310
+ )
311
+
312
+
313
+ def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
314
+ scratch = nn.Module()
315
+ out_shape1 = out_shape
316
+ out_shape2 = out_shape
317
+ out_shape3 = out_shape
318
+ if len(in_shape) >= 4:
319
+ out_shape4 = out_shape
320
+
321
+ if expand:
322
+ out_shape1 = out_shape
323
+ out_shape2 = out_shape * 2
324
+ out_shape3 = out_shape * 4
325
+ if len(in_shape) >= 4:
326
+ out_shape4 = out_shape * 8
327
+
328
+ scratch.layer1_rn = nn.Conv2d(
329
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
330
+ )
331
+ scratch.layer2_rn = nn.Conv2d(
332
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
333
+ )
334
+ scratch.layer3_rn = nn.Conv2d(
335
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
336
+ )
337
+ if len(in_shape) >= 4:
338
+ scratch.layer4_rn = nn.Conv2d(
339
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
340
+ )
341
+ return scratch
342
+
343
+
344
+ class ResidualConvUnit(nn.Module):
345
+ """Residual convolution module."""
346
+
347
+ def __init__(self, features, activation, bn, groups=1):
348
+ """Init.
349
+
350
+ Args:
351
+ features (int): number of features
352
+ """
353
+ super().__init__()
354
+
355
+ self.bn = bn
356
+ self.groups = groups
357
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
358
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
359
+
360
+ self.norm1 = None
361
+ self.norm2 = None
362
+
363
+ self.activation = activation
364
+ self.skip_add = nn.quantized.FloatFunctional()
365
+
366
+ def forward(self, x):
367
+ """Forward pass.
368
+
369
+ Args:
370
+ x (tensor): input
371
+
372
+ Returns:
373
+ tensor: output
374
+ """
375
+
376
+ out = self.activation(x)
377
+ out = self.conv1(out)
378
+ if self.norm1 is not None:
379
+ out = self.norm1(out)
380
+
381
+ out = self.activation(out)
382
+ out = self.conv2(out)
383
+ if self.norm2 is not None:
384
+ out = self.norm2(out)
385
+
386
+ return self.skip_add.add(out, x)
387
+
388
+
389
+ class FeatureFusionBlock(nn.Module):
390
+ """Feature fusion block."""
391
+
392
+ def __init__(
393
+ self,
394
+ features,
395
+ activation,
396
+ deconv=False,
397
+ bn=False,
398
+ expand=False,
399
+ align_corners=True,
400
+ size=None,
401
+ has_residual=True,
402
+ groups=1,
403
+ ):
404
+ """Init.
405
+
406
+ Args:
407
+ features (int): number of features
408
+ """
409
+ super(FeatureFusionBlock, self).__init__()
410
+
411
+ self.deconv = deconv
412
+ self.align_corners = align_corners
413
+ self.groups = groups
414
+ self.expand = expand
415
+ out_features = features
416
+ if self.expand == True:
417
+ out_features = features // 2
418
+
419
+ self.out_conv = nn.Conv2d(
420
+ features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
421
+ )
422
+
423
+ if has_residual:
424
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
425
+
426
+ self.has_residual = has_residual
427
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
428
+
429
+ self.skip_add = nn.quantized.FloatFunctional()
430
+ self.size = size
431
+
432
+ def forward(self, *xs, size=None):
433
+ """Forward pass.
434
+
435
+ Returns:
436
+ tensor: output
437
+ """
438
+ output = xs[0]
439
+
440
+ if self.has_residual:
441
+ res = self.resConfUnit1(xs[1])
442
+ output = self.skip_add.add(output, res)
443
+
444
+ output = self.resConfUnit2(output)
445
+
446
+ if (size is None) and (self.size is None):
447
+ modifier = {"scale_factor": 2}
448
+ elif size is None:
449
+ modifier = {"size": self.size}
450
+ else:
451
+ modifier = {"size": size}
452
+
453
+ output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
454
+ output = self.out_conv(output)
455
+
456
+ return output
457
+
458
+
459
+ def custom_interpolate(
460
+ x: torch.Tensor,
461
+ size: Tuple[int, int] = None,
462
+ scale_factor: float = None,
463
+ mode: str = "bilinear",
464
+ align_corners: bool = True,
465
+ ) -> torch.Tensor:
466
+ """
467
+ Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
468
+ """
469
+ if size is None:
470
+ size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
471
+
472
+ INT_MAX = 1610612736
473
+
474
+ input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
475
+
476
+ if input_elements > INT_MAX:
477
+ chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
478
+ interpolated_chunks = [
479
+ nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
480
+ ]
481
+ x = torch.cat(interpolated_chunks, dim=0)
482
+ return x.contiguous()
483
+ else:
484
+ return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
capvector-pi05/src/vggt/heads/head_act.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+
12
+ def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
13
+ """
14
+ Activate pose parameters with specified activation functions.
15
+
16
+ Args:
17
+ pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
18
+ trans_act: Activation type for translation component
19
+ quat_act: Activation type for quaternion component
20
+ fl_act: Activation type for focal length component
21
+
22
+ Returns:
23
+ Activated pose parameters tensor
24
+ """
25
+ T = pred_pose_enc[..., :3]
26
+ quat = pred_pose_enc[..., 3:7]
27
+ fl = pred_pose_enc[..., 7:] # or fov
28
+
29
+ T = base_pose_act(T, trans_act)
30
+ quat = base_pose_act(quat, quat_act)
31
+ fl = base_pose_act(fl, fl_act) # or fov
32
+
33
+ pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
34
+
35
+ return pred_pose_enc
36
+
37
+
38
+ def base_pose_act(pose_enc, act_type="linear"):
39
+ """
40
+ Apply basic activation function to pose parameters.
41
+
42
+ Args:
43
+ pose_enc: Tensor containing encoded pose parameters
44
+ act_type: Activation type ("linear", "inv_log", "exp", "relu")
45
+
46
+ Returns:
47
+ Activated pose parameters
48
+ """
49
+ if act_type == "linear":
50
+ return pose_enc
51
+ elif act_type == "inv_log":
52
+ return inverse_log_transform(pose_enc)
53
+ elif act_type == "exp":
54
+ return torch.exp(pose_enc)
55
+ elif act_type == "relu":
56
+ return F.relu(pose_enc)
57
+ else:
58
+ raise ValueError(f"Unknown act_type: {act_type}")
59
+
60
+
61
+ def activate_head(out, activation="norm_exp", conf_activation="expp1"):
62
+ """
63
+ Process network output to extract 3D points and confidence values.
64
+
65
+ Args:
66
+ out: Network output tensor (B, C, H, W)
67
+ activation: Activation type for 3D points
68
+ conf_activation: Activation type for confidence values
69
+
70
+ Returns:
71
+ Tuple of (3D points tensor, confidence tensor)
72
+ """
73
+ # Move channels from last dim to the 4th dimension => (B, H, W, C)
74
+ fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
75
+
76
+ # Split into xyz (first C-1 channels) and confidence (last channel)
77
+ xyz = fmap[:, :, :, :-1]
78
+ conf = fmap[:, :, :, -1]
79
+
80
+ if activation == "norm_exp":
81
+ d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
82
+ xyz_normed = xyz / d
83
+ pts3d = xyz_normed * torch.expm1(d)
84
+ elif activation == "norm":
85
+ pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
86
+ elif activation == "exp":
87
+ pts3d = torch.exp(xyz)
88
+ elif activation == "relu":
89
+ pts3d = F.relu(xyz)
90
+ elif activation == "inv_log":
91
+ pts3d = inverse_log_transform(xyz)
92
+ elif activation == "xy_inv_log":
93
+ xy, z = xyz.split([2, 1], dim=-1)
94
+ z = inverse_log_transform(z)
95
+ pts3d = torch.cat([xy * z, z], dim=-1)
96
+ elif activation == "sigmoid":
97
+ pts3d = torch.sigmoid(xyz)
98
+ elif activation == "linear":
99
+ pts3d = xyz
100
+ else:
101
+ raise ValueError(f"Unknown activation: {activation}")
102
+
103
+ if conf_activation == "expp1":
104
+ conf_out = 1 + conf.exp()
105
+ elif conf_activation == "expp0":
106
+ conf_out = conf.exp()
107
+ elif conf_activation == "sigmoid":
108
+ conf_out = torch.sigmoid(conf)
109
+ else:
110
+ raise ValueError(f"Unknown conf_activation: {conf_activation}")
111
+
112
+ return pts3d, conf_out
113
+
114
+
115
+ def inverse_log_transform(y):
116
+ """
117
+ Apply inverse log transform: sign(y) * (exp(|y|) - 1)
118
+
119
+ Args:
120
+ y: Input tensor
121
+
122
+ Returns:
123
+ Transformed tensor
124
+ """
125
+ return torch.sign(y) * (torch.expm1(torch.abs(y)))
capvector-pi05/src/vggt/heads/track_head.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch.nn as nn
8
+ from .dpt_head import DPTHead
9
+ from .track_modules.base_track_predictor import BaseTrackerPredictor
10
+
11
+
12
+ class TrackHead(nn.Module):
13
+ """
14
+ Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking.
15
+ The tracking is performed iteratively, refining predictions over multiple iterations.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ dim_in,
21
+ patch_size=14,
22
+ features=128,
23
+ iters=4,
24
+ predict_conf=True,
25
+ stride=2,
26
+ corr_levels=7,
27
+ corr_radius=4,
28
+ hidden_size=384,
29
+ ):
30
+ """
31
+ Initialize the TrackHead module.
32
+
33
+ Args:
34
+ dim_in (int): Input dimension of tokens from the backbone.
35
+ patch_size (int): Size of image patches used in the vision transformer.
36
+ features (int): Number of feature channels in the feature extractor output.
37
+ iters (int): Number of refinement iterations for tracking predictions.
38
+ predict_conf (bool): Whether to predict confidence scores for tracked points.
39
+ stride (int): Stride value for the tracker predictor.
40
+ corr_levels (int): Number of correlation pyramid levels
41
+ corr_radius (int): Radius for correlation computation, controlling the search area.
42
+ hidden_size (int): Size of hidden layers in the tracker network.
43
+ """
44
+ super().__init__()
45
+
46
+ self.patch_size = patch_size
47
+
48
+ # Feature extractor based on DPT architecture
49
+ # Processes tokens into feature maps for tracking
50
+ self.feature_extractor = DPTHead(
51
+ dim_in=dim_in,
52
+ patch_size=patch_size,
53
+ features=features,
54
+ feature_only=True, # Only output features, no activation
55
+ down_ratio=2, # Reduces spatial dimensions by factor of 2
56
+ pos_embed=False,
57
+ )
58
+
59
+ # Tracker module that predicts point trajectories
60
+ # Takes feature maps and predicts coordinates and visibility
61
+ self.tracker = BaseTrackerPredictor(
62
+ latent_dim=features, # Match the output_dim of feature extractor
63
+ predict_conf=predict_conf,
64
+ stride=stride,
65
+ corr_levels=corr_levels,
66
+ corr_radius=corr_radius,
67
+ hidden_size=hidden_size,
68
+ )
69
+
70
+ self.iters = iters
71
+
72
+ def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None):
73
+ """
74
+ Forward pass of the TrackHead.
75
+
76
+ Args:
77
+ aggregated_tokens_list (list): List of aggregated tokens from the backbone.
78
+ images (torch.Tensor): Input images of shape (B, S, C, H, W) where:
79
+ B = batch size, S = sequence length.
80
+ patch_start_idx (int): Starting index for patch tokens.
81
+ query_points (torch.Tensor, optional): Initial query points to track.
82
+ If None, points are initialized by the tracker.
83
+ iters (int, optional): Number of refinement iterations. If None, uses self.iters.
84
+
85
+ Returns:
86
+ tuple:
87
+ - coord_preds (torch.Tensor): Predicted coordinates for tracked points.
88
+ - vis_scores (torch.Tensor): Visibility scores for tracked points.
89
+ - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True).
90
+ """
91
+ B, S, _, H, W = images.shape
92
+
93
+ # Extract features from tokens
94
+ # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2
95
+ feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx)
96
+
97
+ # Use default iterations if not specified
98
+ if iters is None:
99
+ iters = self.iters
100
+
101
+ # Perform tracking using the extracted features
102
+ coord_preds, vis_scores, conf_scores = self.tracker(query_points=query_points, fmaps=feature_maps, iters=iters)
103
+
104
+ return coord_preds, vis_scores, conf_scores
capvector-pi05/src/vggt/heads/track_modules/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
capvector-pi05/src/vggt/heads/track_modules/base_track_predictor.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from einops import rearrange, repeat
10
+
11
+
12
+ from .blocks import EfficientUpdateFormer, CorrBlock
13
+ from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed
14
+ from .modules import Mlp
15
+
16
+
17
+ class BaseTrackerPredictor(nn.Module):
18
+ def __init__(
19
+ self,
20
+ stride=1,
21
+ corr_levels=5,
22
+ corr_radius=4,
23
+ latent_dim=128,
24
+ hidden_size=384,
25
+ use_spaceatt=True,
26
+ depth=6,
27
+ max_scale=518,
28
+ predict_conf=True,
29
+ ):
30
+ super(BaseTrackerPredictor, self).__init__()
31
+ """
32
+ The base template to create a track predictor
33
+
34
+ Modified from https://github.com/facebookresearch/co-tracker/
35
+ and https://github.com/facebookresearch/vggsfm
36
+ """
37
+
38
+ self.stride = stride
39
+ self.latent_dim = latent_dim
40
+ self.corr_levels = corr_levels
41
+ self.corr_radius = corr_radius
42
+ self.hidden_size = hidden_size
43
+ self.max_scale = max_scale
44
+ self.predict_conf = predict_conf
45
+
46
+ self.flows_emb_dim = latent_dim // 2
47
+
48
+ self.corr_mlp = Mlp(
49
+ in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2,
50
+ hidden_features=self.hidden_size,
51
+ out_features=self.latent_dim,
52
+ )
53
+
54
+ self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4
55
+
56
+ self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim))
57
+
58
+ space_depth = depth if use_spaceatt else 0
59
+ time_depth = depth
60
+
61
+ self.updateformer = EfficientUpdateFormer(
62
+ space_depth=space_depth,
63
+ time_depth=time_depth,
64
+ input_dim=self.transformer_dim,
65
+ hidden_size=self.hidden_size,
66
+ output_dim=self.latent_dim + 2,
67
+ mlp_ratio=4.0,
68
+ add_space_attn=use_spaceatt,
69
+ )
70
+
71
+ self.fmap_norm = nn.LayerNorm(self.latent_dim)
72
+ self.ffeat_norm = nn.GroupNorm(1, self.latent_dim)
73
+
74
+ # A linear layer to update track feats at each iteration
75
+ self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU())
76
+
77
+ self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
78
+
79
+ if predict_conf:
80
+ self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
81
+
82
+ def forward(self, query_points, fmaps=None, iters=6, return_feat=False, down_ratio=1, apply_sigmoid=True):
83
+ """
84
+ query_points: B x N x 2, the number of batches, tracks, and xy
85
+ fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension.
86
+ note HH and WW is the size of feature maps instead of original images
87
+ """
88
+ B, N, D = query_points.shape
89
+ B, S, C, HH, WW = fmaps.shape
90
+
91
+ assert D == 2, "Input points must be 2D coordinates"
92
+
93
+ # apply a layernorm to fmaps here
94
+ fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2))
95
+ fmaps = fmaps.permute(0, 1, 4, 2, 3)
96
+
97
+ # Scale the input query_points because we may downsample the images
98
+ # by down_ratio or self.stride
99
+ # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map
100
+ # its query_points should be query_points/4
101
+ if down_ratio > 1:
102
+ query_points = query_points / float(down_ratio)
103
+
104
+ query_points = query_points / float(self.stride)
105
+
106
+ # Init with coords as the query points
107
+ # It means the search will start from the position of query points at the reference frames
108
+ coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)
109
+
110
+ # Sample/extract the features of the query points in the query frame
111
+ query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])
112
+
113
+ # init track feats by query feats
114
+ track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C
115
+ # back up the init coords
116
+ coords_backup = coords.clone()
117
+
118
+ fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius)
119
+
120
+ coord_preds = []
121
+
122
+ # Iterative Refinement
123
+ for _ in range(iters):
124
+ # Detach the gradients from the last iteration
125
+ # (in my experience, not very important for performance)
126
+ coords = coords.detach()
127
+
128
+ fcorrs = fcorr_fn.corr_sample(track_feats, coords)
129
+
130
+ corr_dim = fcorrs.shape[3]
131
+ fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim)
132
+ fcorrs_ = self.corr_mlp(fcorrs_)
133
+
134
+ # Movement of current coords relative to query points
135
+ flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
136
+
137
+ flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)
138
+
139
+ # (In my trials, it is also okay to just add the flows_emb instead of concat)
140
+ flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1)
141
+
142
+ track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
143
+
144
+ # Concatenate them as the input for the transformers
145
+ transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)
146
+
147
+ # 2D positional embed
148
+ # TODO: this can be much simplified
149
+ pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device)
150
+ sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0])
151
+
152
+ sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1)
153
+
154
+ x = transformer_input + sampled_pos_emb
155
+
156
+ # Add the query ref token to the track feats
157
+ query_ref_token = torch.cat(
158
+ [self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1
159
+ )
160
+ x = x + query_ref_token.to(x.device).to(x.dtype)
161
+
162
+ # B, N, S, C
163
+ x = rearrange(x, "(b n) s d -> b n s d", b=B)
164
+
165
+ # Compute the delta coordinates and delta track features
166
+ delta, _ = self.updateformer(x)
167
+
168
+ # BN, S, C
169
+ delta = rearrange(delta, " b n s d -> (b n) s d", b=B)
170
+ delta_coords_ = delta[:, :, :2]
171
+ delta_feats_ = delta[:, :, 2:]
172
+
173
+ track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim)
174
+ delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
175
+
176
+ # Update the track features
177
+ track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_
178
+
179
+ track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC
180
+
181
+ # B x S x N x 2
182
+ coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
183
+
184
+ # Force coord0 as query
185
+ # because we assume the query points should not be changed
186
+ coords[:, 0] = coords_backup[:, 0]
187
+
188
+ # The predicted tracks are in the original image scale
189
+ if down_ratio > 1:
190
+ coord_preds.append(coords * self.stride * down_ratio)
191
+ else:
192
+ coord_preds.append(coords * self.stride)
193
+
194
+ # B, S, N
195
+ vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
196
+ if apply_sigmoid:
197
+ vis_e = torch.sigmoid(vis_e)
198
+
199
+ if self.predict_conf:
200
+ conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
201
+ if apply_sigmoid:
202
+ conf_e = torch.sigmoid(conf_e)
203
+ else:
204
+ conf_e = None
205
+
206
+ if return_feat:
207
+ return coord_preds, vis_e, track_feats, query_track_feat, conf_e
208
+ else:
209
+ return coord_preds, vis_e, conf_e
capvector-pi05/src/vggt/heads/track_modules/blocks.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ # Modified from https://github.com/facebookresearch/co-tracker/
9
+
10
+ import math
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from .utils import bilinear_sampler
16
+ from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock
17
+
18
+
19
+ class EfficientUpdateFormer(nn.Module):
20
+ """
21
+ Transformer model that updates track estimates.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ space_depth=6,
27
+ time_depth=6,
28
+ input_dim=320,
29
+ hidden_size=384,
30
+ num_heads=8,
31
+ output_dim=130,
32
+ mlp_ratio=4.0,
33
+ add_space_attn=True,
34
+ num_virtual_tracks=64,
35
+ ):
36
+ super().__init__()
37
+
38
+ self.out_channels = 2
39
+ self.num_heads = num_heads
40
+ self.hidden_size = hidden_size
41
+ self.add_space_attn = add_space_attn
42
+
43
+ # Add input LayerNorm before linear projection
44
+ self.input_norm = nn.LayerNorm(input_dim)
45
+ self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
46
+
47
+ # Add output LayerNorm before final projection
48
+ self.output_norm = nn.LayerNorm(hidden_size)
49
+ self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
50
+ self.num_virtual_tracks = num_virtual_tracks
51
+
52
+ if self.add_space_attn:
53
+ self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size))
54
+ else:
55
+ self.virual_tracks = None
56
+
57
+ self.time_blocks = nn.ModuleList(
58
+ [
59
+ AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention)
60
+ for _ in range(time_depth)
61
+ ]
62
+ )
63
+
64
+ if add_space_attn:
65
+ self.space_virtual_blocks = nn.ModuleList(
66
+ [
67
+ AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention)
68
+ for _ in range(space_depth)
69
+ ]
70
+ )
71
+ self.space_point2virtual_blocks = nn.ModuleList(
72
+ [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
73
+ )
74
+ self.space_virtual2point_blocks = nn.ModuleList(
75
+ [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
76
+ )
77
+ assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
78
+ self.initialize_weights()
79
+
80
+ def initialize_weights(self):
81
+ def _basic_init(module):
82
+ if isinstance(module, nn.Linear):
83
+ torch.nn.init.xavier_uniform_(module.weight)
84
+ if module.bias is not None:
85
+ nn.init.constant_(module.bias, 0)
86
+ torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
87
+
88
+ self.apply(_basic_init)
89
+
90
+ def forward(self, input_tensor, mask=None):
91
+ # Apply input LayerNorm
92
+ input_tensor = self.input_norm(input_tensor)
93
+ tokens = self.input_transform(input_tensor)
94
+
95
+ init_tokens = tokens
96
+
97
+ B, _, T, _ = tokens.shape
98
+
99
+ if self.add_space_attn:
100
+ virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
101
+ tokens = torch.cat([tokens, virtual_tokens], dim=1)
102
+
103
+ _, N, _, _ = tokens.shape
104
+
105
+ j = 0
106
+ for i in range(len(self.time_blocks)):
107
+ time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
108
+
109
+ time_tokens = self.time_blocks[i](time_tokens)
110
+
111
+ tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
112
+ if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0):
113
+ space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C
114
+ point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
115
+ virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
116
+
117
+ virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask)
118
+ virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
119
+ point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask)
120
+
121
+ space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
122
+ tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C
123
+ j += 1
124
+
125
+ if self.add_space_attn:
126
+ tokens = tokens[:, : N - self.num_virtual_tracks]
127
+
128
+ tokens = tokens + init_tokens
129
+
130
+ # Apply output LayerNorm before final projection
131
+ tokens = self.output_norm(tokens)
132
+ flow = self.flow_head(tokens)
133
+
134
+ return flow, None
135
+
136
+
137
+ class CorrBlock:
138
+ def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"):
139
+ """
140
+ Build a pyramid of feature maps from the input.
141
+
142
+ fmaps: Tensor (B, S, C, H, W)
143
+ num_levels: number of pyramid levels (each downsampled by factor 2)
144
+ radius: search radius for sampling correlation
145
+ multiple_track_feats: if True, split the target features per pyramid level
146
+ padding_mode: passed to grid_sample / bilinear_sampler
147
+ """
148
+ B, S, C, H, W = fmaps.shape
149
+ self.S, self.C, self.H, self.W = S, C, H, W
150
+ self.num_levels = num_levels
151
+ self.radius = radius
152
+ self.padding_mode = padding_mode
153
+ self.multiple_track_feats = multiple_track_feats
154
+
155
+ # Build pyramid: each level is half the spatial resolution of the previous
156
+ self.fmaps_pyramid = [fmaps] # level 0 is full resolution
157
+ current_fmaps = fmaps
158
+ for i in range(num_levels - 1):
159
+ B, S, C, H, W = current_fmaps.shape
160
+ # Merge batch & sequence dimensions
161
+ current_fmaps = current_fmaps.reshape(B * S, C, H, W)
162
+ # Avg pool down by factor 2
163
+ current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2)
164
+ _, _, H_new, W_new = current_fmaps.shape
165
+ current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new)
166
+ self.fmaps_pyramid.append(current_fmaps)
167
+
168
+ # Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling.
169
+ # This grid is added to the (scaled) coordinate centroids.
170
+ r = self.radius
171
+ dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
172
+ dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
173
+ # delta: for every (dy,dx) displacement (i.e. Δx, Δy)
174
+ self.delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), dim=-1) # shape: (2r+1, 2r+1, 2)
175
+
176
+ def corr_sample(self, targets, coords):
177
+ """
178
+ Instead of storing the entire correlation pyramid, we compute each level's correlation
179
+ volume, sample it immediately, then discard it. This saves GPU memory.
180
+
181
+ Args:
182
+ targets: Tensor (B, S, N, C) — features for the current targets.
183
+ coords: Tensor (B, S, N, 2) — coordinates at full resolution.
184
+
185
+ Returns:
186
+ Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations)
187
+ """
188
+ B, S, N, C = targets.shape
189
+
190
+ # If you have multiple track features, split them per level.
191
+ if self.multiple_track_feats:
192
+ targets_split = torch.split(targets, C // self.num_levels, dim=-1)
193
+
194
+ out_pyramid = []
195
+ for i, fmaps in enumerate(self.fmaps_pyramid):
196
+ # Get current spatial resolution H, W for this pyramid level.
197
+ B, S, C, H, W = fmaps.shape
198
+ # Reshape feature maps for correlation computation:
199
+ # fmap2s: (B, S, C, H*W)
200
+ fmap2s = fmaps.view(B, S, C, H * W)
201
+ # Choose appropriate target features.
202
+ fmap1 = targets_split[i] if self.multiple_track_feats else targets # shape: (B, S, N, C)
203
+
204
+ # Compute correlation directly
205
+ corrs = compute_corr_level(fmap1, fmap2s, C)
206
+ corrs = corrs.view(B, S, N, H, W)
207
+
208
+ # Prepare sampling grid:
209
+ # Scale down the coordinates for the current level.
210
+ centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i)
211
+ # Make sure our precomputed delta grid is on the same device/dtype.
212
+ delta_lvl = self.delta.to(coords.device).to(coords.dtype)
213
+ # Now the grid for grid_sample is:
214
+ # coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid)
215
+ coords_lvl = centroid_lvl + delta_lvl.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2)
216
+
217
+ # Sample from the correlation volume using bilinear interpolation.
218
+ # We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target.
219
+ corrs_sampled = bilinear_sampler(
220
+ corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode
221
+ )
222
+ # The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims.
223
+ corrs_sampled = corrs_sampled.view(B, S, N, -1) # Now shape: (B, S, N, (2r+1)^2)
224
+ out_pyramid.append(corrs_sampled)
225
+
226
+ # Concatenate all levels along the last dimension.
227
+ out = torch.cat(out_pyramid, dim=-1).contiguous()
228
+ return out
229
+
230
+
231
+ def compute_corr_level(fmap1, fmap2s, C):
232
+ # fmap1: (B, S, N, C)
233
+ # fmap2s: (B, S, C, H*W)
234
+ corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W)
235
+ corrs = corrs.view(fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1) # (B, S, N, H*W)
236
+ return corrs / math.sqrt(C)
capvector-pi05/src/vggt/heads/track_modules/modules.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from functools import partial
12
+ from typing import Callable
13
+ import collections
14
+ from torch import Tensor
15
+ from itertools import repeat
16
+
17
+
18
+ # From PyTorch internals
19
+ def _ntuple(n):
20
+ def parse(x):
21
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
22
+ return tuple(x)
23
+ return tuple(repeat(x, n))
24
+
25
+ return parse
26
+
27
+
28
+ def exists(val):
29
+ return val is not None
30
+
31
+
32
+ def default(val, d):
33
+ return val if exists(val) else d
34
+
35
+
36
+ to_2tuple = _ntuple(2)
37
+
38
+
39
+ class ResidualBlock(nn.Module):
40
+ """
41
+ ResidualBlock: construct a block of two conv layers with residual connections
42
+ """
43
+
44
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3):
45
+ super(ResidualBlock, self).__init__()
46
+
47
+ self.conv1 = nn.Conv2d(
48
+ in_planes, planes, kernel_size=kernel_size, padding=1, stride=stride, padding_mode="zeros"
49
+ )
50
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, padding=1, padding_mode="zeros")
51
+ self.relu = nn.ReLU(inplace=True)
52
+
53
+ num_groups = planes // 8
54
+
55
+ if norm_fn == "group":
56
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
57
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
58
+ if not stride == 1:
59
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
60
+
61
+ elif norm_fn == "batch":
62
+ self.norm1 = nn.BatchNorm2d(planes)
63
+ self.norm2 = nn.BatchNorm2d(planes)
64
+ if not stride == 1:
65
+ self.norm3 = nn.BatchNorm2d(planes)
66
+
67
+ elif norm_fn == "instance":
68
+ self.norm1 = nn.InstanceNorm2d(planes)
69
+ self.norm2 = nn.InstanceNorm2d(planes)
70
+ if not stride == 1:
71
+ self.norm3 = nn.InstanceNorm2d(planes)
72
+
73
+ elif norm_fn == "none":
74
+ self.norm1 = nn.Sequential()
75
+ self.norm2 = nn.Sequential()
76
+ if not stride == 1:
77
+ self.norm3 = nn.Sequential()
78
+ else:
79
+ raise NotImplementedError
80
+
81
+ if stride == 1:
82
+ self.downsample = None
83
+ else:
84
+ self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
85
+
86
+ def forward(self, x):
87
+ y = x
88
+ y = self.relu(self.norm1(self.conv1(y)))
89
+ y = self.relu(self.norm2(self.conv2(y)))
90
+
91
+ if self.downsample is not None:
92
+ x = self.downsample(x)
93
+
94
+ return self.relu(x + y)
95
+
96
+
97
+ class Mlp(nn.Module):
98
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
99
+
100
+ def __init__(
101
+ self,
102
+ in_features,
103
+ hidden_features=None,
104
+ out_features=None,
105
+ act_layer=nn.GELU,
106
+ norm_layer=None,
107
+ bias=True,
108
+ drop=0.0,
109
+ use_conv=False,
110
+ ):
111
+ super().__init__()
112
+ out_features = out_features or in_features
113
+ hidden_features = hidden_features or in_features
114
+ bias = to_2tuple(bias)
115
+ drop_probs = to_2tuple(drop)
116
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
117
+
118
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
119
+ self.act = act_layer()
120
+ self.drop1 = nn.Dropout(drop_probs[0])
121
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
122
+ self.drop2 = nn.Dropout(drop_probs[1])
123
+
124
+ def forward(self, x):
125
+ x = self.fc1(x)
126
+ x = self.act(x)
127
+ x = self.drop1(x)
128
+ x = self.fc2(x)
129
+ x = self.drop2(x)
130
+ return x
131
+
132
+
133
+ class AttnBlock(nn.Module):
134
+ def __init__(
135
+ self,
136
+ hidden_size,
137
+ num_heads,
138
+ attn_class: Callable[..., nn.Module] = nn.MultiheadAttention,
139
+ mlp_ratio=4.0,
140
+ **block_kwargs,
141
+ ):
142
+ """
143
+ Self attention block
144
+ """
145
+ super().__init__()
146
+
147
+ self.norm1 = nn.LayerNorm(hidden_size)
148
+ self.norm2 = nn.LayerNorm(hidden_size)
149
+
150
+ self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs)
151
+
152
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
153
+
154
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
155
+
156
+ def forward(self, x, mask=None):
157
+ # Prepare the mask for PyTorch's attention (it expects a different format)
158
+ # attn_mask = mask if mask is not None else None
159
+ # Normalize before attention
160
+ x = self.norm1(x)
161
+
162
+ # PyTorch's MultiheadAttention returns attn_output, attn_output_weights
163
+ # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask)
164
+
165
+ attn_output, _ = self.attn(x, x, x)
166
+
167
+ # Add & Norm
168
+ x = x + attn_output
169
+ x = x + self.mlp(self.norm2(x))
170
+ return x
171
+
172
+
173
+ class CrossAttnBlock(nn.Module):
174
+ def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):
175
+ """
176
+ Cross attention block
177
+ """
178
+ super().__init__()
179
+
180
+ self.norm1 = nn.LayerNorm(hidden_size)
181
+ self.norm_context = nn.LayerNorm(hidden_size)
182
+ self.norm2 = nn.LayerNorm(hidden_size)
183
+
184
+ self.cross_attn = nn.MultiheadAttention(
185
+ embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs
186
+ )
187
+
188
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
189
+
190
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
191
+
192
+ def forward(self, x, context, mask=None):
193
+ # Normalize inputs
194
+ x = self.norm1(x)
195
+ context = self.norm_context(context)
196
+
197
+ # Apply cross attention
198
+ # Note: nn.MultiheadAttention returns attn_output, attn_output_weights
199
+ attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask)
200
+
201
+ # Add & Norm
202
+ x = x + attn_output
203
+ x = x + self.mlp(self.norm2(x))
204
+ return x
capvector-pi05/src/vggt/heads/track_modules/utils.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from https://github.com/facebookresearch/vggsfm
8
+ # and https://github.com/facebookresearch/co-tracker/tree/main
9
+
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from typing import Optional, Tuple, Union
16
+
17
+
18
+ def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor:
19
+ """
20
+ This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
21
+ It is a wrapper of get_2d_sincos_pos_embed_from_grid.
22
+ Args:
23
+ - embed_dim: The embedding dimension.
24
+ - grid_size: The grid size.
25
+ Returns:
26
+ - pos_embed: The generated 2D positional embedding.
27
+ """
28
+ if isinstance(grid_size, tuple):
29
+ grid_size_h, grid_size_w = grid_size
30
+ else:
31
+ grid_size_h = grid_size_w = grid_size
32
+ grid_h = torch.arange(grid_size_h, dtype=torch.float)
33
+ grid_w = torch.arange(grid_size_w, dtype=torch.float)
34
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
35
+ grid = torch.stack(grid, dim=0)
36
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
37
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
38
+ if return_grid:
39
+ return (pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2), grid)
40
+ return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
41
+
42
+
43
+ def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor:
44
+ """
45
+ This function generates a 2D positional embedding from a given grid using sine and cosine functions.
46
+
47
+ Args:
48
+ - embed_dim: The embedding dimension.
49
+ - grid: The grid to generate the embedding from.
50
+
51
+ Returns:
52
+ - emb: The generated 2D positional embedding.
53
+ """
54
+ assert embed_dim % 2 == 0
55
+
56
+ # use half of dimensions to encode grid_h
57
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
58
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
59
+
60
+ emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
61
+ return emb
62
+
63
+
64
+ def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor:
65
+ """
66
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
67
+
68
+ Args:
69
+ - embed_dim: The embedding dimension.
70
+ - pos: The position to generate the embedding from.
71
+
72
+ Returns:
73
+ - emb: The generated 1D positional embedding.
74
+ """
75
+ assert embed_dim % 2 == 0
76
+ omega = torch.arange(embed_dim // 2, dtype=torch.double)
77
+ omega /= embed_dim / 2.0
78
+ omega = 1.0 / 10000**omega # (D/2,)
79
+
80
+ pos = pos.reshape(-1) # (M,)
81
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
82
+
83
+ emb_sin = torch.sin(out) # (M, D/2)
84
+ emb_cos = torch.cos(out) # (M, D/2)
85
+
86
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
87
+ return emb[None].float()
88
+
89
+
90
+ def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
91
+ """
92
+ This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
93
+
94
+ Args:
95
+ - xy: The coordinates to generate the embedding from.
96
+ - C: The size of the embedding.
97
+ - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
98
+
99
+ Returns:
100
+ - pe: The generated 2D positional embedding.
101
+ """
102
+ B, N, D = xy.shape
103
+ assert D == 2
104
+
105
+ x = xy[:, :, 0:1]
106
+ y = xy[:, :, 1:2]
107
+ div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2))
108
+
109
+ pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
110
+ pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
111
+
112
+ pe_x[:, :, 0::2] = torch.sin(x * div_term)
113
+ pe_x[:, :, 1::2] = torch.cos(x * div_term)
114
+
115
+ pe_y[:, :, 0::2] = torch.sin(y * div_term)
116
+ pe_y[:, :, 1::2] = torch.cos(y * div_term)
117
+
118
+ pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
119
+ if cat_coords:
120
+ pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
121
+ return pe
122
+
123
+
124
+ def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
125
+ r"""Sample a tensor using bilinear interpolation
126
+
127
+ `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
128
+ coordinates :attr:`coords` using bilinear interpolation. It is the same
129
+ as `torch.nn.functional.grid_sample()` but with a different coordinate
130
+ convention.
131
+
132
+ The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
133
+ :math:`B` is the batch size, :math:`C` is the number of channels,
134
+ :math:`H` is the height of the image, and :math:`W` is the width of the
135
+ image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
136
+ interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
137
+
138
+ Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
139
+ in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
140
+ that in this case the order of the components is slightly different
141
+ from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
142
+
143
+ If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
144
+ in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
145
+ left-most image pixel :math:`W-1` to the center of the right-most
146
+ pixel.
147
+
148
+ If `align_corners` is `False`, the coordinate :math:`x` is assumed to
149
+ be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
150
+ the left-most pixel :math:`W` to the right edge of the right-most
151
+ pixel.
152
+
153
+ Similar conventions apply to the :math:`y` for the range
154
+ :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
155
+ :math:`[0,T-1]` and :math:`[0,T]`.
156
+
157
+ Args:
158
+ input (Tensor): batch of input images.
159
+ coords (Tensor): batch of coordinates.
160
+ align_corners (bool, optional): Coordinate convention. Defaults to `True`.
161
+ padding_mode (str, optional): Padding mode. Defaults to `"border"`.
162
+
163
+ Returns:
164
+ Tensor: sampled points.
165
+ """
166
+ coords = coords.detach().clone()
167
+ ############################################################
168
+ # IMPORTANT:
169
+ coords = coords.to(input.device).to(input.dtype)
170
+ ############################################################
171
+
172
+ sizes = input.shape[2:]
173
+
174
+ assert len(sizes) in [2, 3]
175
+
176
+ if len(sizes) == 3:
177
+ # t x y -> x y t to match dimensions T H W in grid_sample
178
+ coords = coords[..., [1, 2, 0]]
179
+
180
+ if align_corners:
181
+ scale = torch.tensor(
182
+ [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype
183
+ )
184
+ else:
185
+ scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype)
186
+
187
+ coords.mul_(scale) # coords = coords * scale
188
+ coords.sub_(1) # coords = coords - 1
189
+
190
+ return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
191
+
192
+
193
+ def sample_features4d(input, coords):
194
+ r"""Sample spatial features
195
+
196
+ `sample_features4d(input, coords)` samples the spatial features
197
+ :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
198
+
199
+ The field is sampled at coordinates :attr:`coords` using bilinear
200
+ interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
201
+ 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
202
+ same convention as :func:`bilinear_sampler` with `align_corners=True`.
203
+
204
+ The output tensor has one feature per point, and has shape :math:`(B,
205
+ R, C)`.
206
+
207
+ Args:
208
+ input (Tensor): spatial features.
209
+ coords (Tensor): points.
210
+
211
+ Returns:
212
+ Tensor: sampled features.
213
+ """
214
+
215
+ B, _, _, _ = input.shape
216
+
217
+ # B R 2 -> B R 1 2
218
+ coords = coords.unsqueeze(2)
219
+
220
+ # B C R 1
221
+ feats = bilinear_sampler(input, coords)
222
+
223
+ return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C
capvector-pi05/src/vggt/heads/utils.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import numpy as np
10
+ from typing import List, Dict, Tuple, Union
11
+ from einops import rearrange
12
+
13
+ def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
14
+ """
15
+ Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
16
+
17
+ Args:
18
+ pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
19
+ embed_dim: Output channel dimension for embeddings
20
+
21
+ Returns:
22
+ Tensor of shape (H, W, embed_dim) with positional embeddings
23
+ """
24
+ H, W, grid_dim = pos_grid.shape
25
+ assert grid_dim == 2
26
+ pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
27
+
28
+ # Process x and y coordinates separately
29
+ emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
30
+ emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
31
+
32
+ # Combine and reshape
33
+ emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
34
+
35
+ return emb.view(H, W, embed_dim) # [H, W, D]
36
+
37
+
38
+ def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
39
+ """
40
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
41
+
42
+ Args:
43
+ - embed_dim: The embedding dimension.
44
+ - pos: The position to generate the embedding from.
45
+
46
+ Returns:
47
+ - emb: The generated 1D positional embedding.
48
+ """
49
+ assert embed_dim % 2 == 0
50
+ device = pos.device
51
+ omega = torch.arange(embed_dim // 2, dtype=torch.float32 if device.type == "mps" else torch.double, device=device)
52
+ omega /= embed_dim / 2.0
53
+ omega = 1.0 / omega_0**omega # (D/2,)
54
+
55
+ pos = pos.reshape(-1) # (M,)
56
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
57
+
58
+ emb_sin = torch.sin(out) # (M, D/2)
59
+ emb_cos = torch.cos(out) # (M, D/2)
60
+
61
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
62
+ return emb.float()
63
+
64
+
65
+ # Inspired by https://github.com/microsoft/moge
66
+
67
+
68
+ def create_uv_grid(
69
+ width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
70
+ ) -> torch.Tensor:
71
+ """
72
+ Create a normalized UV grid of shape (width, height, 2).
73
+
74
+ The grid spans horizontally and vertically according to an aspect ratio,
75
+ ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
76
+ corner is at (x_span, y_span), normalized by the diagonal of the plane.
77
+
78
+ Args:
79
+ width (int): Number of points horizontally.
80
+ height (int): Number of points vertically.
81
+ aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
82
+ dtype (torch.dtype, optional): Data type of the resulting tensor.
83
+ device (torch.device, optional): Device on which the tensor is created.
84
+
85
+ Returns:
86
+ torch.Tensor: A (width, height, 2) tensor of UV coordinates.
87
+ """
88
+ # Derive aspect ratio if not explicitly provided
89
+ if aspect_ratio is None:
90
+ aspect_ratio = float(width) / float(height)
91
+
92
+ # Compute normalized spans for X and Y
93
+ diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
94
+ span_x = aspect_ratio / diag_factor
95
+ span_y = 1.0 / diag_factor
96
+
97
+ # Establish the linspace boundaries
98
+ left_x = -span_x * (width - 1) / width
99
+ right_x = span_x * (width - 1) / width
100
+ top_y = -span_y * (height - 1) / height
101
+ bottom_y = span_y * (height - 1) / height
102
+
103
+ # Generate 1D coordinates
104
+ x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
105
+ y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
106
+
107
+ # Create 2D meshgrid (width x height) and stack into UV
108
+ uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
109
+ uv_grid = torch.stack((uu, vv), dim=-1)
110
+
111
+ return uv_grid
112
+
113
+
114
+ def _interpolate(
115
+ x: torch.Tensor,
116
+ size: Tuple[int, int] = None,
117
+ scale_factor: float = None,
118
+ mode: str = "bilinear",
119
+ align_corners: bool = True,
120
+ ) -> torch.Tensor:
121
+ """
122
+ Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
123
+ """
124
+ if size is None:
125
+ size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
126
+
127
+ INT_MAX = 1610612736
128
+
129
+ input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
130
+
131
+ if input_elements > INT_MAX:
132
+ chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
133
+ interpolated_chunks = [
134
+ nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
135
+ ]
136
+ x = torch.cat(interpolated_chunks, dim=0)
137
+ return x.contiguous()
138
+ else:
139
+ return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
140
+
141
+
142
+ def _apply_pos_embed(x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
143
+ """
144
+ Apply positional embedding to tensor x.
145
+ """
146
+ patch_w = x.shape[-1]
147
+ patch_h = x.shape[-2]
148
+ pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
149
+ pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
150
+ pos_embed = pos_embed * ratio
151
+ pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
152
+ return x + pos_embed
153
+
154
+
155
+ def interpolate_pooling(hidden, patch_hw, img_hw, reference, pooling_func, use_vggt_pe):
156
+ (patch_h, patch_w) = patch_hw
157
+ (img_h, img_w) = img_hw
158
+ bs, N, S, D = hidden.shape
159
+ re_sample_ratio = 1 / np.sqrt(N * S / reference.shape[1])
160
+
161
+ _hidden = hidden.permute(0, 1, 3, 2)
162
+ _hidden = _hidden.reshape(bs*N, D, patch_h, patch_w)
163
+ if use_vggt_pe:
164
+ _hidden = _apply_pos_embed(_hidden, img_w, img_h)
165
+ hidden_pooling = _interpolate(
166
+ _hidden, scale_factor=re_sample_ratio, mode=pooling_func, align_corners=True
167
+ )
168
+ hidden_pooling = hidden_pooling.reshape(bs, N, D, -1).permute(0, 1, 3, 2).reshape(bs, -1, D)
169
+ return hidden_pooling
170
+
171
+
172
+ def custom_pooling(hidden, patch_hw, img_hw, reference, pooling_func, use_vggt_pe):
173
+ if pooling_func in ['bilinear']:
174
+ return interpolate_pooling(hidden, patch_hw, img_hw, reference, pooling_func, use_vggt_pe)
175
+ else:
176
+ raise NotImplementedError(f"Pooling function {pooling_func} is not implemented.")
capvector-pi05/src/vggt/layers/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .mlp import Mlp
8
+ from .patch_embed import PatchEmbed
9
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10
+ from .block import NestedTensorBlock
11
+ from .attention import MemEffAttention
capvector-pi05/src/vggt/layers/attention.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import logging
11
+ import os
12
+ import warnings
13
+
14
+ from torch import Tensor
15
+ from torch import nn
16
+ import torch.nn.functional as F
17
+
18
+ XFORMERS_AVAILABLE = False
19
+
20
+
21
+ class Attention(nn.Module):
22
+ def __init__(
23
+ self,
24
+ dim: int,
25
+ num_heads: int = 8,
26
+ qkv_bias: bool = True,
27
+ proj_bias: bool = True,
28
+ attn_drop: float = 0.0,
29
+ proj_drop: float = 0.0,
30
+ norm_layer: nn.Module = nn.LayerNorm,
31
+ qk_norm: bool = False,
32
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
33
+ rope=None,
34
+ ) -> None:
35
+ super().__init__()
36
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
37
+ self.num_heads = num_heads
38
+ self.head_dim = dim // num_heads
39
+ self.scale = self.head_dim**-0.5
40
+ self.fused_attn = fused_attn
41
+
42
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
43
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
44
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
45
+ self.attn_drop = nn.Dropout(attn_drop)
46
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
47
+ self.proj_drop = nn.Dropout(proj_drop)
48
+ self.rope = rope
49
+
50
+ def forward(self, x: Tensor, pos=None) -> Tensor:
51
+ B, N, C = x.shape
52
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
53
+ q, k, v = qkv.unbind(0)
54
+ q, k = self.q_norm(q), self.k_norm(k)
55
+
56
+ if self.rope is not None:
57
+ q = self.rope(q, pos)
58
+ k = self.rope(k, pos)
59
+
60
+ if self.fused_attn:
61
+ x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0)
62
+ else:
63
+ q = q * self.scale
64
+ attn = q @ k.transpose(-2, -1)
65
+ attn = attn.softmax(dim=-1)
66
+ attn = self.attn_drop(attn)
67
+ x = attn @ v
68
+
69
+ x = x.transpose(1, 2).reshape(B, N, C)
70
+ x = self.proj(x)
71
+ x = self.proj_drop(x)
72
+ return x
73
+
74
+
75
+ class MemEffAttention(Attention):
76
+ def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor:
77
+ assert pos is None
78
+ if not XFORMERS_AVAILABLE:
79
+ if attn_bias is not None:
80
+ raise AssertionError("xFormers is required for using nested tensors")
81
+ return super().forward(x)
82
+
83
+ B, N, C = x.shape
84
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
85
+
86
+ q, k, v = unbind(qkv, 2)
87
+
88
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
89
+ x = x.reshape([B, N, C])
90
+
91
+ x = self.proj(x)
92
+ x = self.proj_drop(x)
93
+ return x
capvector-pi05/src/vggt/layers/block.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ import logging
11
+ import os
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+ import warnings
14
+
15
+ import torch
16
+ from torch import nn, Tensor
17
+
18
+ from .attention import Attention
19
+ from .drop_path import DropPath
20
+ from .layer_scale import LayerScale
21
+ from .mlp import Mlp
22
+
23
+
24
+ XFORMERS_AVAILABLE = False
25
+
26
+
27
+ class Block(nn.Module):
28
+ def __init__(
29
+ self,
30
+ dim: int,
31
+ num_heads: int,
32
+ mlp_ratio: float = 4.0,
33
+ qkv_bias: bool = True,
34
+ proj_bias: bool = True,
35
+ ffn_bias: bool = True,
36
+ drop: float = 0.0,
37
+ attn_drop: float = 0.0,
38
+ init_values=None,
39
+ drop_path: float = 0.0,
40
+ act_layer: Callable[..., nn.Module] = nn.GELU,
41
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
42
+ attn_class: Callable[..., nn.Module] = Attention,
43
+ ffn_layer: Callable[..., nn.Module] = Mlp,
44
+ qk_norm: bool = False,
45
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
46
+ rope=None,
47
+ ) -> None:
48
+ super().__init__()
49
+
50
+ self.norm1 = norm_layer(dim)
51
+
52
+ self.attn = attn_class(
53
+ dim,
54
+ num_heads=num_heads,
55
+ qkv_bias=qkv_bias,
56
+ proj_bias=proj_bias,
57
+ attn_drop=attn_drop,
58
+ proj_drop=drop,
59
+ qk_norm=qk_norm,
60
+ fused_attn=fused_attn,
61
+ rope=rope,
62
+ )
63
+
64
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
65
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
66
+
67
+ self.norm2 = norm_layer(dim)
68
+ mlp_hidden_dim = int(dim * mlp_ratio)
69
+ self.mlp = ffn_layer(
70
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, bias=ffn_bias
71
+ )
72
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
73
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
74
+
75
+ self.sample_drop_ratio = drop_path
76
+
77
+ def forward(self, x: Tensor, pos=None) -> Tensor:
78
+ def attn_residual_func(x: Tensor, pos=None) -> Tensor:
79
+ return self.ls1(self.attn(self.norm1(x), pos=pos))
80
+
81
+ def ffn_residual_func(x: Tensor) -> Tensor:
82
+ return self.ls2(self.mlp(self.norm2(x)))
83
+
84
+ if self.training and self.sample_drop_ratio > 0.1:
85
+ # the overhead is compensated only for a drop path rate larger than 0.1
86
+ x = drop_add_residual_stochastic_depth(
87
+ x, pos=pos, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio
88
+ )
89
+ x = drop_add_residual_stochastic_depth(
90
+ x, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio
91
+ )
92
+ elif self.training and self.sample_drop_ratio > 0.0:
93
+ x = x + self.drop_path1(attn_residual_func(x, pos=pos))
94
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
95
+ else:
96
+ x = x + attn_residual_func(x, pos=pos)
97
+ x = x + ffn_residual_func(x)
98
+ return x
99
+
100
+
101
+ def drop_add_residual_stochastic_depth(
102
+ x: Tensor, residual_func: Callable[[Tensor], Tensor], sample_drop_ratio: float = 0.0, pos=None
103
+ ) -> Tensor:
104
+ # 1) extract subset using permutation
105
+ b, n, d = x.shape
106
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
107
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
108
+ x_subset = x[brange]
109
+
110
+ # 2) apply residual_func to get residual
111
+ if pos is not None:
112
+ # if necessary, apply rope to the subset
113
+ pos = pos[brange]
114
+ residual = residual_func(x_subset, pos=pos)
115
+ else:
116
+ residual = residual_func(x_subset)
117
+
118
+ x_flat = x.flatten(1)
119
+ residual = residual.flatten(1)
120
+
121
+ residual_scale_factor = b / sample_subset_size
122
+
123
+ # 3) add the residual
124
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
125
+ return x_plus_residual.view_as(x)
126
+
127
+
128
+ def get_branges_scales(x, sample_drop_ratio=0.0):
129
+ b, n, d = x.shape
130
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
131
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
132
+ residual_scale_factor = b / sample_subset_size
133
+ return brange, residual_scale_factor
134
+
135
+
136
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
137
+ if scaling_vector is None:
138
+ x_flat = x.flatten(1)
139
+ residual = residual.flatten(1)
140
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
141
+ else:
142
+ x_plus_residual = scaled_index_add(
143
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
144
+ )
145
+ return x_plus_residual
146
+
147
+
148
+ attn_bias_cache: Dict[Tuple, Any] = {}
149
+
150
+
151
+ def get_attn_bias_and_cat(x_list, branges=None):
152
+ """
153
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
154
+ """
155
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
156
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
157
+ if all_shapes not in attn_bias_cache.keys():
158
+ seqlens = []
159
+ for b, x in zip(batch_sizes, x_list):
160
+ for _ in range(b):
161
+ seqlens.append(x.shape[1])
162
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
163
+ attn_bias._batch_sizes = batch_sizes
164
+ attn_bias_cache[all_shapes] = attn_bias
165
+
166
+ if branges is not None:
167
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
168
+ else:
169
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
170
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
171
+
172
+ return attn_bias_cache[all_shapes], cat_tensors
173
+
174
+
175
+ def drop_add_residual_stochastic_depth_list(
176
+ x_list: List[Tensor],
177
+ residual_func: Callable[[Tensor, Any], Tensor],
178
+ sample_drop_ratio: float = 0.0,
179
+ scaling_vector=None,
180
+ ) -> Tensor:
181
+ # 1) generate random set of indices for dropping samples in the batch
182
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
183
+ branges = [s[0] for s in branges_scales]
184
+ residual_scale_factors = [s[1] for s in branges_scales]
185
+
186
+ # 2) get attention bias and index+concat the tensors
187
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
188
+
189
+ # 3) apply residual_func to get residual, and split the result
190
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
191
+
192
+ outputs = []
193
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
194
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
195
+ return outputs
196
+
197
+
198
+ class NestedTensorBlock(Block):
199
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
200
+ """
201
+ x_list contains a list of tensors to nest together and run
202
+ """
203
+ assert isinstance(self.attn, MemEffAttention)
204
+
205
+ if self.training and self.sample_drop_ratio > 0.0:
206
+
207
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
208
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
209
+
210
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
211
+ return self.mlp(self.norm2(x))
212
+
213
+ x_list = drop_add_residual_stochastic_depth_list(
214
+ x_list,
215
+ residual_func=attn_residual_func,
216
+ sample_drop_ratio=self.sample_drop_ratio,
217
+ scaling_vector=(self.ls1.gamma if isinstance(self.ls1, LayerScale) else None),
218
+ )
219
+ x_list = drop_add_residual_stochastic_depth_list(
220
+ x_list,
221
+ residual_func=ffn_residual_func,
222
+ sample_drop_ratio=self.sample_drop_ratio,
223
+ scaling_vector=(self.ls2.gamma if isinstance(self.ls1, LayerScale) else None),
224
+ )
225
+ return x_list
226
+ else:
227
+
228
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
229
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
230
+
231
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
232
+ return self.ls2(self.mlp(self.norm2(x)))
233
+
234
+ attn_bias, x = get_attn_bias_and_cat(x_list)
235
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
236
+ x = x + ffn_residual_func(x)
237
+ return attn_bias.split(x)
238
+
239
+ def forward(self, x_or_x_list):
240
+ if isinstance(x_or_x_list, Tensor):
241
+ return super().forward(x_or_x_list)
242
+ elif isinstance(x_or_x_list, list):
243
+ if not XFORMERS_AVAILABLE:
244
+ raise AssertionError("xFormers is required for using nested tensors")
245
+ return self.forward_nested(x_or_x_list)
246
+ else:
247
+ raise AssertionError
capvector-pi05/src/vggt/layers/drop_path.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
9
+
10
+
11
+ from torch import nn
12
+
13
+
14
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
15
+ if drop_prob == 0.0 or not training:
16
+ return x
17
+ keep_prob = 1 - drop_prob
18
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
19
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
20
+ if keep_prob > 0.0:
21
+ random_tensor.div_(keep_prob)
22
+ output = x * random_tensor
23
+ return output
24
+
25
+
26
+ class DropPath(nn.Module):
27
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
28
+
29
+ def __init__(self, drop_prob=None):
30
+ super(DropPath, self).__init__()
31
+ self.drop_prob = drop_prob
32
+
33
+ def forward(self, x):
34
+ return drop_path(x, self.drop_prob, self.training)
capvector-pi05/src/vggt/layers/layer_scale.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
7
+
8
+ from typing import Union
9
+
10
+ import torch
11
+ from torch import Tensor
12
+ from torch import nn
13
+
14
+
15
+ class LayerScale(nn.Module):
16
+ def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False) -> None:
17
+ super().__init__()
18
+ self.inplace = inplace
19
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
20
+
21
+ def forward(self, x: Tensor) -> Tensor:
22
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
capvector-pi05/src/vggt/layers/mlp.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
9
+
10
+
11
+ from typing import Callable, Optional
12
+
13
+ from torch import Tensor, nn
14
+
15
+
16
+ class Mlp(nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_features: int,
20
+ hidden_features: Optional[int] = None,
21
+ out_features: Optional[int] = None,
22
+ act_layer: Callable[..., nn.Module] = nn.GELU,
23
+ drop: float = 0.0,
24
+ bias: bool = True,
25
+ ) -> None:
26
+ super().__init__()
27
+ out_features = out_features or in_features
28
+ hidden_features = hidden_features or in_features
29
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
30
+ self.act = act_layer()
31
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
32
+ self.drop = nn.Dropout(drop)
33
+
34
+ def forward(self, x: Tensor) -> Tensor:
35
+ x = self.fc1(x)
36
+ x = self.act(x)
37
+ x = self.drop(x)
38
+ x = self.fc2(x)
39
+ x = self.drop(x)
40
+ return x