| import concurrent.futures |
| import datetime |
| import getpass |
| import logging |
| import os |
| import pathlib |
| import re |
| import shutil |
| import stat |
| import time |
| import urllib.parse |
|
|
| import filelock |
| import fsspec |
| import fsspec.generic |
| import tqdm_loggable.auto as tqdm |
|
|
| |
| _OPENPI_DATA_HOME = "OPENPI_DATA_HOME" |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def get_cache_dir() -> pathlib.Path: |
| default_dir = "~/.cache/openpi" |
| if os.path.exists("/mnt/weka"): |
| default_dir = f"/mnt/weka/{getpass.getuser()}/.cache/openpi" |
|
|
| cache_dir = pathlib.Path(os.getenv(_OPENPI_DATA_HOME, default_dir)).expanduser().resolve() |
| cache_dir.mkdir(parents=True, exist_ok=True) |
| _set_folder_permission(cache_dir) |
| return cache_dir |
|
|
|
|
| def maybe_download(url: str, *, force_download: bool = False, **kwargs) -> pathlib.Path: |
| """Download a file or directory from a remote filesystem to the local cache, and return the local path. |
| |
| If the local file already exists, it will be returned directly. |
| |
| It is safe to call this function concurrently from multiple processes. |
| See `get_cache_dir` for more details on the cache directory. |
| |
| Args: |
| url: URL to the file to download. |
| force_download: If True, the file will be downloaded even if it already exists in the cache. |
| **kwargs: Additional arguments to pass to fsspec. |
| |
| Returns: |
| Local path to the downloaded file or directory. That path is guaranteed to exist and is absolute. |
| """ |
| |
| parsed = urllib.parse.urlparse(url) |
|
|
| |
| if parsed.scheme == "": |
| path = pathlib.Path(url) |
| if not path.exists(): |
| raise FileNotFoundError(f"File not found at {url}") |
| return path.resolve() |
|
|
| cache_dir = get_cache_dir() |
|
|
| local_path = cache_dir / parsed.netloc / parsed.path.strip("/") |
| local_path = local_path.resolve() |
|
|
| |
| invalidate_cache = False |
| if local_path.exists(): |
| if force_download or _should_invalidate_cache(cache_dir, local_path): |
| invalidate_cache = True |
| else: |
| return local_path |
|
|
| try: |
| lock_path = local_path.with_suffix(".lock") |
| with filelock.FileLock(lock_path): |
| |
| _ensure_permissions(lock_path) |
| |
| if invalidate_cache: |
| logger.info(f"Removing expired cached entry: {local_path}") |
| if local_path.is_dir(): |
| shutil.rmtree(local_path) |
| else: |
| local_path.unlink() |
|
|
| |
| logger.info(f"Downloading {url} to {local_path}") |
| scratch_path = local_path.with_suffix(".partial") |
| _download_fsspec(url, scratch_path, **kwargs) |
|
|
| shutil.move(scratch_path, local_path) |
| _ensure_permissions(local_path) |
|
|
| except PermissionError as e: |
| msg = ( |
| f"Local file permission error was encountered while downloading {url}. " |
| f"Please try again after removing the cached data using: `rm -rf {local_path}*`" |
| ) |
| raise PermissionError(msg) from e |
|
|
| return local_path |
|
|
|
|
| def _download_fsspec(url: str, local_path: pathlib.Path, **kwargs) -> None: |
| """Download a file from a remote filesystem to the local cache, and return the local path.""" |
| fs, _ = fsspec.core.url_to_fs(url, **kwargs) |
| info = fs.info(url) |
| |
| if is_dir := (info["type"] == "directory" or (info["size"] == 0 and info["name"].endswith("/"))): |
| total_size = fs.du(url) |
| else: |
| total_size = info["size"] |
| with tqdm.tqdm(total=total_size, unit="iB", unit_scale=True, unit_divisor=1024) as pbar: |
| executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) |
| future = executor.submit(fs.get, url, local_path, recursive=is_dir) |
| while not future.done(): |
| current_size = sum(f.stat().st_size for f in [*local_path.rglob("*"), local_path] if f.is_file()) |
| pbar.update(current_size - pbar.n) |
| time.sleep(1) |
| pbar.update(total_size - pbar.n) |
|
|
|
|
| def _set_permission(path: pathlib.Path, target_permission: int): |
| """chmod requires executable permission to be set, so we skip if the permission is already match with the target.""" |
| if path.stat().st_mode & target_permission == target_permission: |
| logger.debug(f"Skipping {path} because it already has correct permissions") |
| return |
| path.chmod(target_permission) |
| logger.debug(f"Set {path} to {target_permission}") |
|
|
|
|
| def _set_folder_permission(folder_path: pathlib.Path) -> None: |
| """Set folder permission to be read, write and searchable.""" |
| _set_permission(folder_path, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) |
|
|
|
|
| def _ensure_permissions(path: pathlib.Path) -> None: |
| """Since we are sharing cache directory with containerized runtime as well as training script, we need to |
| ensure that the cache directory has the correct permissions. |
| """ |
|
|
| def _setup_folder_permission_between_cache_dir_and_path(path: pathlib.Path) -> None: |
| cache_dir = get_cache_dir() |
| relative_path = path.relative_to(cache_dir) |
| moving_path = cache_dir |
| for part in relative_path.parts: |
| _set_folder_permission(moving_path / part) |
| moving_path = moving_path / part |
|
|
| def _set_file_permission(file_path: pathlib.Path) -> None: |
| """Set all files to be read & writable, if it is a script, keep it as a script.""" |
| file_rw = stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IWGRP | stat.S_IROTH | stat.S_IWOTH |
| if file_path.stat().st_mode & 0o100: |
| _set_permission(file_path, file_rw | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH) |
| else: |
| _set_permission(file_path, file_rw) |
|
|
| _setup_folder_permission_between_cache_dir_and_path(path) |
| for root, dirs, files in os.walk(str(path)): |
| root_path = pathlib.Path(root) |
| for file in files: |
| file_path = root_path / file |
| _set_file_permission(file_path) |
|
|
| for dir in dirs: |
| dir_path = root_path / dir |
| _set_folder_permission(dir_path) |
|
|
|
|
| def _get_mtime(year: int, month: int, day: int) -> float: |
| """Get the mtime of a given date at midnight UTC.""" |
| date = datetime.datetime(year, month, day, tzinfo=datetime.UTC) |
| return time.mktime(date.timetuple()) |
|
|
|
|
| |
| |
| |
| _INVALIDATE_CACHE_DIRS: dict[re.Pattern, float] = { |
| re.compile("openpi-assets/checkpoints/pi0_aloha_pen_uncap"): _get_mtime(2025, 2, 17), |
| re.compile("openpi-assets/checkpoints/pi0_libero"): _get_mtime(2025, 2, 6), |
| re.compile("openpi-assets/checkpoints/"): _get_mtime(2025, 2, 3), |
| } |
|
|
|
|
| def _should_invalidate_cache(cache_dir: pathlib.Path, local_path: pathlib.Path) -> bool: |
| """Invalidate the cache if it is expired. Return True if the cache was invalidated.""" |
|
|
| assert local_path.exists(), f"File not found at {local_path}" |
|
|
| relative_path = str(local_path.relative_to(cache_dir)) |
| for pattern, expire_time in _INVALIDATE_CACHE_DIRS.items(): |
| if pattern.match(relative_path): |
| |
| return local_path.stat().st_mtime <= expire_time |
|
|
| return False |
|
|