| """
|
| load.py
|
|
|
| Entry point for loading pretrained VLMs for inference; exposes functions for listing available models (with canonical
|
| IDs, mappings to paper experiments, and short descriptions), as well as for loading models (from disk or HF Hub).
|
| """
|
|
|
| import json
|
| import os
|
| from pathlib import Path
|
| from typing import List, Optional, Union
|
|
|
| from huggingface_hub import HfFileSystem, hf_hub_download
|
|
|
| from prismatic.conf import ModelConfig
|
| from prismatic.models.materialize import get_llm_backbone_and_tokenizer, get_vision_backbone_and_transform
|
| from prismatic.models.registry import GLOBAL_REGISTRY, MODEL_REGISTRY
|
| from prismatic.models.vlas import OpenVLA
|
| from prismatic.models.vlms import PrismaticVLM
|
| from prismatic.overwatch import initialize_overwatch
|
| from prismatic.vla.action_tokenizer import ActionTokenizer
|
|
|
|
|
| overwatch = initialize_overwatch(__name__)
|
|
|
|
|
|
|
| HF_HUB_REPO = "TRI-ML/prismatic-vlms"
|
| VLA_HF_HUB_REPO = "openvla/openvla-dev"
|
|
|
|
|
|
|
| def available_models() -> List[str]:
|
| return list(MODEL_REGISTRY.keys())
|
|
|
|
|
| def available_model_names() -> List[str]:
|
| return list(GLOBAL_REGISTRY.items())
|
|
|
|
|
| def get_model_description(model_id_or_name: str) -> str:
|
| if model_id_or_name not in GLOBAL_REGISTRY:
|
| raise ValueError(f"Couldn't find `{model_id_or_name = }; check `prismatic.available_model_names()`")
|
|
|
|
|
| print(json.dumps(description := GLOBAL_REGISTRY[model_id_or_name]["description"], indent=2))
|
|
|
| return description
|
|
|
|
|
|
|
| def load(
|
| model_id_or_path: Union[str, Path],
|
| hf_token: Optional[str] = None,
|
| cache_dir: Optional[Union[str, Path]] = None,
|
| load_for_training: bool = False,
|
| ) -> PrismaticVLM:
|
| """Loads a pretrained PrismaticVLM from either local disk or the HuggingFace Hub."""
|
| if os.path.isdir(model_id_or_path):
|
| overwatch.info(f"Loading from local path `{(run_dir := Path(model_id_or_path))}`")
|
|
|
|
|
| config_json, checkpoint_pt = run_dir / "config.json", run_dir / "checkpoints" / "latest-checkpoint.pt"
|
| assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`"
|
| assert checkpoint_pt.exists(), f"Missing checkpoint for `{run_dir = }`"
|
| else:
|
| if model_id_or_path not in GLOBAL_REGISTRY:
|
| raise ValueError(f"Couldn't find `{model_id_or_path = }; check `prismatic.available_model_names()`")
|
|
|
| overwatch.info(f"Downloading `{(model_id := GLOBAL_REGISTRY[model_id_or_path]['model_id'])} from HF Hub")
|
| with overwatch.local_zero_first():
|
| config_json = hf_hub_download(repo_id=HF_HUB_REPO, filename=f"{model_id}/config.json", cache_dir=cache_dir)
|
| checkpoint_pt = hf_hub_download(
|
| repo_id=HF_HUB_REPO, filename=f"{model_id}/checkpoints/latest-checkpoint.pt", cache_dir=cache_dir
|
| )
|
|
|
|
|
| with open(config_json, "r") as f:
|
| model_cfg = json.load(f)["model"]
|
|
|
|
|
|
|
| overwatch.info(
|
| f"Found Config =>> Loading & Freezing [bold blue]{model_cfg['model_id']}[/] with:\n"
|
| f" Vision Backbone =>> [bold]{model_cfg['vision_backbone_id']}[/]\n"
|
| f" LLM Backbone =>> [bold]{model_cfg['llm_backbone_id']}[/]\n"
|
| f" Arch Specifier =>> [bold]{model_cfg['arch_specifier']}[/]\n"
|
| f" Checkpoint Path =>> [underline]`{checkpoint_pt}`[/]"
|
| )
|
|
|
|
|
| overwatch.info(f"Loading Vision Backbone [bold]{model_cfg['vision_backbone_id']}[/]")
|
| vision_backbone, image_transform = get_vision_backbone_and_transform(
|
| model_cfg["vision_backbone_id"],
|
| model_cfg["image_resize_strategy"],
|
| )
|
|
|
|
|
| overwatch.info(f"Loading Pretrained LLM [bold]{model_cfg['llm_backbone_id']}[/] via HF Transformers")
|
| llm_backbone, tokenizer = get_llm_backbone_and_tokenizer(
|
| model_cfg["llm_backbone_id"],
|
| llm_max_length=model_cfg.get("llm_max_length", 2048),
|
| hf_token=hf_token,
|
| inference_mode=not load_for_training,
|
| )
|
|
|
|
|
| overwatch.info(f"Loading VLM [bold blue]{model_cfg['model_id']}[/] from Checkpoint")
|
| vlm = PrismaticVLM.from_pretrained(
|
| checkpoint_pt,
|
| model_cfg["model_id"],
|
| vision_backbone,
|
| llm_backbone,
|
| arch_specifier=model_cfg["arch_specifier"],
|
| freeze_weights=not load_for_training,
|
| )
|
|
|
| return vlm
|
|
|
|
|
|
|
| def load_vla(
|
| model_id_or_path: Union[str, Path],
|
| hf_token: Optional[str] = None,
|
| cache_dir: Optional[Union[str, Path]] = None,
|
| load_for_training: bool = False,
|
| step_to_load: Optional[int] = None,
|
| model_type: str = "pretrained",
|
| ) -> OpenVLA:
|
| """Loads a pretrained OpenVLA from either local disk or the HuggingFace Hub."""
|
|
|
|
|
|
|
| if os.path.isfile(model_id_or_path):
|
| overwatch.info(f"Loading from local checkpoint path `{(checkpoint_pt := Path(model_id_or_path))}`")
|
|
|
|
|
| assert (checkpoint_pt.suffix == ".pt") and (checkpoint_pt.parent.name == "checkpoints"), "Invalid checkpoint!"
|
| run_dir = checkpoint_pt.parents[1]
|
|
|
|
|
| config_json, dataset_statistics_json = run_dir / "config.json", run_dir / "dataset_statistics.json"
|
| assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`"
|
| assert dataset_statistics_json.exists(), f"Missing `dataset_statistics.json` for `{run_dir = }`"
|
|
|
|
|
| else:
|
|
|
| overwatch.info(f"Checking HF for `{(hf_path := str(Path(VLA_HF_HUB_REPO) / model_type / model_id_or_path))}`")
|
| if not (tmpfs := HfFileSystem()).exists(hf_path):
|
| raise ValueError(f"Couldn't find valid HF Hub Path `{hf_path = }`")
|
|
|
|
|
| step_to_load = f"{step_to_load:06d}" if step_to_load is not None else None
|
| valid_ckpts = tmpfs.glob(f"{hf_path}/checkpoints/step-{step_to_load if step_to_load is not None else ''}*.pt")
|
| if (len(valid_ckpts) == 0) or (step_to_load is not None and len(valid_ckpts) != 1):
|
| raise ValueError(f"Couldn't find a valid checkpoint to load from HF Hub Path `{hf_path}/checkpoints/")
|
|
|
|
|
| target_ckpt = Path(valid_ckpts[-1]).name
|
|
|
| overwatch.info(f"Downloading Model `{model_id_or_path}` Config & Checkpoint `{target_ckpt}`")
|
| with overwatch.local_zero_first():
|
| relpath = Path(model_type) / model_id_or_path
|
| config_json = hf_hub_download(
|
| repo_id=VLA_HF_HUB_REPO, filename=f"{(relpath / 'config.json')!s}", cache_dir=cache_dir
|
| )
|
| dataset_statistics_json = hf_hub_download(
|
| repo_id=VLA_HF_HUB_REPO, filename=f"{(relpath / 'dataset_statistics.json')!s}", cache_dir=cache_dir
|
| )
|
| checkpoint_pt = hf_hub_download(
|
| repo_id=VLA_HF_HUB_REPO, filename=f"{(relpath / 'checkpoints' / target_ckpt)!s}", cache_dir=cache_dir
|
| )
|
|
|
|
|
| with open(config_json, "r") as f:
|
| vla_cfg = json.load(f)["vla"]
|
| model_cfg = ModelConfig.get_choice_class(vla_cfg["base_vlm"])()
|
|
|
|
|
| with open(dataset_statistics_json, "r") as f:
|
| norm_stats = json.load(f)
|
|
|
|
|
|
|
| overwatch.info(
|
| f"Found Config =>> Loading & Freezing [bold blue]{model_cfg.model_id}[/] with:\n"
|
| f" Vision Backbone =>> [bold]{model_cfg.vision_backbone_id}[/]\n"
|
| f" LLM Backbone =>> [bold]{model_cfg.llm_backbone_id}[/]\n"
|
| f" Arch Specifier =>> [bold]{model_cfg.arch_specifier}[/]\n"
|
| f" Checkpoint Path =>> [underline]`{checkpoint_pt}`[/]"
|
| )
|
|
|
|
|
| overwatch.info(f"Loading Vision Backbone [bold]{model_cfg.vision_backbone_id}[/]")
|
| vision_backbone, image_transform = get_vision_backbone_and_transform(
|
| model_cfg.vision_backbone_id,
|
| model_cfg.image_resize_strategy,
|
| )
|
|
|
|
|
| overwatch.info(f"Loading Pretrained LLM [bold]{model_cfg.llm_backbone_id}[/] via HF Transformers")
|
| llm_backbone, tokenizer = get_llm_backbone_and_tokenizer(
|
| model_cfg.llm_backbone_id,
|
| llm_max_length=model_cfg.llm_max_length,
|
| hf_token=hf_token,
|
| inference_mode=not load_for_training,
|
| )
|
|
|
|
|
| action_tokenizer = ActionTokenizer(llm_backbone.get_tokenizer())
|
|
|
|
|
| overwatch.info(f"Loading VLA [bold blue]{model_cfg.model_id}[/] from Checkpoint")
|
| vla = OpenVLA.from_pretrained(
|
| checkpoint_pt,
|
| model_cfg.model_id,
|
| vision_backbone,
|
| llm_backbone,
|
| arch_specifier=model_cfg.arch_specifier,
|
| freeze_weights=not load_for_training,
|
| norm_stats=norm_stats,
|
| action_tokenizer=action_tokenizer,
|
| )
|
|
|
| return vla
|
|
|